diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a897f95..424966a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,19 +2,46 @@ # CUTLASS 4.x -## [4.2.0](https://github.com/NVIDIA/cutlass/tree/main) (2025-08-21) +## [4.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.2.0) (2025-09-15) ### CuTe DSL -* We will likely be skipping 4.2.dev release and directly target 4.2. -* CuTeDSL version remains at 4.1.0 till then. +* More Python versions are now supported for both x86-64 and aarch64, including + - Python 3.10, 3.11, 3.12, and 3.13 +* Added new example and updated notebook to get started with CuTe DSL + - [Call kernels with dlpack bypassed](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py) + - Updates on [TensorSSA demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/tensorssa.ipynb) + + Added a section for introducing the broadcast +* API updates + - Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details +* Bug fixings and improvements + - Fixed ``cute.print_tensor`` for coordinate tensor + - Fixed `cute.print` for tuple of layouts + - Fixed frozen object is not properly updated after fully assigned in dynamic control flow + - Fixed assign tuple/list element in a dynamic control flow may cause compilation failure + - Improved error message when CUDA context is not initialized + - Improved docstring of congruent and weakly_congruent ### CUTLASS C++ * Add K major scale factor support for Hopper SM90 blockwise kernels. * Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). - Add fused reduction kernel support for cutlass MLA. + - Add softmax skip correction. + - Support for GQA in FMHA backward kernel. - Fix an issue where `get_unmasked_trip_count` may return a negative value. - Fix an issue where mbarriers are initialized with a zero arrival count. + - Fix a corner case issue where the sequence length of q is not a multiple of tile_q. + - Remove tma padding for forward kernel inputs. * Add Blackwell SM120 blockwise gemm kernel example: [example 87](https://github.com/NVIDIA/cutlass/tree/main/87_blackwell_geforce_gemm_blockwise/). +* Add Blackwell SM100 kernel example of MoE gemm using TMA+CPASYNC to load input matrices: [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm/). +* Support for Blackwell SM103 kernels for B300 GPUs. + - Collective mainloop codes: [Blockscaled datatypes with support for dense GEMM mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp) + - New [GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders. + - Kernel codes: [Blockscaled datatypes with support for dense GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp). +* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM103 architecture: + - [Blockscaled ultra fp4 dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/89_sm103_fp4_ultra_gemm/). + - [Blockscaled ultra fp4 dense grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/90_sm103_fp4_ultra_grouped_gemm). +* Set of unit tests that demonstrate the usage of Blackwell SM103 blockscaled GEMM + - Unit test files with prefix name of `sm103_` under [GEMM device unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/). * Support for Blackwell SM100 cpasync kernel. - Collective mainloop codes: [cpasync mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp). - Kernel codes: [cpasync kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp). @@ -41,9 +68,24 @@ * Add support for heuristics-based kernel filtering and autotuning using `nvidia-matmul-heuristics`. - Details please refer to [heuristics doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/heuristics.md). * Rename legacy Python API package from `cutlass` to `cutlass_cppgen`. +* Add Blackwell EVT support to legacy Python interface. + - Restructuring the C++ Blackwell SM100 Collective Epilogue Builder to work with the Python interface's `EpilogueDescriptors`. + - Added Blackwell SM100 EVT Emitter on the Python side and routed most emission through Hopper SM90 Emitter. + - Added some support for running SM100 kernels via the Python interface. +* Instantiating more Blackwell kernels in profiler. + - Blackwell SM100 and SM103 kernels support `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate all possible combinations. + - To use this feature, `CUTLASS_LIBRARY_KERNELS` must be non-empty. Profiler will combine `CUTLASS_LIBRARY_KERNELS` and `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate specific kernels. + - Details please check [Profiler Doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/profiler.md). * Fix some profiler issues: - Modify default cluster callback values to none 0 to avoid profiler failure when these values are not set in command line. - Fix some no output and timeout issues. + - Fix Pingpong Blockwise Hopper library generation. +* Fix some kernel issues: + - Fix Hopper SM90 group gemm kernel to only use the commit group and wait group instead of also waiting on mbarriers. + - Support Blackwell SM120 mixed input blockscaled grouped GEMM. + - Fix a tiny bug when K is large for Blackwell SM103 fp4 grouped GEMM kernel. + - Fix an issue in [example 68](https://github.com/NVIDIA/cutlass/tree/main/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) where problem size has value of 0. + - Relax k dimension constraints for blockwise gemm on Hopper in [example 68](https://github.com/NVIDIA/cutlass/tree/main/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/). * Add following unit tests: - [fp16 accmulator for sm89 fp8 mma](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/ampere/cooperative_gemm.cu) - [movmatrix test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/turing/movm.cu) diff --git a/CMakeLists.txt b/CMakeLists.txt index 29fb4e21..23be6991 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -337,7 +337,7 @@ set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma-delimited list of opera set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma-delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If the string 'all' is specified, all kernels are enabled.") set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option ONLY takes effect if CUTLASS_LIBRARY_KERNELS is set.") set(CUTLASS_LIBRARY_EXCLUDE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option always takes effect, whether or not CUTLASS_LIBRARY_KERNELS is set. It also can exclude kernels from the filter file (see KERNEL_FILTER_FILE).") -set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.") +set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 and SM100 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.") if(CUTLASS_LIBRARY_INSTANTIATION_LEVEL OR CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE) message(STATUS "Enable extended SM90 WGMMA instruction shapes for instantiation levels") diff --git a/README.md b/README.md index 06db5899..415eb3b1 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ # CUTLASS 4.2.0 -_CUTLASS 4.2.0 - Aug 2025_ +_CUTLASS 4.2.0 - Sept 2025_ CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. It incorporates strategies for @@ -46,16 +46,43 @@ To get started quickly - please refer : # What's New in CUTLASS 4.2 ## CuTe DSL -* We will likely be skipping 4.2.dev release and directly target 4.2. -* CuTeDSL version remains at 4.1.0 till then. +* More Python versions are now supported for both x86-64 and aarch64, including + - Python 3.10, 3.11, 3.12, and 3.13 +* Added new example and updated notebook to get started with CuTe DSL + - [Call kernels with dlpack bypassed](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py) + - Updates on [TensorSSA demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/tensorssa.ipynb) + + Added a section for introducing the broadcast +* API updates + - Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details +* Bug fixings and improvements + - Fixed ``cute.print_tensor`` for coordinate tensor + - Fixed `cute.print` for tuple of layouts + - Fixed frozen object is not properly updated after fully assigned in dynamic control flow + - Fixed assign tuple/list element in a dynamic control flow may cause compilation failure + - Improved error message when CUDA context is not initialized + - Improved docstring of congruent and weakly_congruent ## CUTLASS C++ * Add K major scale factor support for Hopper SM90 blockwise kernels. * Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). - Add fused reduction kernel support for cutlass MLA. + - Add softmax skip correction. + - Support for GQA in FMHA backward kernel. - Fix an issue where `get_unmasked_trip_count` may return a negative value. - Fix an issue where mbarriers are initialized with a zero arrival count. + - Fix a corner case issue where the sequence length of q is not a multiple of tile_q. + - Remove tma padding for forward kernel inputs. * Add Blackwell SM120 blockwise gemm kernel example: [example 87](https://github.com/NVIDIA/cutlass/tree/main/87_blackwell_geforce_gemm_blockwise/). +* Add Blackwell SM100 kernel example of MoE gemm using TMA+CPASYNC to load input matrices: [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm/). +* Support for Blackwell SM103 kernels for B300 GPUs. + - Collective mainloop codes: [Blockscaled datatypes with support for dense GEMM mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp) + - New [GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders. + - Kernel codes: [Blockscaled datatypes with support for dense GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp). +* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM103 architecture: + - [Blockscaled ultra fp4 dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/89_sm103_fp4_ultra_gemm/). + - [Blockscaled ultra fp4 dense grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/90_sm103_fp4_ultra_grouped_gemm). +* Set of unit tests that demonstrate the usage of Blackwell SM103 blockscaled GEMM + - Unit test files with prefix name of `sm103_` under [GEMM device unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/). * Support for Blackwell SM100 cpasync kernel. - Collective mainloop codes: [cpasync mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp). - Kernel codes: [cpasync kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp). @@ -82,9 +109,24 @@ To get started quickly - please refer : * Add support for heuristics-based kernel filtering and autotuning using `nvidia-matmul-heuristics`. - Details please refer to [heuristics doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/heuristics.md). * Rename legacy Python API package from `cutlass` to `cutlass_cppgen`. +* Add Blackwell EVT support to legacy Python interface. + - Restructuring the C++ Blackwell SM100 Collective Epilogue Builder to work with the Python interface's `EpilogueDescriptors`. + - Added Blackwell SM100 EVT Emitter on the Python side and routed most emission through Hopper SM90 Emitter. + - Added some support for running SM100 kernels via the Python interface. +* Instantiating more Blackwell kernels in profiler. + - Blackwell SM100 and SM103 kernels support `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate all possible combinations. + - To use this feature, `CUTLASS_LIBRARY_KERNELS` must be non-empty. Profiler will combine `CUTLASS_LIBRARY_KERNELS` and `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate specific kernels. + - Details please check [Profiler Doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/profiler.md). * Fix some profiler issues: - Modify default cluster callback values to none 0 to avoid profiler failure when these values are not set in command line. - Fix some no output and timeout issues. + - Fix Pingpong Blockwise Hopper library generation. +* Fix some kernel issues: + - Fix Hopper SM90 group gemm kernel to only use the commit group and wait group instead of also waiting on mbarriers. + - Support Blackwell SM120 mixed input blockscaled grouped GEMM. + - Fix a tiny bug when K is large for Blackwell SM103 fp4 grouped GEMM kernel. + - Fix an issue in [example 68](https://github.com/NVIDIA/cutlass/tree/main/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) where problem size has value of 0. + - Relax k dimension constraints for blockwise gemm on Hopper in [example 68](https://github.com/NVIDIA/cutlass/tree/main/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/). * Add following unit tests: - [fp16 accmulator for sm89 fp8 mma](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/ampere/cooperative_gemm.cu) - [movmatrix test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/turing/movm.cu) diff --git a/examples/65_distributed_gemm/CMakeLists.txt b/examples/65_distributed_gemm/CMakeLists.txt index 247b3407..043f0cb8 100644 --- a/examples/65_distributed_gemm/CMakeLists.txt +++ b/examples/65_distributed_gemm/CMakeLists.txt @@ -26,7 +26,9 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +if (CUTLASS_NVCC_ARCHS MATCHES 90a) cutlass_example_add_executable( 65_distributed_gemm 65_distributed_gemm.cu ) +endif() diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu index 9e55755b..f080b6c6 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu @@ -129,7 +129,7 @@ using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_confi using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand -using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise; using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu index ab88f54d..19e012b0 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu @@ -137,7 +137,7 @@ using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig; @@ -402,12 +402,37 @@ void initialize(const OptionType &options) { beta_host.clear(); for (int i = 0; i < options.groups; i++) { - ptr_A_host.at(i) = block_A.get() + offset_A.at(i); - ptr_B_host.at(i) = block_B.get() + offset_B.at(i); - ptr_C_host.at(i) = block_C.get() + offset_C.at(i); - ptr_D_host.at(i) = block_D.get() + offset_D.at(i); - ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i); - ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i); + // If the current group's matrix has size 0, set the pointer to nullptr + if (i < options.groups - 1 && offset_A.at(i) == offset_A.at(i + 1)) { + ptr_A_host.at(i) = nullptr; + } else { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + } + if (i < options.groups - 1 && offset_B.at(i) == offset_B.at(i + 1)) { + ptr_B_host.at(i) = nullptr; + } else { + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + } + if (i < options.groups - 1 && offset_C.at(i) == offset_C.at(i + 1)) { + ptr_C_host.at(i) = nullptr; + } else { + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + } + if (i < options.groups - 1 && offset_D.at(i) == offset_D.at(i + 1)) { + ptr_D_host.at(i) = nullptr; + } else { + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + if (i < options.groups - 1 && offset_blockscale_A.at(i) == offset_blockscale_A.at(i + 1)) { + ptr_blockscale_A_host.at(i) = nullptr; + } else { + ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i); + } + if (i < options.groups - 1 && offset_blockscale_B.at(i) == offset_blockscale_B.at(i + 1)) { + ptr_blockscale_B_host.at(i) = nullptr; + } else { + ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i); + } alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); ptr_alpha_host.at(i) = block_alpha.get() + i; @@ -546,10 +571,10 @@ bool verify(const OptionType &options) { blockscale_block_B.copy_to_host(blockscale_block_B_host.data()); bool passed = true; + std::cout << " Running host reference kernel - may run for a while for large problems." << std::endl; for (int group_idx = 0; group_idx < options.groups; group_idx++) { // Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape auto [m, n, k] = options.problem_sizes_host.at(group_idx); - auto gemm_problem_shape = cute::make_shape(m, n, k); // Create instantiation for device reference gemm kernel auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx), @@ -598,11 +623,7 @@ bool verify(const OptionType &options) { ElementAccumulator, ElementCompute, decltype(C), - decltype(D), - unused_t, // bias - unused_t, // Aux - unused_t, // valpha - unused_t // vbeta + decltype(D) > epilogue_params; epilogue_params.C = C; @@ -639,6 +660,24 @@ int run(OptionType &options, bool host_problem_shapes_available = true) allocate(options); initialize(options); + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl; + std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl; + std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl; + std::string raster = "Heuristic"; + if (options.raster_order == RasterOrderOptions::AlongN) { + raster = "Along N"; + } + else if (options.raster_order == RasterOrderOptions::AlongM) { + raster = "Along M"; + } + std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl; + // Instantiate CUTLASS kernel depending on templates Gemm gemm; @@ -671,8 +710,7 @@ int run(OptionType &options, bool host_problem_shapes_available = true) } // Run profiling loop - if (options.iterations > 0) - { + if (options.iterations > 0) { GpuTimer timer; timer.start(); for (int iter = 0; iter < options.iterations; ++iter) { @@ -686,25 +724,6 @@ int run(OptionType &options, bool host_problem_shapes_available = true) result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); - std::string raster = "Heuristic"; - - if (options.raster_order == RasterOrderOptions::AlongN) { - raster = "Along N"; - } - else if (options.raster_order == RasterOrderOptions::AlongM) { - raster = "Along M"; - } - - std::cout << " Problem Sizes, Alpha, Beta " << std::endl; - for (int32_t i = 0; i < options.groups; ++i) { - std::cout << " " << options.problem_sizes_host.at(i); - std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; - } - std::cout << " Groups : " << options.groups << std::endl; - std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl; - std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl; - std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl; - std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl; std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; std::cout << " GFLOPS: " << result.gflops << std::endl; fflush(stdout); diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu index b187d2da..b5419fe2 100644 --- a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu @@ -132,8 +132,7 @@ using ElementCompute = float; // E using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag - -using TileShape = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()... +using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster static constexpr int ScaleGranularityM = 1; @@ -148,7 +147,7 @@ using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout ty using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand -using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; using FusionOperation = cutlass::epilogue::fusion::LinearCombination; @@ -407,12 +406,37 @@ void initialize(const OptionType &options) { beta_host.clear(); for (int i = 0; i < options.groups; i++) { - ptr_A_host.at(i) = block_A.get() + offset_A.at(i); - ptr_B_host.at(i) = block_B.get() + offset_B.at(i); - ptr_C_host.at(i) = block_C.get() + offset_C.at(i); - ptr_D_host.at(i) = block_D.get() + offset_D.at(i); - ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i); - ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i); + // If the current group's matrix has size 0, set the pointer to nullptr + if (i < options.groups - 1 && offset_A.at(i) == offset_A.at(i + 1)) { + ptr_A_host.at(i) = nullptr; + } else { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + } + if (i < options.groups - 1 && offset_B.at(i) == offset_B.at(i + 1)) { + ptr_B_host.at(i) = nullptr; + } else { + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + } + if (i < options.groups - 1 && offset_C.at(i) == offset_C.at(i + 1)) { + ptr_C_host.at(i) = nullptr; + } else { + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + } + if (i < options.groups - 1 && offset_D.at(i) == offset_D.at(i + 1)) { + ptr_D_host.at(i) = nullptr; + } else { + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + if (i < options.groups - 1 && offset_blockscale_A.at(i) == offset_blockscale_A.at(i + 1)) { + ptr_blockscale_A_host.at(i) = nullptr; + } else { + ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i); + } + if (i < options.groups - 1 && offset_blockscale_B.at(i) == offset_blockscale_B.at(i + 1)) { + ptr_blockscale_B_host.at(i) = nullptr; + } else { + ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i); + } alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); ptr_alpha_host.at(i) = block_alpha.get() + i; @@ -551,10 +575,10 @@ bool verify(const OptionType &options) { blockscale_block_B.copy_to_host(blockscale_block_B_host.data()); bool passed = true; + std::cout << " Running host reference kernel - may run for a while for large problems." << std::endl; for (int group_idx = 0; group_idx < options.groups; group_idx++) { // Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape auto [m, n, k] = options.problem_sizes_after_alignment_host.at(group_idx); - auto gemm_problem_shape = cute::make_shape(m, n, k); // Create instantiation for device reference gemm kernel auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx), @@ -637,10 +661,27 @@ bool verify(const OptionType &options) { template int run(OptionType &options, bool host_problem_shapes_available = true) { - allocate(options); initialize(options); + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl; + std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl; + std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl; + std::string raster = "Heuristic"; + if (options.raster_order == RasterOrderOptions::AlongN) { + raster = "Along N"; + } + else if (options.raster_order == RasterOrderOptions::AlongM) { + raster = "Along M"; + } + std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl; + // Instantiate CUTLASS kernel depending on templates Gemm gemm; @@ -695,27 +736,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true) ScaleMsPerTile, ScaleNsPerTile>(result.avg_runtime_ms / 1000.0); - std::string raster = "Heuristic"; - - if (options.raster_order == RasterOrderOptions::AlongN) { - raster = "Along N"; - } - else if (options.raster_order == RasterOrderOptions::AlongM) { - raster = "Along M"; - } - - std::cout << " Problem Sizes, Alpha, Beta " << std::endl; - for (int32_t i = 0; i < options.groups; ++i) { - std::cout << " " << options.problem_sizes_host.at(i); - std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; - } - std::cout << " Groups : " << options.groups << std::endl; - std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl; - std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl; - std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl; - std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl; std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; std::cout << " GFLOPS: " << result.gflops << std::endl; + std::cout << " GBPS: " << result.gbps << std::endl; + fflush(stdout); } return 0; @@ -766,8 +790,8 @@ int main(int argc, char const **args) { // Evaluate CUTLASS kernels // + std::cout << "Running tests with host problem shapes:" << std::endl; run(options, true); - std::cout << "Running tests without host problem shapes:" << std::endl; run(options, false); diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/CMakeLists.txt b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/CMakeLists.txt index 09d506de..92e62653 100644 --- a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/CMakeLists.txt +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/CMakeLists.txt @@ -44,6 +44,9 @@ set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0) set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=500 --iterations=0) # Small problem sizes +set(TEST_K_16B_ALIGNED --m=256 --n=512 --k=960 --groups=10 --iterations=0) +set(TEST_K_16B_ALIGNED_LARGE_GROUP --m=256 --n=512 --k=960 --groups=512 --iterations=0) + cutlass_example_add_executable( 68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling 68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu @@ -58,6 +61,8 @@ cutlass_example_add_executable( TEST_FIXED_LARGE_GROUP TEST_SMALL TEST_SMALL_LARGE_GROUP + TEST_K_16B_ALIGNED + TEST_K_16B_ALIGNED_LARGE_GROUP ) # MSVC will fail to compile this example with the following error: diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp index dacbb324..ff446a91 100644 --- a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp @@ -111,14 +111,14 @@ struct Options { int m = cmd_line_m; int n = cmd_line_n; int k = cmd_line_k; - if (m < 1) { - m = m_alignment * ((rand() % (64 * alignment / m_alignment)) + 1); + if (m < 0) { + m = m_alignment * (rand() % (64 * alignment / m_alignment)); } - if (n < 1) { - n = n_alignment * ((rand() % (64 * alignment / n_alignment)) + 1); + if (n < 0) { + n = n_alignment * (rand() % (64 * alignment / n_alignment)); } - if (k < 1) { - k = k_alignment * ((rand() % (32 * alignment / k_alignment)) + 1); + if (k < 0) { + k = k_alignment * (rand() % (32 * alignment / k_alignment)); } problem_sizes_after_alignment_host.push_back({m, n, k}); problem_sizes_host.push_back({m, n, k}); diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha.cu b/examples/77_blackwell_fmha/77_blackwell_fmha.cu index 9ed1cf0c..cc09e68a 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha.cu @@ -419,16 +419,16 @@ struct FwdRunner { using ElementAccumulatorPV = float; using ElementOut = cutlass::half_t; - // Q K D (B H) + // Q K D ((H_R, H_K) B) using ProblemShapeRegular = cute::tuple, int>>; using ProblemShapeVarlen = cute::tuple, int>>; using ProblemShapeType = std::conditional_t; - using StrideQ = cute::tuple, int>>; // Q D (H_G H_R B) - using StrideK = cute::tuple, int>>; // K D (H_G H_R B) + using StrideQ = cute::tuple, int>>; // Q D ((H_R, H_K), B) + using StrideK = cute::tuple, int>>; // K D ((H_R, H_K), B) using StrideV = StrideK; using StrideO = StrideQ; - using StrideLSE = cute::tuple<_1, cute::tuple, int>>; // Q (H_G H_R B) + using StrideLSE = cute::tuple<_1, cute::tuple, int>>; // Q ((H_R, H_K), B) static constexpr bool kIsPersistent = find_option_t::value; using TileScheduler = std::conditional_t; @@ -611,8 +611,8 @@ struct FwdRunner { ProblemShapeType problem_size_for_launch; - get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q}; - get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv}; + get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q}; + get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv}; get<2>(problem_size_for_launch) = get<2>(problem_size); get<3>(problem_size_for_launch) = get<3>(problem_size); @@ -669,9 +669,9 @@ struct FwdRunner { } auto buffer_init_fn = [&](auto& buffer) { - buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); - buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0); - buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0); + buffer.block_Q.reset(size(shape_QO)); + buffer.block_K.reset(size(shape_KV)); + buffer.block_V.reset(size(shape_KV)); buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); buffer.block_LSE.reset(size(shape_LSE)); buffer.block_ref_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); diff --git a/examples/77_blackwell_fmha/77_blackwell_mla_fwd.cu b/examples/77_blackwell_fmha/77_blackwell_mla_fwd.cu index 2b587513..2de135a0 100644 --- a/examples/77_blackwell_fmha/77_blackwell_mla_fwd.cu +++ b/examples/77_blackwell_fmha/77_blackwell_mla_fwd.cu @@ -590,8 +590,8 @@ struct MlaFwdRunner { ProblemShapeType problem_size_for_launch; - get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q}; - get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv}; + get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q}; + get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv}; get<2>(problem_size_for_launch) = get<2>(problem_size); get<3>(problem_size_for_launch) = get<3>(problem_size); @@ -651,9 +651,9 @@ struct MlaFwdRunner { } auto buffer_init_fn = [&](auto& buffer) { - buffer.block_Q.reset(size(shape_Q), kIsVarlen ? D_latent_rope*SQ*H : 0); - buffer.block_K.reset(size(shape_K), kIsVarlen ? D_latent_rope*SK*H_K : 0); - buffer.block_V.reset(size(shape_V), kIsVarlen ? D*SK*H_K : 0); + buffer.block_Q.reset(size(shape_Q)); + buffer.block_K.reset(size(shape_K)); + buffer.block_V.reset(size(shape_V)); buffer.block_O.reset(size(shape_O), kIsVarlen ? D*SQ*H : 0); buffer.block_LSE.reset(size(shape_LSE)); buffer.block_ref_O.reset(size(shape_O), kIsVarlen ? D*SQ*H : 0); @@ -849,7 +849,8 @@ struct MlaFwdRunner { flops *= static_cast(size<3,1>(problem_shape)); } - flops *= 2.0 * (std::is_same_v> ? 0.5 : 1.0); + flops *= 2.0 * (std::is_same_v> || + std::is_same_v> ? 0.5 : 1.0); flops *= static_cast(size<3,0>(problem_shape)); double flops0 = flops * static_cast(size<2, 0>(problem_shape) + size<2, 1>(problem_shape)); diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index c848fcfa..ac69081d 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -65,6 +65,8 @@ set(TEST_VARLEN_17 --verify --varlen --mask=causal --causal-type=qbegin --d=128 set(TEST_VARLEN_18 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128) set(TEST_VARLEN_19 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257) set(TEST_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25) +set(TEST_VARLEN_21 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1013 --varlen-k=1024) +set(TEST_VARLEN_22 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1024 --varlen-k=1035) @@ -89,6 +91,8 @@ set(TEST_MLA_FWD_VARLEN_17 --verify --varlen --mask=causal --causal-type=qbegin set(TEST_MLA_FWD_VARLEN_18 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128) set(TEST_MLA_FWD_VARLEN_19 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257) set(TEST_MLA_FWD_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25) +set(TEST_MLA_FWD_VARLEN_21 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1013 --varlen-k=1024) +set(TEST_MLA_FWD_VARLEN_22 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1024 --varlen-k=1035) set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify) @@ -140,6 +144,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC TEST_VARLEN_18 TEST_VARLEN_19 TEST_VARLEN_20 + TEST_VARLEN_21 + TEST_VARLEN_22 ) target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO}) @@ -163,7 +169,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC 77_blackwell_mla.cu TEST_COMMAND_OPTIONS TEST_MLA_BASIC - TEST_MLA_SEP_REDUCTION + TEST_MLA_SEP_REDUCTION TEST_MLA_FUSE_REDUCTION ) target_include_directories(77_blackwell_mla_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) @@ -175,8 +181,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC 77_blackwell_mla.cu TEST_COMMAND_OPTIONS TEST_MLA_BASIC - TEST_MLA_SEP_REDUCTION - TEST_MLA_FUSE_REDUCTION + TEST_MLA_SEP_REDUCTION + TEST_MLA_FUSE_REDUCTION ) target_include_directories(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC) @@ -241,6 +247,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC TEST_MLA_FWD_VARLEN_18 TEST_MLA_FWD_VARLEN_19 TEST_MLA_FWD_VARLEN_20 + TEST_MLA_FWD_VARLEN_21 + TEST_MLA_FWD_VARLEN_22 ) target_include_directories(77_blackwell_mla_fwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_compile_definitions(77_blackwell_mla_fwd_${PREC} PRIVATE ${PREC_MACRO}) diff --git a/examples/77_blackwell_fmha/README.md b/examples/77_blackwell_fmha/README.md index 1e28929b..581515f0 100644 --- a/examples/77_blackwell_fmha/README.md +++ b/examples/77_blackwell_fmha/README.md @@ -8,7 +8,7 @@ For generation usage, use an M-blocking (Num-Groups) of 128 (although the limit Context loads are done via TMA, whereas generation usage utilized `cp.async` and is thus more amenable to complex load patterns. -For variable sequence lenght, the code requires a batch of valid (but never used) padding memory ahead of the first input batch. This is achieved with least overhead by leaving one batch free and then arranging QKV consecutively. +For variable sequence length, the code requires a batch of valid (but never used) padding memory ahead of the first output batch. No padding is needed for the input tensor, but it requires that the input tensor contain no NaN or Inf values. Note that users should set `total_length` to the `problem_shape`. The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an FMHA kernel. The kernel and collective layer are then formulated to be fmha-specific. @@ -67,6 +67,8 @@ For detailed information on how to invoke them, check out either the tests in `C to simplify the sample, clarified that `fmha_gen` sample only supports head dim 128. +* 4.3.0: For variable sequence length, the code requires a batch of valid (but never used) padding memory ahead of the first output batch. No padding is needed for the input tensor, but it requires that the input tensor contain no NaN or Inf values. Note that users should set `total_length` to the `problem_shape`. + # Copyright Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index be090e10..dbdff428 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -225,8 +225,8 @@ struct CausalMask : NoMask { if constexpr (IsQBegin) { return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); } else { - const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape); - return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)))); + const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ; + return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count); } } diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp index f2cbd4d1..afef0224 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -534,14 +534,14 @@ struct Sm100FmhaGenMainloopWarpspecialized { tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); - auto tilePlikeFP32 = get<1>(TileShapeQK{}) / Int{} * Int{}; + auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); // Each thread owns a single row - using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem - using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem + using TMEM_LOAD = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp index 1951056b..3606dcc7 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp @@ -95,32 +95,21 @@ struct Sm100FmhaLoadTmaWarpspecialized { auto dQ = args.dQ; auto dK = args.dK; auto dV = args.dV; - auto problem_shape_qk = problem_shape; + using IntProblemShape = cute::tuple, int>>; + + IntProblemShape problem_shape_qk; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(problem_shape).cumulative_length; - if (cumulative_length_q != nullptr) { - int max_length_q = get<0>(problem_shape).max_length; - // for variable sequence lenght, the batch is in units of row_stride - get<2,1>(dQ) = get<0>(dQ); - get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape))); - // offset ptr by the amount we add back in later - ptr_Q -= max_length_q * get<0>(dQ); - } - } - - if constexpr (is_variable_length_v>) { - auto cumulative_length_kv = get<1>(problem_shape).cumulative_length; - if (cumulative_length_kv != nullptr) { - int max_length_kv = get<1>(problem_shape).max_length; - // for variable sequence lenght, the batch is in units of row_stride - get<2,1>(dK) = get<0>(dK); - get<2,1>(dV) = get<0>(dV); - get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape))); - // offset ptr by the amount we add back in later - ptr_K -= max_length_kv * get<0>(dK); - ptr_V -= max_length_kv * get<0>(dV); + auto cumulative_length_k = get<1>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) { + get<0>(problem_shape_qk) = get<0>(problem_shape).total_length; + get<1>(problem_shape_qk) = get<1>(problem_shape).total_length; + get<2>(problem_shape_qk) = get<2>(problem_shape); + get<3>(problem_shape_qk) = get<3>(problem_shape); } + } else { + problem_shape_qk = problem_shape; } auto params_qk = CollectiveMmaQK::to_underlying_arguments( @@ -181,19 +170,16 @@ struct Sm100FmhaLoadTmaWarpspecialized { Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape)); int q_offs_0 = 0; - int q_offs_2_1 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; if (cumulative_length_q != nullptr) { - int max_length_q = get<0>(params_problem_shape).max_length; - q_offs_0 = max_length_q - get<0>(problem_shape); - q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape); + q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)]; get<2,1>(blk_coord_q) = 0; } } - Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p); + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p); Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); @@ -208,19 +194,16 @@ struct Sm100FmhaLoadTmaWarpspecialized { Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape)); int kv_offs_0 = 0; - int kv_offs_2_1 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length = get<1>(params_problem_shape).cumulative_length; if (cumulative_length != nullptr) { - int max_length = get<1>(params_problem_shape).max_length; - kv_offs_0 = max_length - get<1>(problem_shape); - kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape); + kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)]; get<2,1>(blk_coord_kv) = 0; } } - Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p); + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p); Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); @@ -235,7 +218,7 @@ struct Sm100FmhaLoadTmaWarpspecialized { ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape)); - Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p); + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p); Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp index c2d3e2ba..c8fc13b9 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp @@ -102,32 +102,21 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { auto dQ = args.dQ; auto dK = args.dK; auto dV = args.dV; - auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); + using IntProblemShape = cute::tuple, int>>; + + IntProblemShape problem_shape_qk; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(problem_shape).cumulative_length; - if (cumulative_length_q != nullptr) { - int max_length_q = get<0>(problem_shape).max_length; - // for variable sequence lenght, the batch is in units of row_stride - get<2,1>(dQ) = get<0>(dQ); - get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape))); - // offset ptr by the amount we add back in later - ptr_Q -= max_length_q * get<0>(dQ); - } - } - - if constexpr (is_variable_length_v>) { - auto cumulative_length_kv = get<1>(problem_shape).cumulative_length; - if (cumulative_length_kv != nullptr) { - int max_length_kv = get<1>(problem_shape).max_length; - // for variable sequence lenght, the batch is in units of row_stride - get<2,1>(dK) = get<0>(dK); - get<2,1>(dV) = get<0>(dV); - get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape))); - // offset ptr by the amount we add back in later - ptr_K -= max_length_kv * get<0>(dK); - ptr_V -= max_length_kv * get<0>(dV); + auto cumulative_length_k = get<1>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) { + get<0>(problem_shape_qk) = get<0>(problem_shape).total_length; + get<1>(problem_shape_qk) = get<1>(problem_shape).total_length; + get<2>(problem_shape_qk) = get<2, 0>(problem_shape) + get<2, 1>(problem_shape); + get<3>(problem_shape_qk) = get<3>(problem_shape); } + } else { + problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));; } auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape)); @@ -192,19 +181,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk)); int q_offs_0 = 0; - int q_offs_2_1 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; if (cumulative_length_q != nullptr) { - int max_length_q = get<0>(params_problem_shape).max_length; - q_offs_0 = max_length_q - get<0>(problem_shape); - q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape); + q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)]; get<2,1>(blk_coord_q) = 0; } } - Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p); + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p); Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); @@ -219,19 +205,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk)); int kv_offs_0 = 0; - int kv_offs_2_1 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length = get<1>(params_problem_shape).cumulative_length; if (cumulative_length != nullptr) { - int max_length = get<1>(params_problem_shape).max_length; - kv_offs_0 = max_length - get<1>(problem_shape); - kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape); + kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)]; get<2,1>(blk_coord_kv) = 0; } } - Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p); + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p); Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); @@ -246,7 +229,7 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v)); - Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p); + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p); Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp index 4cc42dc4..742b507d 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -1215,6 +1215,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { Tensor tTR_cST_p = thread_t2r.partition_D(cST); Tensor tTR_cST = split_wg(tTR_cST_p); Tensor tTR_rST = make_tensor(shape(tTR_cST)); + // Tensor tTR_tST_p = thread_t2r.partition_S(tSTtST); Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); @@ -1507,6 +1508,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { CUTLASS_DEVICE void operator()(Params const& params, char* smem) { +#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else int warp_idx = cutlass::canonical_warp_idx_sync(); auto role = warp_idx_to_role(warp_idx); uint32_t lane_predicate = cute::elect_one_sync(); @@ -1835,6 +1839,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { /* no-op */ } +#endif } static dim3 get_block_shape() { diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp index 976e1f26..bf72843a 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp @@ -1480,6 +1480,9 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { CUTLASS_DEVICE void operator()(Params const& params, char* smem) { +#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else int warp_idx = cutlass::canonical_warp_idx_sync(); auto role = warp_idx_to_role(warp_idx); uint32_t lane_predicate = cute::elect_one_sync(); @@ -1804,6 +1807,7 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { /* no-op */ } +#endif } static dim3 get_block_shape() { diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp index 8fe503b4..a88b9a87 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -251,6 +251,9 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else TileScheduler tile_scheduler{params.tile_scheduler}; @@ -465,6 +468,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { else if (role == WarpRole::Correction) { cutlass::arch::warpgroup_reg_dealloc(); + bool has_valid = false; + CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); @@ -476,6 +481,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { continue; } + has_valid = true; + if (get<1>(logical_problem_shape) == 0) { mainloop.correction_empty( blk_coord, @@ -505,16 +512,17 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { if constexpr (NumWarpsEpilogue == 0) { static_assert(NumWarpsCorrection == 1); - uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; - tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + if (has_valid) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } } } else if (role == WarpRole::MMA) { warpgroup_reg_set(); - tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); - __syncwarp(); + bool allocated = false; CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { @@ -527,6 +535,12 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { continue; } + if (!allocated) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + allocated = true; + } + if (get<1>(logical_problem_shape) == 0) { continue; } @@ -580,6 +594,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { else if (role == WarpRole::Epilogue) { warpgroup_reg_set(); + bool has_valid = false; + CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); @@ -591,6 +607,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { continue; } + has_valid = true; + epilogue.store( blk_coord, logical_problem_shape, params.epilogue, params.problem_shape, @@ -602,8 +620,10 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { static_assert(NumWarpsEpilogue <= 1); if constexpr (NumWarpsEpilogue == 1) { - uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; - tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + if(has_valid) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } } } @@ -612,6 +632,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { /* no-op, donate regs and exit */ } +#endif } }; diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp index 5fd8a53c..d59b20ff 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp @@ -247,6 +247,9 @@ struct Sm100FmhaGenKernelWarpspecialized { } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else TileScheduler tile_scheduler{params.tile_scheduler}; @@ -569,6 +572,7 @@ struct Sm100FmhaGenKernelWarpspecialized { /* no-op, donate regs and exit */ } +#endif } }; diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp index 5eb8e20b..e9edb90e 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp @@ -507,6 +507,9 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) { +#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else TileScheduler tile_scheduler(params.tile_scheduler); @@ -814,6 +817,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); } +#endif } template diff --git a/examples/82_blackwell_distributed_gemm/CMakeLists.txt b/examples/82_blackwell_distributed_gemm/CMakeLists.txt index fa8fe9ad..30d2eb49 100644 --- a/examples/82_blackwell_distributed_gemm/CMakeLists.txt +++ b/examples/82_blackwell_distributed_gemm/CMakeLists.txt @@ -26,7 +26,9 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +if (CUTLASS_NVCC_ARCHS MATCHES 100a) cutlass_example_add_executable( 82_blackwell_distributed_gemm 82_blackwell_distributed_gemm.cu ) +endif() diff --git a/examples/86_blackwell_mixed_dtype_gemm/86_blackwell_mixed_dtype.cu b/examples/86_blackwell_mixed_dtype_gemm/86_blackwell_mixed_dtype.cu index dd748452..0a23488e 100644 --- a/examples/86_blackwell_mixed_dtype_gemm/86_blackwell_mixed_dtype.cu +++ b/examples/86_blackwell_mixed_dtype_gemm/86_blackwell_mixed_dtype.cu @@ -331,11 +331,13 @@ bool verify(MixedDtypeOptions const& options) { // // Compute reference output // - + + constexpr int AlignmentBdq = 128 / cutlass::sizeof_bits::value; + using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, MmaType, LayoutA, AlignmentA, - MmaType, LayoutB, AlignmentB, + MmaType, LayoutB, AlignmentBdq, ElementAccumulator, MmaTileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, diff --git a/examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp b/examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp index 528e83cb..eeee7fa2 100644 --- a/examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp +++ b/examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp @@ -137,6 +137,9 @@ struct FmhaKernelTma { } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if ! defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else TileScheduler tile_scheduler{params.tile_scheduler}; // Shared memory. @@ -216,6 +219,7 @@ struct FmhaKernelTma { result, typename CollectiveMainloop::TiledMmaPV{}, params.problem_size, params.epilogue, epi_load_pipeline, storage.epilogue); +#endif } }; diff --git a/examples/88_hopper_fmha/kernel/fmha_kernel_tma_warpspecialized.hpp b/examples/88_hopper_fmha/kernel/fmha_kernel_tma_warpspecialized.hpp index 1e760a3e..4b96d711 100644 --- a/examples/88_hopper_fmha/kernel/fmha_kernel_tma_warpspecialized.hpp +++ b/examples/88_hopper_fmha/kernel/fmha_kernel_tma_warpspecialized.hpp @@ -160,6 +160,9 @@ struct FmhaKernelTmaWarpSpecialized { } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if ! defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else enum class WarpGroupRole { Producer = 0, @@ -412,6 +415,7 @@ struct FmhaKernelTmaWarpSpecialized { if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.arrive(); } } +#endif } }; diff --git a/examples/90_sm103_fp4_ultra_grouped_gemm/CMakeLists.txt b/examples/90_sm103_fp4_ultra_grouped_gemm/CMakeLists.txt index 2c6bd869..70eb6325 100644 --- a/examples/90_sm103_fp4_ultra_grouped_gemm/CMakeLists.txt +++ b/examples/90_sm103_fp4_ultra_grouped_gemm/CMakeLists.txt @@ -26,6 +26,21 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +set(TEST_RANDOM --iterations=0) # Random problem sizes + +set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes +set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=50 --iterations=0) # Random problem sizes + +set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes + +set(TEST_FIXED --m=2048 --n=5120 --k=8192 --iterations=0) # Fixed problem sizes +set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --beta=2.0 --k=512 --groups=51 --iterations=0) + +set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes +set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --beta=0.5 --groups=50 --iterations=0) # Small problem sizes + +set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes + set(TEST_RANDOM_SMALL_GROUP --groups=3 --iterations=1) # Random problem sizes set(TEST_EPILOGUE_SMALL_GROUP --alpha=1.5 --beta=2.0 --groups=3 --iterations=1) # Random problem sizes @@ -35,6 +50,15 @@ cutlass_example_add_executable( 90_sm103_fp4_ultra_grouped_gemm 90_sm103_fp4_ultra_grouped_gemm.cu TEST_COMMAND_OPTIONS + TEST_RANDOM + TEST_EPILOGUE + TEST_EPILOGUE_LARGE_GROUP + TEST_EPILOGUE_OP + TEST_FIXED + TEST_FIXED_LARGE_GROUP + TEST_SMALL + TEST_SMALL_LARGE_GROUP + TEST_RANDOM_PERF TEST_RANDOM_SMALL_GROUP TEST_EPILOGUE_SMALL_GROUP ) diff --git a/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_grouped.cu b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_grouped.cu new file mode 100644 index 00000000..aef683cb --- /dev/null +++ b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_grouped.cu @@ -0,0 +1,701 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Example of Blackwell MoE-style grouped NVFP4 GEMM implementation using TMA to load A and CPASYNC to load B. + + This example demonstrates an implementation of GEMM using mixed TMA+CPASYNC to load input matrices. + In the decoding stage of Mixture of Experts (MoE) models, the number of tokens in different experts + can varies a lot, which requires frequently updates of TMA descriptors in TMA-based implementation. + This examples uses CPASYNC to load activation (B) matrix to avoid the overhead of updating TMA descriptors. + + Usage: + $ ./examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_grouped + --m=28672 --n=4 --k=4096 --l=8 --benchmark=benchmark.txt + +*/ + +#include +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" + + +#include "helper.h" + + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Command line options parsing +struct Options { + + bool help; + bool error; + bool verification; + + int m, n, k, l; + + int iterations; + + std::string benchmark_path; + + Options(): + help(false), + error(false), + verification(true), + m(2048), n(2048), k(2048), l(1), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 2048); + cmd.get_cmd_line_argument("n", n, 2048); + cmd.get_cmd_line_argument("k", k, 2048); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("iterations", iterations, 10); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + + + if (cmd.check_cmd_line_flag("no_verif")) { + verification = false; + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "92_blackwell_moe_gemm_fp4_grouped\n\n" + << " Blackwell MoE-style grouped NVFP4 GEMM implementation using TMA to load A and CPASYNC to load B\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --iterations= Set the number of profiling iterations to perform\n" + << " --benchmark= Executes a benchmark problem size\n" + << " --no_verif Do not run verification kernels\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v || cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ExampleRunner { + // Type of kernel schedule to generate + using MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100; + // Type of epilogue schedule to generate + using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto; + static constexpr bool FuseQuantization = false; + + using LayoutATag = cutlass::layout::RowMajor; + using LayoutBTag = cutlass::layout::ColumnMajor; + using LayoutCTag = cutlass::layout::ColumnMajor; + using LayoutDTag = cutlass::layout::ColumnMajor; + using LayoutSFDTag = LayoutDTag; // Layout type for SFD should be same as D matrix operand + + using ElementInput = cutlass::float_e2m1_t; // Element type for Input matrix operands + using ElementSF = cutlass::float_ue4m3_t; // Element type for SF matrix operands + + using ElementA = cutlass::nv_float4_t; + using ElementB = cutlass::nv_float4_t; + using ElementC = cutlass::half_t; + using ElementD = cute::conditional_t; + using ElementSFD = ElementSF; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementScalar = float; + + + + using ClusterShapeMNK = Shape<_1,_1,_1>; + using MmaTileMNK = Shape<_128,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage) + + static constexpr int AlignmentA = 32; + static constexpr int AlignmentB = 32; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + static constexpr int OutputSFVectorSize = 16; + + // D = alpha * acc + beta * C + // With BlockScaleFactor generation. + using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + OutputSFVectorSize, + ElementD, + ElementCompute, + ElementSFD, LayoutSFDTag, + ElementC>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTileMNK, ClusterShapeMNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + EpilogueScheduleType, + cute::conditional_t< + FuseQuantization, + FusionOperation, + cutlass::epilogue::fusion::LinearCombination> + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + MmaTileMNK, ClusterShapeMNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduleType + >::CollectiveOp; + + using ProblemShapeGroup = cutlass::gemm::GroupProblemShape>; // per group + using ProblemShapeMax = Shape; // max + using ProblemShape = cutlass::gemm::MoEProblemShape; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{})); + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. + using StrideB = typename Gemm::GemmKernel::StrideB; + using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); + + using FusionOp = typename Gemm::EpilogueOutputOp; + static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; + using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig; + using LayoutSFD = typename SfdOutputCfg::LayoutSF; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + LayoutA layout_A; + LayoutSFA layout_SFA; + StrideB stride_B; + LayoutB layout_B; + LayoutSFB layout_SFB; + StrideC stride_C; + LayoutC layout_C; + StrideD stride_D; + LayoutD layout_D; + LayoutSFD layout_SFD; + uint64_t seed = 0; + + cutlass::HostTensor block_A; + cutlass::HostTensor block_SFA; + cutlass::HostTensor block_B; + cutlass::HostTensor block_SFB; + cutlass::HostTensor block_C; + cutlass::HostTensor block_D; + cutlass::HostTensor block_SFD; + cutlass::HostTensor block_reference_D; + cutlass::HostTensor block_reference_SFD; + cutlass::HostTensor block_Normconst; + + cutlass::DeviceAllocation problem_sizes; + + // + // Methods + // + + bool verify(ProblemShape const& problem_size, float alpha, float beta) { + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB); + + // think about how to simplify the gemm3x interface. + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, // ElementAccumulator + decltype(tensor_A), // TensorA + decltype(tensor_SFA), // TensorSfA + decltype(tensor_B), // TensorB + decltype(tensor_SFB) // TensorSfB + > mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + + Tensor tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); + Tensor tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); + Tensor tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD); + + if constexpr (FuseQuantization) { + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementCompute, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_D), // TensorD + decltype(tensor_SFD), // TensorSfD + cute::Int, + cutlass::reference::host::SfStrategy::SfDGen + > epilogue_params {alpha, beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + } + else { + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementCompute, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_D) // TensorD + > epilogue_params {alpha, beta, tensor_C, tensor_D }; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + } + + bool passed = true; + + // Comparison + block_D.sync_host(); + + auto [maxM, maxN, maxK, L] = problem_size.max_problem_shape; + for (int i = 0; i < problem_size.problem_shape.num_groups; i++) { + auto problem = problem_size.problem_shape.get_host_problem_shape(i); + auto [M, N, K] = problem; + + // assume all M == maxM + auto refD_view = block_reference_D.host_view().subview(cutlass::make_Coord(M * N), cutlass::make_Coord(i * maxN * maxM)); + auto D_view = block_D.host_view().subview(cutlass::make_Coord(M * N), cutlass::make_Coord(i * maxN * maxM)); + + passed &= cutlass::reference::host::TensorEquals(refD_view, D_view); + } + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(ProblemShape const& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size.max_problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + // For SFA and SFB tensors layouts + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + // For SFD tensor layout + using Sm1xxBlockScaledOutputConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + // printf("\nStrideC = "); + // print(StrideC{}); + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, L}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, L}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, L}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, L}); + + // printf("\nstride_C = "); + // print(stride_C); + + layout_A = make_layout(make_shape(M, K, L), stride_A); + layout_B = make_layout(make_shape(N, K, L), stride_B); + layout_C = make_layout(make_shape(M, N, L), stride_C); + layout_D = make_layout(make_shape(M, N, L), stride_D); + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, L)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, L)); + layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, L)); + + // printf("\nlayout_A = "); + // print(layout_A); + // printf("\nlayout_B = "); + // print(layout_B); + // printf("\nlayout_C = "); + // print(layout_C); + + // printf("\nsize(layout_A)=%lld", (long long)size(layout_A)); + // printf("\n"); + + block_A.reset(cutlass::make_Coord(size(layout_A))); + block_B.reset(cutlass::make_Coord(size(layout_B))); + block_C.reset(cutlass::make_Coord(size(layout_C))); + block_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD)))); + block_Normconst.reset(cutlass::make_Coord(1)); + + block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA)))); + block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB)))); + block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD)))); + + initialize_block(block_A.host_view(), seed + 2021); + initialize_block(block_B.host_view(), seed + 2022); + initialize_block(block_C.host_view(), seed + 2023); + initialize_block(block_SFA.host_view(), seed + 2024); + initialize_block(block_SFB.host_view(), seed + 2025); + block_Normconst.at(cutlass::make_Coord(0)) = 2; + + block_A.sync_device(); + block_B.sync_device(); + block_C.sync_device(); + block_D.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); + block_SFD.sync_device(); + block_Normconst.sync_device(); + } + + /// Load a benchmark + std::vector benchmark_problems(std::string const& benchmark_path) { + std::vector problem_sizes_host; + + std::ifstream file(benchmark_path); + if (!file.good()) { + return {}; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + extent.at(i) = std::atoi(tokens.at(i).c_str()); + } + problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); + } + + return problem_sizes_host; + } + + bool run(Options const& options, cutlass::KernelHardwareInfo const& hw_info) { + auto problem_sizes_host = benchmark_problems(options.benchmark_path); + if (problem_sizes_host.empty()) { + return false; + } + + problem_sizes.reset(problem_sizes_host.size()); + problem_sizes.copy_from_host(problem_sizes_host.data()); + + ProblemShape problem_size; + problem_size.max_problem_shape = ProblemShapeMax{options.m, options.n, options.k, options.l}; + problem_size.problem_shape.num_groups = problem_sizes_host.size(); + problem_size.problem_shape.problem_shapes = problem_sizes.get(); + problem_size.problem_shape.host_problem_shapes = problem_sizes_host.data(); + + initialize(problem_size); + + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + problem_size, + { // Mainloop arguments + block_A.device_data(), stride_A, + block_B.device_data(), stride_B, + block_SFA.device_data(), layout_SFA, + block_SFB.device_data(), layout_SFB + }, + { // Epilogue arguments + {}, + block_C.device_data(), stride_C, + block_D.device_data(), stride_D + }, + hw_info + }; + + auto f = [&](auto blockscale) { + auto impl = [this](auto& arguments) { + arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data(); + arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data(); + }; + if constexpr (decltype(blockscale)::value) { + impl(arguments); + } + }; + f(std::bool_constant()); + + // arguments.scheduler.max_swizzle_size = options.swizzle; + + arguments.epilogue.thread.alpha = 1.0f; + arguments.epilogue.thread.beta = 0.0f; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + // Run the GEMM + status = gemm_op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + if (options.verification) { + // Verify that the result is correct + bool passed = verify(problem_size, 1.0f, 0.0f); + + std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (!passed) { + exit(-1); + return false; + } + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_op.run()); + } + timer.stop(); + + // Compute average setup and runtime and FLOPs. + float elapsed_ms = timer.elapsed_millis(); + double avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + double flops = double(int64_t(2) * options.m * options.n * options.k * options.l) / (avg_runtime_ms / 1000.0); + + std::cout << " Avg runtime : " << avg_runtime_ms << " ms" << std::endl; + std::cout << " TFLOPS : " << flops / 1e12 << std::endl; + } + + return true; + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + if (!(props.major == 10 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl; + ExampleRunner runner_mixed_tma_cpasync; + runner_mixed_tma_cpasync.run(options, hw_info); + +#endif + + return 0; +} diff --git a/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_regular.cu b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_regular.cu new file mode 100644 index 00000000..e129d07e --- /dev/null +++ b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_regular.cu @@ -0,0 +1,654 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Example of Blackwell MoE-style NVFP4 GEMM implementation using TMA to load A and CPASYNC to load B + + This example demonstrates an implementation of GEMM using mixed TMA+CPASYNC to load input matrices. + In the decoding stage of Mixture of Experts (MoE) models, the number of tokens in different experts + can varies a lot, which requires frequently updates of TMA descriptors in TMA-based implementation. + This examples uses CPASYNC to load activation (B) matrix to avoid the overhead of updating TMA descriptors. + + This example assumes all experts have the same number of tokens, in which case the GEMM becomes a regular (batched) gemm. + For the realistic use case where each expert may have different number of tokens (grouped GEMM), check 92_blackwell_moe_gemm_fp4_grouped. + + Usage: + $ ./examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_regular + --m=28672 --n=4 --k=4096 --l=8 + +*/ + +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" + + +#include "helper.h" + + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Command line options parsing +struct Options { + + bool help; + bool error; + bool verification; + + int m, n, k, l; + + int iterations; + + Options(): + help(false), + error(false), + verification(true), + m(2048), n(2048), k(2048), l(1), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 2048); + cmd.get_cmd_line_argument("n", n, 2048); + cmd.get_cmd_line_argument("k", k, 2048); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("iterations", iterations, 10); + + if (cmd.check_cmd_line_flag("no_verif")) { + verification = false; + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "92_blackwell_moe_gemm_fp4_regular\n\n" + << " Blackwell NVFP4 GEMM implementation using TMA to load A and CPASYNC to load B\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --iterations= Set the number of profiling iterations to perform\n" + << " --no_verif Do not run verification kernels\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v || cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// MSVC complain about it if moved to ExampleRunner +static constexpr int OutputSFVectorSize = 16; +using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig; + +// Wrapper to construct, run, and verify a GEMM. This example showcases CUTLASS's collective +// operation builders by specializing the GEMM on the kernel+epilogue schedule it will use and the +// number of pipeline stages. +template < + // Type of kernel schedule to generate + class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100, + // Type of epilogue schedule to generate + class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto, + bool FuseQuantization = false +> +struct ExampleRunner { + + using LayoutATag = cutlass::layout::RowMajor; + using LayoutBTag = cutlass::layout::ColumnMajor; + using LayoutCTag = cutlass::layout::ColumnMajor; + using LayoutDTag = cutlass::layout::ColumnMajor; + using LayoutSFDTag = LayoutDTag; // Layout type for SFD should be same as D matrix operand + + using ElementInput = cutlass::float_e2m1_t; // Element type for Input matrix operands + using ElementSF = cutlass::float_ue4m3_t; // Element type for SF matrix operands + + using ElementA = cutlass::nv_float4_t; + using ElementB = cutlass::nv_float4_t; + using ElementC = cutlass::half_t; + using ElementD = cute::conditional_t; + using ElementSFD = ElementSF; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementScalar = float; + + + + using ClusterShapeMNK = Shape<_1,_1,_1>; + using MmaTileMNK = Shape<_128,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage) + + static constexpr int AlignmentA = 32; + static constexpr int AlignmentB = 32; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + + // D = alpha * acc + beta * C + // With BlockScaleFactor generation. + using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + OutputSFVectorSize, + ElementD, + ElementCompute, + ElementSFD, LayoutSFDTag, + ElementC>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTileMNK, ClusterShapeMNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + EpilogueScheduleType, + cute::conditional_t< + FuseQuantization, + FusionOperation, + cutlass::epilogue::fusion::LinearCombination> + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + MmaTileMNK, ClusterShapeMNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{})); + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. + using StrideB = typename Gemm::GemmKernel::StrideB; + using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); + + using FusionOp = typename Gemm::EpilogueOutputOp; + static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; + using LayoutSFD = typename SfdOutputCfg::LayoutSF; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + LayoutA layout_A; + LayoutSFA layout_SFA; + StrideB stride_B; + LayoutB layout_B; + LayoutSFB layout_SFB; + StrideC stride_C; + LayoutC layout_C; + StrideD stride_D; + LayoutD layout_D; + LayoutSFD layout_SFD; + uint64_t seed = 0; + + cutlass::HostTensor block_A; + cutlass::HostTensor block_SFA; + cutlass::HostTensor block_B; + cutlass::HostTensor block_SFB; + cutlass::HostTensor block_C; + cutlass::HostTensor block_D; + cutlass::HostTensor block_SFD; + cutlass::HostTensor block_reference_D; + cutlass::HostTensor block_reference_SFD; + cutlass::HostTensor block_Normconst; + + // + // Methods + // + + bool verify(ProblemShapeType const& problem_size, float alpha, float beta) { + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB); + + // think about how to simplify the gemm3x interface. + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, // ElementAccumulator + decltype(tensor_A), // TensorA + decltype(tensor_SFA), // TensorSfA + decltype(tensor_B), // TensorB + decltype(tensor_SFB) // TensorSfB + > mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + + Tensor tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); + Tensor tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); + Tensor tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD); + + if constexpr (FuseQuantization) { + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementCompute, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_D), // TensorD + decltype(tensor_SFD), // TensorSfD + cute::Int, + cutlass::reference::host::SfStrategy::SfDGen + > epilogue_params {alpha, beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + } + else { + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementCompute, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_D) // TensorD + > epilogue_params {alpha, beta, tensor_C, tensor_D }; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + } + + bool passed = true, passed_sfd = true; + + // Comparison + block_D.sync_host(); + passed &= cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view()); + + if constexpr (FuseQuantization) { + passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0); + + block_SFD.sync_host(); + passed_sfd &= cutlass::reference::host::TensorEquals(block_reference_SFD.host_view(), block_SFD.host_view()); + passed_sfd &= (cutlass::reference::host::TensorNorm(block_reference_SFD.host_view()) > 0); + passed_sfd &= (cutlass::reference::host::TensorNorm(block_SFD.host_view()) > 0); + } + + // printf("passed=%d\n", (int)passed); + // printf("passed_sfd=%d\n", (int)passed_sfd); + return passed && passed_sfd; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(ProblemShapeType const& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + // For SFA and SFB tensors layouts + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + // For SFD tensor layout + using Sm1xxBlockScaledOutputConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + // printf("\nStrideC = "); + // print(StrideC{}); + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, L}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, L}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, L}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, L}); + + // printf("\nstride_C = "); + // print(stride_C); + + layout_A = make_layout(make_shape(M, K, L), stride_A); + layout_B = make_layout(make_shape(N, K, L), stride_B); + layout_C = make_layout(make_shape(M, N, L), stride_C); + layout_D = make_layout(make_shape(M, N, L), stride_D); + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, L)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, L)); + layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, L)); + + // printf("\nlayout_A = "); + // print(layout_A); + // printf("\nlayout_B = "); + // print(layout_B); + // printf("\nlayout_C = "); + // print(layout_C); + + // printf("\nsize(layout_A)=%lld", (long long)size(layout_A)); + // printf("\n"); + + block_A.reset(cutlass::make_Coord(size(layout_A))); + block_B.reset(cutlass::make_Coord(size(layout_B))); + block_C.reset(cutlass::make_Coord(size(layout_C))); + block_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD)))); + block_Normconst.reset(cutlass::make_Coord(1)); + + block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA)))); + block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB)))); + block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD)))); + + initialize_block(block_A.host_view(), seed + 2021); + initialize_block(block_B.host_view(), seed + 2022); + initialize_block(block_C.host_view(), seed + 2023); + initialize_block(block_SFA.host_view(), seed + 2024); + initialize_block(block_SFB.host_view(), seed + 2025); + block_Normconst.at(cutlass::make_Coord(0)) = 2; + + block_A.sync_device(); + block_B.sync_device(); + block_C.sync_device(); + block_D.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); + block_SFD.sync_device(); + block_Normconst.sync_device(); + } + + bool run(Options const& options, cutlass::KernelHardwareInfo const& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { // Mainloop arguments + block_A.device_data(), stride_A, + block_B.device_data(), stride_B, + block_SFA.device_data(), layout_SFA, + block_SFB.device_data(), layout_SFB + }, + { // Epilogue arguments + {}, + block_C.device_data(), stride_C, + block_D.device_data(), stride_D + }, + hw_info + }; + + if constexpr (IsBlockScaleSupported) { + arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data(); + arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data(); + } + + // arguments.scheduler.max_swizzle_size = options.swizzle; + + arguments.epilogue.thread.alpha = 1.0f; + arguments.epilogue.thread.beta = 0.0f; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + // Run the GEMM + status = gemm_op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + if (options.verification) { + // Verify that the result is correct + bool passed = verify(problem_size, 1.0f, 0.0f); + + std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (!passed) { + exit(-1); + return false; + } + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_op.run()); + } + timer.stop(); + + // Compute average setup and runtime and FLOPs. + float elapsed_ms = timer.elapsed_millis(); + double avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + double flops = double(int64_t(2) * options.m * options.n * options.k * options.l) / (avg_runtime_ms / 1000.0); + + std::cout << " Avg runtime : " << avg_runtime_ms << " ms" << std::endl; + std::cout << " TFLOPS : " << flops / 1e12 << std::endl; + } + + return true; + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + if (!(props.major == 10 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + std::cout << "Running kernel with TMA load:" << std::endl; + ExampleRunner runner_tma; + runner_tma.run(options, hw_info); + + std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl; + ExampleRunner runner_mixed_tma_cpasync; + runner_mixed_tma_cpasync.run(options, hw_info); + +#endif + + return 0; +} diff --git a/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_grouped.cu b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_grouped.cu new file mode 100644 index 00000000..8e5325af --- /dev/null +++ b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_grouped.cu @@ -0,0 +1,541 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Example of Blackwell MoE-style grouped GEMM implementation using TMA to load A and CPASYNC to load B. + + This example demonstrates an implementation of GEMM using mixed TMA+CPASYNC to load input matrices. + In the decoding stage of Mixture of Experts (MoE) models, the number of tokens in different experts + can varies a lot, which requires frequently updates of TMA descriptors in TMA-based implementation. + This examples uses CPASYNC to load activation (B) matrix to avoid the overhead of updating TMA descriptors. + + Usage: + $ ./examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_grouped + --m=28672 --n=4 --k=4096 --l=8 --benchmark=benchmark.txt + +*/ + +#include +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Command line options parsing +struct Options { + + bool help; + bool error; + bool verification; + + int m, n, k, l; + + int iterations; + + std::string benchmark_path; + + Options(): + help(false), + error(false), + verification(true), + m(2048), n(2048), k(2048), l(1), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 2048); + cmd.get_cmd_line_argument("n", n, 2048); + cmd.get_cmd_line_argument("k", k, 2048); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("iterations", iterations, 10); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + + + if (cmd.check_cmd_line_flag("no_verif")) { + verification = false; + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "92_blackwell_moe_gemm_grouped\n\n" + << " Blackwell MoE-style grouped GEMM implementation using TMA to load A and CPASYNC to load B\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --iterations= Set the number of profiling iterations to perform\n" + << " --benchmark= Executes a benchmark problem size\n" + << " --no_verif Do not run verification kernels\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = static_cast(2); + scope_min = static_cast(0); + } + else if (bits_input <= 8) { + scope_max = static_cast(2); + scope_min = static_cast(-2); + } + else { + scope_max = static_cast(8); + scope_min = static_cast(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ExampleRunner { + + // Type of kernel schedule to generate + using MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100; + // Type of epilogue schedule to generate + using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementScalar = float; + + using ClusterShapeMNK = Shape<_1,_1,_1>; + using MmaTileMNK = Shape<_128,_16,Int<128 / sizeof(ElementA)>>; // use tile size of N=16 to match real use cases (N is typically very small in decoding stage) + + // 16B alignment lets us use TMA + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileMNK, ClusterShapeMNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::LinearCombination + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileMNK, ClusterShapeMNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduleType + >::CollectiveOp; + + using ProblemShapeGroup = cutlass::gemm::GroupProblemShape>; // per group + using ProblemShapeMax = Shape; // max + using ProblemShape = cutlass::gemm::MoEProblemShape; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + //, cutlass::gemm::MoEScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutTagA = cutlass::gemm::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::gemm::detail::StrideToLayoutTagB_t; + using LayoutTagC = cutlass::gemm::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::gemm::detail::StrideToLayoutTagC_t; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + cutlass::DeviceAllocation problem_sizes; + + + // + // Methods + // + + bool verify(ProblemShape const& problem_size, float alpha, float beta) { + auto [maxM, maxN, maxK, L] = problem_size.max_problem_shape; + for (int i = 0; i < problem_size.problem_shape.num_groups; i++) { + auto problem = problem_size.problem_shape.get_host_problem_shape(i); + auto [M, N, K] = problem; + + cutlass::TensorRef ref_A(block_A.get() + size_t(1) * i * maxM * maxK, Gemm::LayoutA(maxK)); + cutlass::TensorRef ref_B(block_B.get() + size_t(1) * i * maxN * maxK, Gemm::LayoutB(maxK)); + cutlass::TensorRef ref_C(block_C.get() + size_t(1) * i * maxN * maxM, Gemm::LayoutC(maxM)); + cutlass::TensorRef ref_D(block_ref_D.get() + size_t(1) * i * maxN * maxM, Gemm::LayoutD(maxM)); + + using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementScalar, + ElementAccumulator>; + + DeviceGemmReference gemm_reference; + + gemm_reference( + {M, N, K}, + ElementScalar(alpha), + ref_A, + ref_B, + ElementScalar(beta), + ref_C, + ref_D); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + // Check if output from CUTLASS kernel and reference kernel are equal or not + // assume all M == maxM + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + size_t(1) * i * maxN * maxM, block_D.get() + size_t(1) * i * maxN * maxM, M * N); + if (!passed) { + return false; + } + } + + return true; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(ProblemShape const& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size.max_problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(size_t(1) * M * K * L); + block_B.reset(size_t(1) * K * N * L); + block_C.reset(size_t(1) * M * N * L); + block_D.reset(size_t(1) * M * N * L); + block_ref_D.reset(size_t(1) * M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + /// Load a benchmark + std::vector benchmark_problems(std::string const& benchmark_path) { + std::vector problem_sizes_host; + + std::ifstream file(benchmark_path); + if (!file.good()) { + return {}; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + extent.at(i) = std::atoi(tokens.at(i).c_str()); + } + problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); + } + + return problem_sizes_host; + } + + bool run(Options const& options, cutlass::KernelHardwareInfo const& hw_info) { + auto problem_sizes_host = benchmark_problems(options.benchmark_path); + if (problem_sizes_host.empty()) { + return false; + } + + problem_sizes.reset(problem_sizes_host.size()); + problem_sizes.copy_from_host(problem_sizes_host.data()); + + ProblemShape problem_size; + problem_size.max_problem_shape = ProblemShapeMax{options.m, options.n, options.k, options.l}; + problem_size.problem_shape.num_groups = problem_sizes_host.size(); + problem_size.problem_shape.problem_shapes = problem_sizes.get(); + problem_size.problem_shape.host_problem_shapes = problem_sizes_host.data(); + + initialize(problem_size); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{}, // epilogue.thread + block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + // arguments.scheduler.max_swizzle_size = options.swizzle; + + arguments.epilogue.thread.alpha = 1.0f; + arguments.epilogue.thread.beta = 0.0f; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + // Run the GEMM + status = gemm_op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + if (options.verification) { + // Verify that the result is correct + bool passed = verify(problem_size, 1.0f, 0.0f); + + std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (!passed) { + exit(-1); + return false; + } + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_op.run()); + } + timer.stop(); + + // Compute average setup and runtime and FLOPs. + float elapsed_ms = timer.elapsed_millis(); + double avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + double flops = double(int64_t(2) * options.m * options.n * options.k * options.l) / (avg_runtime_ms / 1000.0); + + std::cout << " Avg runtime : " << avg_runtime_ms << " ms" << std::endl; + std::cout << " TFLOPS : " << flops / 1e12 << std::endl; + } + + return true; + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + if (!(props.major == 10 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl; + ExampleRunner runner_mixed_tma_cpasync; + runner_mixed_tma_cpasync.run(options, hw_info); + +#endif + + return 0; +} diff --git a/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_regular.cu b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_regular.cu new file mode 100644 index 00000000..f33838c8 --- /dev/null +++ b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_regular.cu @@ -0,0 +1,484 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Example of Blackwell MoE-style GEMM implementation using TMA to load A and CPASYNC to load B. + + This example demonstrates an implementation of GEMM using mixed TMA+CPASYNC to load input matrices. + In the decoding stage of Mixture of Experts (MoE) models, the number of tokens in different experts + can varies a lot, which requires frequently updates of TMA descriptors in TMA-based implementation. + This examples uses CPASYNC to load activation (B) matrix to avoid the overhead of updating TMA descriptors. + + Usage: + $ ./examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_regular + --m=28672 --n=4 --k=4096 --l=8 + +*/ + +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Command line options parsing +struct Options { + + bool help; + bool error; + bool verification; + + int m, n, k, l; + + int iterations; + + Options(): + help(false), + error(false), + verification(true), + m(2048), n(2048), k(2048), l(1), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 2048); + cmd.get_cmd_line_argument("n", n, 2048); + cmd.get_cmd_line_argument("k", k, 2048); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("iterations", iterations, 10); + + if (cmd.check_cmd_line_flag("no_verif")) { + verification = false; + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "92_blackwell_moe_gemm_regular\n\n" + << " Blackwell GEMM implementation using TMA to load A and CPASYNC to load B\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --iterations= Set the number of profiling iterations to perform\n" + << " --no_verif Do not run verification kernels\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = static_cast(2); + scope_min = static_cast(0); + } + else if (bits_input <= 8) { + scope_max = static_cast(2); + scope_min = static_cast(-2); + } + else { + scope_max = static_cast(8); + scope_min = static_cast(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + + + +// Wrapper to construct, run, and verify a GEMM. This example showcases CUTLASS's collective +// operation builders by specializing the GEMM on the kernel+epilogue schedule it will use and the +// number of pipeline stages. +template < + // Type of kernel schedule to generate + class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100, + // Type of epilogue schedule to generate + class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto +> +struct ExampleRunner { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementScalar = float; + + using ClusterShapeMNK = Shape<_1,_1,_1>; + using MmaTileMNK = Shape<_128,_16,Int<128 / sizeof(ElementA)>>; // use tile size of N=16 to match real use cases (N is typically very small in decoding stage) + + // 16B alignment lets us use TMA + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileMNK, ClusterShapeMNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::LinearCombination + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileMNK, ClusterShapeMNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutTagA = cutlass::gemm::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::gemm::detail::StrideToLayoutTagB_t; + using LayoutTagC = cutlass::gemm::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::gemm::detail::StrideToLayoutTagC_t; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(ProblemShapeType const& problem_size, float alpha, float beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + ElementScalar(alpha), + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + ElementScalar(beta), + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(ProblemShapeType const& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(size_t(1) * M * K * L); + block_B.reset(size_t(1) * K * N * L); + block_C.reset(size_t(1) * M * N * L); + block_D.reset(size_t(1) * M * N * L); + block_ref_D.reset(size_t(1) * M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + bool run(Options const& options, cutlass::KernelHardwareInfo const& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{}, // epilogue.thread + block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + // arguments.scheduler.max_swizzle_size = options.swizzle; + + arguments.epilogue.thread.alpha = 1.0f; + arguments.epilogue.thread.beta = 0.0f; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + // Run the GEMM + status = gemm_op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + if (options.verification) { + // Verify that the result is correct + bool passed = verify(problem_size, 1.0f, 0.0f); + + std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (!passed) { + exit(-1); + return false; + } + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_op.run()); + } + timer.stop(); + + // Compute average setup and runtime and FLOPs. + float elapsed_ms = timer.elapsed_millis(); + double avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + double flops = double(int64_t(2) * options.m * options.n * options.k * options.l) / (avg_runtime_ms / 1000.0); + + std::cout << " Avg runtime : " << avg_runtime_ms << " ms" << std::endl; + std::cout << " TFLOPS : " << flops / 1e12 << std::endl; + } + + return true; + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + if (!(props.major == 10 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + std::cout << "Running kernel with TMA load:" << std::endl; + ExampleRunner runner_tma; + runner_tma.run(options, hw_info); + + std::cout << "Running kernel with CPASYNC load:" << std::endl; + ExampleRunner runner_cpasync; + runner_cpasync.run(options, hw_info); + + std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl; + ExampleRunner runner_mixed_tma_cpasync; + runner_mixed_tma_cpasync.run(options, hw_info); + +#endif + + return 0; +} diff --git a/examples/92_blackwell_moe_gemm/CMakeLists.txt b/examples/92_blackwell_moe_gemm/CMakeLists.txt new file mode 100644 index 00000000..c88a461e --- /dev/null +++ b/examples/92_blackwell_moe_gemm/CMakeLists.txt @@ -0,0 +1,70 @@ +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set(TEST_MIXTRAL_A --m=28672 --n=4 --k=4096 --l=8) +set(TEST_MIXTRAL_B --m=4096 --n=4 --k=14336 --l=8) +set(TEST_DEEPSEEK_A --m=4096 --n=1 --k=7168 --l=256) +set(TEST_DEEPSEEK_B --m=7168 --n=1 --k=2048 --l=256) +set(TEST_IRREGULAR_MNK --m=4080 --n=9 --k=4112 --l=8) # M,N,K not multiples of tile size + +set(TEST_DEEPSEEK_A_FP4 --m=1024 --n=1 --k=7168 --l=256) # TP=1 shape is too large for PackedVectorLayout +set(TEST_DEEPSEEK_B_FP4 --m=7168 --n=1 --k=512 --l=256) +set(TEST_IRREGULAR_MNK_FP4 --m=4080 --n=9 --k=4160 --l=8) + +if (CUTLASS_NVCC_ARCHS MATCHES 100a) +cutlass_example_add_executable( + 92_blackwell_moe_gemm_regular + 92_blackwell_moe_gemm_regular.cu + TEST_COMMAND_OPTIONS + TEST_MIXTRAL_A + TEST_MIXTRAL_B + TEST_DEEPSEEK_A + TEST_DEEPSEEK_B + TEST_IRREGULAR_MNK +) + +cutlass_example_add_executable( + 92_blackwell_moe_gemm_grouped + 92_blackwell_moe_gemm_grouped.cu +) + +cutlass_example_add_executable( + 92_blackwell_moe_gemm_fp4_regular + 92_blackwell_moe_gemm_fp4_regular.cu + TEST_COMMAND_OPTIONS + TEST_MIXTRAL_A + TEST_MIXTRAL_B + TEST_DEEPSEEK_A_FP4 + TEST_DEEPSEEK_B_FP4 + TEST_IRREGULAR_MNK_FP4 +) +cutlass_example_add_executable( + 92_blackwell_moe_gemm_fp4_grouped + 92_blackwell_moe_gemm_fp4_grouped.cu +) +endif() diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index b46fbbda..6d97ea56 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -169,6 +169,7 @@ foreach(EXAMPLE 89_sm103_fp4_ultra_gemm 90_sm103_fp4_ultra_grouped_gemm 91_fp4_gemv + 92_blackwell_moe_gemm ) add_subdirectory(${EXAMPLE}) diff --git a/examples/python/CuTeDSL/ampere/all_reduce.py b/examples/python/CuTeDSL/ampere/all_reduce.py new file mode 100644 index 00000000..5f7e20fe --- /dev/null +++ b/examples/python/CuTeDSL/ampere/all_reduce.py @@ -0,0 +1,314 @@ +import os +import torch +import argparse +from cuda import cuda +from cuda.bindings import driver +from typing import Type + +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem +import torch.multiprocessing as mp + +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack +from cutlass._mlir.dialects import llvm, builtin, vector, arith + +WORLD_SIZE = 8 +PING_PONG_SIZE = 3 + + +def setup(rank, world_size): + # set environment variables for torch.distributed environment + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12959" + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + +def cleanup(): + dist.destroy_process_group() + + +class AllReduceKernel: + + @cute.jit + def __call__( + self, + rank, + signal, + local_input: cute.Tensor, + local_output: cute.Tensor, + buffer0: cute.Tensor, + buffer1: cute.Tensor, + buffer2: cute.Tensor, + buffer3: cute.Tensor, + buffer4: cute.Tensor, + buffer5: cute.Tensor, + buffer6: cute.Tensor, + buffer7: cute.Tensor, + stream: cuda.CUstream, + ): + # define constants for future use + num_of_elements = cute.size(local_input.layout) + # 128 threads per block and 4 elements per thread + tv_layout = cute.make_layout(((128), (4)), stride=((1), (1))) + tile = cute.size(tv_layout.shape) + + buffers = [ + buffer0, + buffer1, + buffer2, + buffer3, + buffer4, + buffer5, + buffer6, + buffer7, + ] + tiled_buffers = [ + cute.logical_divide(buffer, (tile, None, None)) for buffer in buffers + ] + + tiled_input = cute.zipped_divide(local_input, cute.make_layout(tile)) + tiled_output = cute.zipped_divide(local_output, cute.make_layout(tile)) + self.kernel( + tiled_buffers[0], + tiled_buffers[1], + tiled_buffers[2], + tiled_buffers[3], + tiled_buffers[4], + tiled_buffers[5], + tiled_buffers[6], + tiled_buffers[7], + tiled_input, + tiled_output, + tv_layout, + cutlass.Int32(signal), + cutlass.Int32(rank), + ).launch( + grid=[num_of_elements // tile, 1, 1], + block=[tv_layout.shape[0], 1, 1], + stream=stream, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + buffer0: cute.Tensor, + buffer1: cute.Tensor, + buffer2: cute.Tensor, + buffer3: cute.Tensor, + buffer4: cute.Tensor, + buffer5: cute.Tensor, + buffer6: cute.Tensor, + buffer7: cute.Tensor, + local_input: cute.Tensor, + local_output: cute.Tensor, + tv_layout: cute.Layout, + signal: cutlass.Int32, + rank: cutlass.Int32, + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + ping = signal % 3 + pong = (signal + 1) % 3 + + buffers = [ + buffer0, + buffer1, + buffer2, + buffer3, + buffer4, + buffer5, + buffer6, + buffer7, + ] + + def get_buffer(): + t = buffers[2] + if rank == cutlass.Int32(0): + t = buffers[0] + if rank == cutlass.Int32(1): + t = buffers[1] + if rank == cutlass.Int32(2): + t = buffers[2] + if rank == cutlass.Int32(3): + t = buffers[3] + if rank == cutlass.Int32(4): + t = buffers[4] + if rank == cutlass.Int32(5): + t = buffers[5] + if rank == cutlass.Int32(6): + t = buffers[6] + if rank == cutlass.Int32(7): + t = buffers[7] + return t + + buffer_local = get_buffer() + cta_coord = (None, bidx) + local_tile_in = local_input[cta_coord] + local_tile_out = local_output[cta_coord] + + ping_coord = ((None, bidx), None, ping) + read_buffer = buffer_local[ping_coord] + + pong_coord = ((None, bidx), None, pong) + clear_buffer = buffer_local[pong_coord] + + write_coord = ((None, bidx), rank, ping) + write_buffers = [buffer[write_coord] for buffer in buffers] + + # assume all buffers have the same element type with input + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + buffer0.element_type, + num_bits_per_copy=64, + memory_scope=cute.nvgpu.common.MemoryScope.SYS, + memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE, + ) + copy_atom_store = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + buffer0.element_type, + num_bits_per_copy=64, + memory_scope=cute.nvgpu.common.MemoryScope.SYS, + memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE, + ) + tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, tv_layout[0], tv_layout[1]) + thr_copy = tiled_copy.get_slice(tidx) + + thr_write_buffer_list = [ + thr_copy.partition_D(tensor) for tensor in write_buffers + ] + thr_read_buffer = thr_copy.partition_S(read_buffer) + thr_clear_buffer = thr_copy.partition_D(clear_buffer) + + thr_in = thr_copy.partition_S(local_tile_in) + thr_out = thr_copy.partition_D(local_tile_out) + + frg_in = cute.make_fragment_like(thr_in) + frg_clear = cute.make_fragment_like(thr_clear_buffer) + frg_acc = cute.make_fragment_like(thr_out) + frg_acc.fill(0.0) + + clear_tensor = frg_clear.load() + frg_size = cute.size(clear_tensor.shape) + neg0_i32_vec = cute.full_like(clear_tensor, 0x80000000, cutlass.Int32) + neg0_f32_vec = vector.bitcast(T.vector(frg_size, T.f32()), neg0_i32_vec) + neg0_f32_tensor = cute.TensorSSA( + neg0_f32_vec, clear_tensor.shape, cutlass.Float32 + ) + frg_clear.store(neg0_f32_tensor) + + cute.copy(copy_atom_load, thr_in, frg_in) + + for thr_write_buffer in thr_write_buffer_list: + cute.copy(copy_atom_store, frg_in, thr_write_buffer) + + cute.copy(copy_atom_store, frg_clear, thr_clear_buffer) + + frg_in_vector_neg0_i32 = cute.full_like( + frg_in, cutlass.Int32(0x80000000), cutlass.Int32 + ) + frg_in_size = cute.size(frg_in.shape) + + for i in range(WORLD_SIZE): + read_coord = (None, 0, i) + cute.copy(copy_atom_load, thr_read_buffer[read_coord], frg_in[None, 0]) + frg_vector = frg_in.load() + frg_vector_i32 = vector.bitcast(T.vector(frg_in_size, T.i32()), frg_vector) + isNotNeg0 = cute.all_(frg_vector_i32 != frg_in_vector_neg0_i32) + while not isNotNeg0: + cute.copy(copy_atom_load, thr_read_buffer[read_coord], frg_in[None, 0]) + frg_vector = frg_in.load() + frg_vector_i32 = vector.bitcast( + T.vector(frg_in_size, T.i32()), frg_vector + ) + isNotNeg0 = cute.all_(frg_vector_i32 != frg_in_vector_neg0_i32) + frg_acc.store(frg_in.load() + frg_acc.load()) + + cute.copy(copy_atom_stg, frg_acc, thr_out) + + +def run_all_reduce(rank, M, N, dtype: Type[cutlass.Numeric]): + setup(rank, WORLD_SIZE) + + input_tensor = torch.randn(M * N, device=f"cuda:{rank}") + output_tensor = torch.zeros(M * N, device=f"cuda:{rank}") + + # init tensors on different devices + t = symm_mem.empty( + [ + PING_PONG_SIZE, + WORLD_SIZE, + M * N, + ], + device="cuda", + ).neg_() + hdl = symm_mem.rendezvous(t, dist.group.WORLD) + buffer_tensor_list = [ + hdl.get_buffer(rank, t.shape, t.dtype).permute(2, 1, 0) + for rank in range(WORLD_SIZE) + ] + + # enable peer access + driver.cuInit(0) + dev_list = [driver.cuDeviceGet(i)[1] for i in range(WORLD_SIZE)] + ctx_list = [driver.cuDevicePrimaryCtxRetain(dev)[1] for dev in dev_list] + for i in range(WORLD_SIZE): + driver.cuCtxSetCurrent(ctx_list[i]) + for j in range(WORLD_SIZE): + if i == j: + continue + driver.cuCtxEnablePeerAccess(ctx_list[j], 0) + driver.cuCtxSetCurrent(ctx_list[rank]) + + stream = cutlass.cuda.default_stream() + all_reduce_kernel = AllReduceKernel() + dlpack_buffers = [from_dlpack(x, assumed_align=32) for x in buffer_tensor_list] + all_reduce_kernel( + rank, + 0, + from_dlpack(input_tensor, assumed_align=32), + from_dlpack(output_tensor, assumed_align=32), + *dlpack_buffers, + stream, + ) + torch.cuda.synchronize(0) + + # use torch api to get reference and inplace stored to input_tensor + ref_tensor = input_tensor.clone() + dist.all_reduce(ref_tensor, op=dist.ReduceOp.SUM) + + # check result of output tensor, allow small error due to different accumulator datatypes + equal_mask = (ref_tensor.cpu() - output_tensor.cpu()).abs() < 1e-4 + result = (equal_mask.sum()).item() == ref_tensor.numel() + + if result: + print(f"rank {rank} test passed") + else: + print(f"rank {rank} test failed") + print( + "ref_tensor[ref_tensor != output_tensor]: ", + ref_tensor[ref_tensor != output_tensor], + ) + print( + "output_tensor[ref_tensor != output_tensor]: ", + output_tensor[ref_tensor != output_tensor], + ) + + cleanup() + + +def main(): + M = 1024 + N = 1024 + + # each process will run run_all_reduce on different device + mp.spawn(run_all_reduce, args=(M, N, cutlass.Float32), nprocs=WORLD_SIZE, join=True) + + return + + +if __name__ == "__main__": + main() diff --git a/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py b/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py new file mode 100644 index 00000000..79a23079 --- /dev/null +++ b/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py @@ -0,0 +1,166 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys +import os +from typing import Tuple +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import make_ptr + + +""" +An Example demonstrating how to call off-the-shelf kernel by-passing dlpack protocol + +The example shows how to directly pass pointers from PyTorch tensors to off-the-shelf kernels +written by CuTe DSL with a thin customized wrapper jit function. The jit function will be +compiled with inline without introducing overhead. + +To run this example: + +.. code-block:: bash + + python examples/ampere/call_bypass_dlpack.py + + +It's worth to mention that by-passing dlpack protocol can resolve the issue that dlpack doesn't handle shape-1 +mode correctly. For example, the following code will fail, because dlpack will convert the shape-1 mode +with stride-1 which propagate alignment incorrectly. + +.. code-block:: python + + @cute.kernel + def fails_kernel(gX: cute.Tensor): + bidx, _, _ = cute.arch.block_idx() + mX = gX[None, bidx, None] # We wish to retain alignment + # assert mX.iterator.alignment == 16 + + + @cute.jit + def fails(gX_: cute.Tensor): + gX = gX_ + fails_kernel(gX).launch(grid=(1, 1, 1), block=(128, 1, 1)) + + + gX_torch = torch.rand((128, 1, 128), device="cuda", dtype=torch.bfloat16) + fails(from_dlpack(gX_torch, assumed_align=16)) + +""" + +# Add the current directory to sys.path +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from tensorop_gemm import TensorOpGemm + + +@cute.jit +def tensor_op_gemm_wrapper( + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + c_ptr: cute.Pointer, + m: cutlass.Int32, + n: cutlass.Int32, + k: cutlass.Int32, + l: cutlass.Int32, +): + print(f"\n[DSL INFO] Input Parameters:") + print(f"[DSL INFO] mnkl: {(m, n, k, l)}") + + # Assume alignment of shape to call tensorop_gemm example + m = cute.assume(m, divby=8) + n = cute.assume(n, divby=8) + + # Torch is row major + a_layout = cute.make_ordered_layout((m, k, l), order=(0, 1, 2)) + b_layout = cute.make_ordered_layout((n, k, l), order=(0, 1, 2)) + c_layout = cute.make_ordered_layout((m, n, l), order=(1, 0, 2)) + mA = cute.make_tensor(a_ptr, layout=a_layout) + mB = cute.make_tensor(b_ptr, layout=b_layout) + mC = cute.make_tensor(c_ptr, layout=c_layout) + + print(f"[DSL INFO] mA: {mA}") + print(f"[DSL INFO] mB: {mB}") + print(f"[DSL INFO] mC: {mC}") + + tensor_op_gemm = TensorOpGemm( + a_ptr.value_type, c_ptr.value_type, cutlass.Float32, (2, 2, 1) + ) + print(f"\n[DSL INFO] Created TensorOpGemm instance") + print(f"[DSL INFO] Input dtype: {a_ptr.value_type}") + print(f"[DSL INFO] Output dtype: {c_ptr.value_type}") + print(f"[DSL INFO] Accumulation dtype: {cutlass.Float32}") + print(f"[DSL INFO] Atom layout: {(2, 2, 1)}") + + # No need to compile inside jit function + tensor_op_gemm(mA, mB, mC) + print(f"\n[DSL INFO] Executed TensorOpGemm") + + +def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]): + print(f"\nRunning TensorOpGemm test with:") + print(f"Tensor dimensions: {mnkl}") + + # (M,K,L) + a = torch.randn( + mnkl[3], mnkl[2], mnkl[0], dtype=torch.float16, device="cuda" + ).permute(2, 1, 0) + # (N,K,L) + b = torch.randn( + mnkl[3], mnkl[2], mnkl[1], dtype=torch.float16, device="cuda" + ).permute(2, 1, 0) + # (N,M,L) + c = torch.randn( + mnkl[3], mnkl[0], mnkl[1], dtype=torch.float16, device="cuda" + ).permute(1, 2, 0) + + print(f"Input tensor shapes:") + print(f"a: {a.shape}, dtype: {a.dtype}") + print(f"b: {b.shape}, dtype: {b.dtype}") + print(f"c: {c.shape}, dtype: {c.dtype}\n") + + a_ptr = make_ptr( + cutlass.Float16, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + b_ptr = make_ptr( + cutlass.Float16, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + c_ptr = make_ptr( + cutlass.Float16, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + tensor_op_gemm_wrapper(a_ptr, b_ptr, c_ptr, *mnkl) + torch.cuda.synchronize() + + ref = torch.einsum("mkl,nkl->mnl", a, b) + torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05) + print(f"\n[DSL INFO] Results verified successfully!") + print(f"First few elements of result: \n{c[:3, :3, :3]}") + + +if __name__ == "__main__": + run_tensor_op_gemm_wrapper((512, 256, 128, 16)) diff --git a/examples/python/CuTeDSL/ampere/call_from_jit.py b/examples/python/CuTeDSL/ampere/call_from_jit.py index ffe2eb70..ee71e53f 100644 --- a/examples/python/CuTeDSL/ampere/call_from_jit.py +++ b/examples/python/CuTeDSL/ampere/call_from_jit.py @@ -226,15 +226,15 @@ def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]): print(f"c: {c.shape}, dtype: {c.dtype}\n") buffer_a = BufferWithLayout( - make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem), + make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32), (2, 1, 0), ) buffer_b = BufferWithLayout( - make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem), + make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32), (2, 1, 0), ) buffer_c = BufferWithLayout( - make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem), + make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32), (2, 1, 0), ) diff --git a/examples/python/CuTeDSL/ampere/distributed_vector_add.py b/examples/python/CuTeDSL/ampere/distributed_vector_add.py new file mode 100644 index 00000000..252d5def --- /dev/null +++ b/examples/python/CuTeDSL/ampere/distributed_vector_add.py @@ -0,0 +1,189 @@ +import os +import torch +import argparse +from typing import Type +from cuda.bindings import driver + +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem +import torch.multiprocessing as mp + +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack + + +def setup(rank, world_size): + # set environment variables for torch.distributed environment + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12995" + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + +def cleanup(): + dist.destroy_process_group() + + +@cute.kernel +def vector_add_kernel( + g0: cute.Tensor, + g1: cute.Tensor, + g2: cute.Tensor, + g3: cute.Tensor, + g4: cute.Tensor, + g5: cute.Tensor, + g6: cute.Tensor, + g7: cute.Tensor, + gOut: cute.Tensor, + tv_layout: cute.Layout, +): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + cta_coord = (None, bidx) + local_tile_out = gOut[cta_coord] + local_tile_list = [ + g0[cta_coord], + g1[cta_coord], + g2[cta_coord], + g3[cta_coord], + g4[cta_coord], + g5[cta_coord], + g6[cta_coord], + g7[cta_coord], + ] + + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + g0.element_type, + memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE, + memory_scope=cute.nvgpu.common.MemoryScope.SYS, + ) + copy_atom_store = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + g0.element_type, + memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE, + memory_scope=cute.nvgpu.common.MemoryScope.SYS, + ) + tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, tv_layout[0], tv_layout[1]) + thr_copy = tiled_copy.get_slice(tidx) + + thr_tensor_list = [thr_copy.partition_S(tensor) for tensor in local_tile_list] + thr_out = thr_copy.partition_D(local_tile_out) + frg_tensor_list = [cute.make_fragment_like(tensor) for tensor in thr_tensor_list] + frg_acc = cute.make_fragment_like(thr_out) + frg_acc.fill(0.0) + + for thr, frg in zip(thr_tensor_list, frg_tensor_list): + cute.copy(copy_atom_load, thr, frg) + tmp = frg.load() + frg_acc.load() + frg_acc.store(tmp) + + cute.copy(copy_atom_store, frg_acc, thr_out) + + +@cute.jit +def vector_add( + m0: cute.Tensor, + m1: cute.Tensor, + m2: cute.Tensor, + m3: cute.Tensor, + m4: cute.Tensor, + m5: cute.Tensor, + m6: cute.Tensor, + m7: cute.Tensor, + output: cute.Tensor, +): + # define constants for future use + num_of_elements = cute.size(m0.layout) + # 128 threads per block and 4 elements per thread + tv_layout = cute.make_layout(((128), (4)), stride=((1), (1))) + tile = cute.size(tv_layout.shape) + + tensors = [m0, m1, m2, m3, m4, m5, m6, m7] + divided_tensors = [ + cute.zipped_divide(tensor, cute.make_layout(tile)) for tensor in tensors + ] + gOut = cute.zipped_divide(output, cute.make_layout(tile)) # ((Tile),(Rest)) + vector_add_kernel( + divided_tensors[0], + divided_tensors[1], + divided_tensors[2], + divided_tensors[3], + divided_tensors[4], + divided_tensors[5], + divided_tensors[6], + divided_tensors[7], + gOut, + tv_layout, + ).launch( + grid=[num_of_elements // tile, 1, 1], + block=[tv_layout.shape[0], 1, 1], + ) + + +def run_vector_add(rank, world_size, M, N, dtype: Type[cutlass.Numeric]): + setup(rank, world_size) + + t = symm_mem.empty(M * N, device="cuda") + hdl = symm_mem.rendezvous(t, dist.group.WORLD) + # get tensors from other devices from the symmetric memory + tensor_list = [hdl.get_buffer(rank, t.shape, t.dtype) for rank in range(world_size)] + tensor_list[rank].random_(0, 100) + + # enable peer access + driver.cuInit(0) + dev_list = [driver.cuDeviceGet(i)[1] for i in range(world_size)] + ctx_list = [driver.cuDevicePrimaryCtxRetain(dev)[1] for dev in dev_list] + driver.cuCtxSetCurrent(ctx_list[rank]) + for i in range(world_size): + if i == rank: + continue + driver.cuCtxEnablePeerAccess(ctx_list[i], 0) + + output = torch.zeros(M * N, device=f"cuda:{rank}") + + # we have to explicitly pass each tensor instead of a list of tensors + vector_add( + from_dlpack(tensor_list[0], assumed_align=32), + from_dlpack(tensor_list[1], assumed_align=32), + from_dlpack(tensor_list[2], assumed_align=32), + from_dlpack(tensor_list[3], assumed_align=32), + from_dlpack(tensor_list[4], assumed_align=32), + from_dlpack(tensor_list[5], assumed_align=32), + from_dlpack(tensor_list[6], assumed_align=32), + from_dlpack(tensor_list[7], assumed_align=32), + from_dlpack(output, assumed_align=32), + ) + + sum_tensor = sum([tensor.cpu() for tensor in tensor_list]) + + if sum(sum_tensor.cpu() == output.cpu()) == sum_tensor.numel(): + print("test passed") + else: + print("test failed") + print(sum_tensor.cpu()) + print(output.cpu()) + + cleanup() + + +def main(): + + world_size = torch.cuda.device_count() + M = 1024 + N = 1024 + + # each process will run run_vector_add on different device + mp.spawn( + run_vector_add, + args=(world_size, M, N, cutlass.Float32), + nprocs=world_size, + join=True, + ) + + return + + +if __name__ == "__main__": + main() diff --git a/examples/python/CuTeDSL/ampere/dynamic_smem_size.py b/examples/python/CuTeDSL/ampere/dynamic_smem_size.py new file mode 100644 index 00000000..dba50548 --- /dev/null +++ b/examples/python/CuTeDSL/ampere/dynamic_smem_size.py @@ -0,0 +1,114 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import cutlass.cute as cute +import cutlass + +""" +Example of automatic shared memory size computation for configuring kernel launch + +This example demonstrates how to let the DSL automatically set shared memory + size for a kernel launch rather explicitly configuring it at launch time, + provided that developers are using `SmemAllocator` for all allocations. + +Usage: + python dynamic_smem_size.py # Show auto inference +""" + + +@cute.struct +class SharedData: + """A struct to demonstrate shared memory allocation.""" + + values: cute.struct.MemRange[cutlass.Float32, 64] # 256 bytes + counter: cutlass.Int32 # 4 bytes + flag: cutlass.Int8 # 1 byte + + +@cute.kernel +def kernel(): + """ + Example kernel that allocates shared memory. + The total allocation will be automatically calculated when smem=None. + """ + allocator = cutlass.utils.SmemAllocator() + + # Allocate various types of shared memory + shared_data = allocator.allocate(SharedData) + raw_buffer = allocator.allocate(512, byte_alignment=64) + int_array = allocator.allocate_array(element_type=cutlass.Int32, num_elems=128) + tensor_smem = allocator.allocate_tensor( + element_type=cutlass.Float16, + layout=cute.make_layout((32, 16)), + byte_alignment=16, + swizzle=None, + ) + return + + +@cute.kernel +def kernel_no_smem(): + """ + Example kernel that does not allocates shared memory. + The total allocation will be automatically calculated as 0 when smem=None. + """ + tidx, _, _ = cute.arch.block_idx() + if tidx == 0: + cute.printf("Hello world") + return + + +if __name__ == "__main__": + # Initialize CUDA context + cutlass.cuda.initialize_cuda_context() + + print("Launching kernel with auto smem size. (launch config `smem=None`)") + + # Compile the example + @cute.jit + def launch_kernel1(): + k = kernel() + k.launch( + grid=(1, 1, 1), + block=(1, 1, 1), + ) + print(f"Kernel recorded internal smem usage: {k.smem_usage()}") + + @cute.jit + def launch_kernel2(): + k = kernel_no_smem() + k.launch( + grid=(1, 1, 1), + block=(1, 1, 1), + ) + print(f"Kernel recorded internal smem usage: {k.smem_usage()}") + + cute.compile(launch_kernel1) + cute.compile(launch_kernel2) + + print("PASS") diff --git a/examples/python/CuTeDSL/ampere/flash_attention_v2.py b/examples/python/CuTeDSL/ampere/flash_attention_v2.py index 441238cc..c6f8ff4c 100644 --- a/examples/python/CuTeDSL/ampere/flash_attention_v2.py +++ b/examples/python/CuTeDSL/ampere/flash_attention_v2.py @@ -327,7 +327,6 @@ class FlashAttentionForwardAmpere: ).launch( grid=grid_dim, block=[self._num_threads, 1, 1], - smem=SharedStorage.size_in_bytes(), stream=stream, ) @@ -1014,13 +1013,10 @@ class FlashAttentionForwardAmpere: ) # compute exp(x - max) using exp2(x * log_2(e) - max * log_2(e)) - acc_S_row_exp = cute.TensorSSA( - self._exp2f( - acc_S_row * softmax_params.softmax_scale_log2 - - row_max_cur_row * softmax_params.softmax_scale_log2 - ), - tuple(acc_S_row.shape), - cutlass.Float32, + acc_S_row_exp = cute.math.exp2( + acc_S_row * softmax_params.softmax_scale_log2 + - row_max_cur_row * softmax_params.softmax_scale_log2, + fastmath=True, ) # acc_S_row_sum => f32 acc_S_row_sum = acc_S_row_exp.reduce( @@ -1028,9 +1024,10 @@ class FlashAttentionForwardAmpere: ) # if it is not the first tile, load the row r of previous row_max and minus row_max_cur_row to update row_sum. if cutlass.const_expr(not is_first_n_block): - prev_minus_cur_exp = self._exp2f( + prev_minus_cur_exp = cute.math.exp2( row_max_prev_row * softmax_params.softmax_scale_log2 - - row_max_cur_row * softmax_params.softmax_scale_log2 + - row_max_cur_row * softmax_params.softmax_scale_log2, + fastmath=True, ) acc_S_row_sum = ( acc_S_row_sum + softmax_params.row_sum[r] * prev_minus_cur_exp @@ -1141,26 +1138,6 @@ class FlashAttentionForwardAmpere: """ return self._threadquad_reduce(val, lambda x, y: x + y) - def _exp2f( - self, x: Union[cute.TensorSSA, cutlass.Float32] - ) -> Union[cute.TensorSSA, cutlass.Float32]: - """exp2f calculation for both vector and scalar. - - :param x: input value - :type x: cute.TensorSSA or cutlass.Float32 - :return: exp2 value - :rtype: cute.TensorSSA or cutlass.Float32 - """ - if isinstance(x, cute.TensorSSA): - res = cute.make_fragment(x.shape, cutlass.Float32) - res.store(x) - - for i in range(cute.size(x.shape)): - res[i] = self._exp2f(res[i]) - - return res.load() - return cute.arch.exp2(x) - def run( dtype: Type[cutlass.Numeric], diff --git a/examples/python/CuTeDSL/ampere/sgemm.py b/examples/python/CuTeDSL/ampere/sgemm.py index 8058f24d..e7722a8d 100644 --- a/examples/python/CuTeDSL/ampere/sgemm.py +++ b/examples/python/CuTeDSL/ampere/sgemm.py @@ -136,10 +136,6 @@ class SGemm: stride=(1, (self._bN + padding_b), self._bK * (self._bN + padding_b)), ) - smem_size = cute.size_in_bytes(mA.element_type, sA_layout) + cute.size_in_bytes( - mB.element_type, sB_layout - ) - # /////////////////////////////////////////////////////////////////////////////// # Create copy layouts that will be used for asynchronous # global memory -> shared memory copies: @@ -258,7 +254,6 @@ class SGemm: ).launch( grid=grid_dim, block=[cute.size(atoms_layout), 1, 1], - smem=smem_size, stream=stream, ) @@ -738,14 +733,20 @@ def run( print("Compiling kernel with cute.compile ...") start_time = time.time() - gemm = cute.compile(sgemm, a_tensor, b_tensor, c_tensor, stream=current_stream) + compiled_fn = cute.compile( + sgemm, + a_tensor, + b_tensor, + c_tensor, + stream=current_stream, + ) compilation_time = time.time() - start_time print(f"Compilation time: {compilation_time:.4f} seconds") print("Executing GEMM kernel...") if not skip_ref_check: - gemm(a_tensor, b_tensor, c_tensor) + compiled_fn(a_tensor, b_tensor, c_tensor) torch.cuda.synchronize() print("Verifying results...") ref = torch.einsum("mk,nk->mn", a, b) @@ -804,7 +805,7 @@ def run( ) avg_time_us = testing.benchmark( - gemm, + compiled_fn, workspace_generator=generate_tensors, workspace_count=workspace_count, stream=current_stream, @@ -837,6 +838,7 @@ if __name__ == "__main__": parser.add_argument("--c_major", choices=["n", "m"], default="n") parser.add_argument("--warmup_iterations", default=2, type=int) parser.add_argument("--iterations", default=100, type=int) + parser.add_argument("--static_shape", action="store_true") parser.add_argument("--skip_ref_check", action="store_true") parser.add_argument( "--use_cold_l2", diff --git a/examples/python/CuTeDSL/ampere/smem_allocator.py b/examples/python/CuTeDSL/ampere/smem_allocator.py index 8c54a5a6..f9f5c1e0 100644 --- a/examples/python/CuTeDSL/ampere/smem_allocator.py +++ b/examples/python/CuTeDSL/ampere/smem_allocator.py @@ -69,7 +69,7 @@ class complex: class SharedStorage: # struct elements with natural alignment a: cute.struct.MemRange[cutlass.Float32, 32] # array - b: cutlass.Int64 # scalar + b: cutlass.Int64 # saclar c: complex # nested struct # struct elements with strict alignment x: cute.struct.Align[ diff --git a/examples/python/CuTeDSL/ampere/tensorop_gemm.py b/examples/python/CuTeDSL/ampere/tensorop_gemm.py index 86110f36..413d1fbf 100644 --- a/examples/python/CuTeDSL/ampere/tensorop_gemm.py +++ b/examples/python/CuTeDSL/ampere/tensorop_gemm.py @@ -471,7 +471,7 @@ class TensorOpGemm: cute.arch.sync_threads() # Start async loads for the first k-tile. Here we take care of the k residue # via if/else check along the k dimension. Because we shifted the identity tensor - # by the residue_k and because the identity tensor is a counting tensor, the + # by the residue_k and because the identity tensor is a coord tensor, the # values of any identity tensor element that is poison is less than -1 num_smem_stages = cute.size(tAsA, mode=[3]) k_tile_count = cute.size(tAgA, mode=[3]) @@ -683,7 +683,7 @@ class TensorOpGemm: # Copy results of D back to shared memory cute.autovec_copy(tCrD, tCsC) - # Create counting tensor for C + # Create coord tensor for C ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1)) mcC = cute.make_identity_tensor( ( diff --git a/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py b/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py index 81720ee9..c9c24b84 100644 --- a/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py +++ b/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py @@ -610,7 +610,6 @@ class Sm100BlockScaledPersistentDenseGemmKernel: grid=grid, block=[self.threads_per_cta, 1, 1], cluster=(*self.cluster_shape_mn, 1), - smem=self.shared_storage.size_in_bytes(), stream=stream, ) return @@ -797,7 +796,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: gC_mnl = cute.local_tile( mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) ) - k_block_cnt = cute.size(gA_mkl, mode=[3]) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) # # Partition global tensor for TiledMMA_A/B/C @@ -946,17 +945,17 @@ class Sm100BlockScaledPersistentDenseGemmKernel: (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) ] - # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt ab_producer_state.reset_count() peek_ab_empty_status = cutlass.Boolean(1) - if ab_producer_state.count < k_block_cnt: + if ab_producer_state.count < k_tile_cnt: peek_ab_empty_status = ab_pipeline.producer_try_acquire( ab_producer_state ) # # Tma load loop # - for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1): + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): # Conditionally wait for AB buffer empty ab_pipeline.producer_acquire( ab_producer_state, peek_ab_empty_status @@ -992,10 +991,10 @@ class Sm100BlockScaledPersistentDenseGemmKernel: mcast_mask=sfb_full_mcast_mask, ) - # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1 + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 ab_producer_state.advance() peek_ab_empty_status = cutlass.Boolean(1) - if ab_producer_state.count < k_block_cnt: + if ab_producer_state.count < k_tile_cnt: peek_ab_empty_status = ab_pipeline.producer_try_acquire( ab_producer_state ) @@ -1103,10 +1102,10 @@ class Sm100BlockScaledPersistentDenseGemmKernel: # (MMA, MMA_M, MMA_N) tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] - # Peek (try_wait) AB buffer full for k_block = 0 + # Peek (try_wait) AB buffer full for k_tile = 0 ab_consumer_state.reset_count() peek_ab_full_status = cutlass.Boolean(1) - if ab_consumer_state.count < k_block_cnt and is_leader_cta: + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: peek_ab_full_status = ab_pipeline.consumer_try_wait( ab_consumer_state ) @@ -1125,7 +1124,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: # # Mma mainloop # - for k_block in range(k_block_cnt): + for k_tile in range(k_tile_cnt): if is_leader_cta: # Conditionally wait for AB buffer full ab_pipeline.consumer_wait( @@ -1154,44 +1153,44 @@ class Sm100BlockScaledPersistentDenseGemmKernel: ) # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB - num_kphases = cute.size(tCrA, mode=[2]) - for kphase_idx in cutlass.range(num_kphases, unroll_full=True): - kphase_coord = ( + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( None, None, - kphase_idx, + kblock_idx, ab_consumer_state.index, ) # Set SFA/SFB tensor to tiled_mma - sf_kphase_coord = (None, None, kphase_idx) + sf_kblock_coord = (None, None, kblock_idx) tiled_mma.set( tcgen05.Field.SFA, - tCtSFA[sf_kphase_coord].iterator, + tCtSFA[sf_kblock_coord].iterator, ) tiled_mma.set( tcgen05.Field.SFB, - tCtSFB[sf_kphase_coord].iterator, + tCtSFB[sf_kblock_coord].iterator, ) cute.gemm( tiled_mma, tCtAcc, - tCrA[kphase_coord], - tCrB[kphase_coord], + tCrA[kblock_coord], + tCrB[kblock_coord], tCtAcc, ) - # Enable accumulate on tCtAcc after first kphase + # Enable accumulate on tCtAcc after first kblock tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Async arrive AB buffer empty ab_pipeline.consumer_release(ab_consumer_state) - # Peek (try_wait) AB buffer full for k_block = k_block + 1 + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 ab_consumer_state.advance() peek_ab_full_status = cutlass.Boolean(1) - if ab_consumer_state.count < k_block_cnt: + if ab_consumer_state.count < k_tile_cnt: if is_leader_cta: peek_ab_full_status = ab_pipeline.consumer_try_wait( ab_consumer_state diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm.py b/examples/python/CuTeDSL/blackwell/dense_gemm.py index c36c28a8..f5a83729 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm.py @@ -486,7 +486,6 @@ class DenseGemmKernel: grid=grid, block=[self.threads_per_cta, 1, 1], cluster=(*self.cluster_shape_mn, 1), - smem=self.shared_storage.size_in_bytes(), stream=stream, ) return @@ -660,7 +659,7 @@ class DenseGemmKernel: gC_mnl = cute.local_tile( mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) ) - k_block_cnt = cute.size(gA_mkl, mode=[3]) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) # # Partition global tensor for TiledMMA_A/B/C @@ -788,19 +787,19 @@ class DenseGemmKernel: # # Pipelining TMA load A/B and MMA mainloop # - prefetch_k_block_cnt = cutlass.min(self.num_ab_stage - 2, k_block_cnt) + prefetch_k_tile_cnt = cutlass.min(self.num_ab_stage - 2, k_tile_cnt) if warp_idx == 0: - # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt peek_ab_empty_status = cutlass.Boolean(1) - if ab_producer_state.count < k_block_cnt: + if ab_producer_state.count < k_tile_cnt: peek_ab_empty_status = ab_pipeline.producer_try_acquire( ab_producer_state ) # # Prefetch TMA load A/B # - for prefetch_idx in cutlass.range(prefetch_k_block_cnt, unroll=1): + for prefetch_idx in cutlass.range(prefetch_k_tile_cnt, unroll=1): # Conditionally wait for AB buffer empty ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) @@ -820,27 +819,27 @@ class DenseGemmKernel: mcast_mask=b_full_mcast_mask, ) - # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1 + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 ab_producer_state.advance() peek_ab_empty_status = cutlass.Boolean(1) - if ab_producer_state.count < k_block_cnt: + if ab_producer_state.count < k_tile_cnt: peek_ab_empty_status = ab_pipeline.producer_try_acquire( ab_producer_state ) - # Peek (try_wait) AB buffer full for k_block = 0 + # Peek (try_wait) AB buffer full for k_tile = 0 peek_ab_full_status = cutlass.Boolean(1) - if ab_consumer_state.count < k_block_cnt and is_leader_cta: + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) # # MMA mainloop # - for k_block in range(k_block_cnt): + for k_tile in range(k_tile_cnt): # Conditionally wait for AB buffer empty ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) - if ab_producer_state.count < k_block_cnt: + if ab_producer_state.count < k_tile_cnt: # TMA load A/B cute.copy( tma_atom_a, @@ -862,35 +861,35 @@ class DenseGemmKernel: ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) # tCtAcc += tCrA * tCrB - num_kphases = cute.size(tCrA, mode=[2]) - for kphase_idx in cutlass.range(num_kphases, unroll_full=True): - kphase_coord = (None, None, kphase_idx, ab_consumer_state.index) + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = (None, None, kblock_idx, ab_consumer_state.index) cute.gemm( tiled_mma, tCtAcc, - tCrA[kphase_coord], - tCrB[kphase_coord], + tCrA[kblock_coord], + tCrB[kblock_coord], tCtAcc, ) - # Enable accumulate on tCtAcc after first kphase + # Enable accumulate on tCtAcc after first kblock tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Async arrive AB buffer empty ab_pipeline.consumer_release(ab_consumer_state) - # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1 + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 ab_producer_state.advance() peek_ab_empty_status = cutlass.Boolean(1) - if ab_producer_state.count < k_block_cnt: + if ab_producer_state.count < k_tile_cnt: peek_ab_empty_status = ab_pipeline.producer_try_acquire( ab_producer_state ) - # Peek (try_wait) AB buffer full for k_block = k_block + 1 + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 ab_consumer_state.advance() peek_ab_full_status = cutlass.Boolean(1) - if ab_consumer_state.count < k_block_cnt: + if ab_consumer_state.count < k_tile_cnt: if is_leader_cta: peek_ab_full_status = ab_pipeline.consumer_try_wait( ab_consumer_state @@ -1009,8 +1008,8 @@ class DenseGemmKernel: # Wait A/B buffer empty # if warp_idx == 0: - # Reverse prefetch_k_block_cnt times to next available buffer - for i in range(prefetch_k_block_cnt): + # Reverse prefetch_k_tile_cnt times to next available buffer + for i in range(prefetch_k_tile_cnt): ab_producer_state.reverse() ab_pipeline.producer_tail(ab_producer_state) return diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py b/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py index 3251abbd..f5022a82 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py @@ -510,7 +510,6 @@ class PersistentDenseGemmKernel: grid=grid, block=[self.threads_per_cta, 1, 1], cluster=(*self.cluster_shape_mn, 1), - smem=self.shared_storage.size_in_bytes(), stream=stream, ) return @@ -669,7 +668,7 @@ class PersistentDenseGemmKernel: gC_mnl = cute.local_tile( mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) ) - k_block_cnt = cute.size(gA_mkl, mode=[3]) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) # # Partition global tensor for TiledMMA_A/B/C @@ -774,17 +773,17 @@ class PersistentDenseGemmKernel: (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) ] - # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt ab_producer_state.reset_count() peek_ab_empty_status = cutlass.Boolean(1) - if ab_producer_state.count < k_block_cnt: + if ab_producer_state.count < k_tile_cnt: peek_ab_empty_status = ab_pipeline.producer_try_acquire( ab_producer_state ) # # Tma load loop # - for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1): + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): # Conditionally wait for AB buffer empty ab_pipeline.producer_acquire( ab_producer_state, peek_ab_empty_status @@ -806,10 +805,10 @@ class PersistentDenseGemmKernel: mcast_mask=b_full_mcast_mask, ) - # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1 + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 ab_producer_state.advance() peek_ab_empty_status = cutlass.Boolean(1) - if ab_producer_state.count < k_block_cnt: + if ab_producer_state.count < k_tile_cnt: peek_ab_empty_status = ab_pipeline.producer_try_acquire( ab_producer_state ) @@ -877,10 +876,10 @@ class PersistentDenseGemmKernel: # (MMA, MMA_M, MMA_N) tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] - # Peek (try_wait) AB buffer full for k_block = 0 + # Peek (try_wait) AB buffer full for k_tile = 0 ab_consumer_state.reset_count() peek_ab_full_status = cutlass.Boolean(1) - if ab_consumer_state.count < k_block_cnt and is_leader_cta: + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: peek_ab_full_status = ab_pipeline.consumer_try_wait( ab_consumer_state ) @@ -899,7 +898,7 @@ class PersistentDenseGemmKernel: # # Mma mainloop # - for k_block in range(k_block_cnt): + for k_tile in range(k_tile_cnt): if is_leader_cta: # Conditionally wait for AB buffer full ab_pipeline.consumer_wait( @@ -907,32 +906,32 @@ class PersistentDenseGemmKernel: ) # tCtAcc += tCrA * tCrB - num_kphases = cute.size(tCrA, mode=[2]) - for kphase_idx in cutlass.range(num_kphases, unroll_full=True): - kphase_coord = ( + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( None, None, - kphase_idx, + kblock_idx, ab_consumer_state.index, ) cute.gemm( tiled_mma, tCtAcc, - tCrA[kphase_coord], - tCrB[kphase_coord], + tCrA[kblock_coord], + tCrB[kblock_coord], tCtAcc, ) - # Enable accumulate on tCtAcc after first kphase + # Enable accumulate on tCtAcc after first kblock tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Async arrive AB buffer empty ab_pipeline.consumer_release(ab_consumer_state) - # Peek (try_wait) AB buffer full for k_block = k_block + 1 + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 ab_consumer_state.advance() peek_ab_full_status = cutlass.Boolean(1) - if ab_consumer_state.count < k_block_cnt: + if ab_consumer_state.count < k_tile_cnt: if is_leader_cta: peek_ab_full_status = ab_pipeline.consumer_try_wait( ab_consumer_state diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py b/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py index d69ad401..8f5e172e 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py @@ -110,17 +110,6 @@ Constraints: """ -class PipelineStateMinimal: - """ - Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer. - """ - - def __init__(self, count, index, phase): - self.count = count - self.index = index - self.phase = phase - - class DenseGemmKernel: """ This class implements batched matrix multiplication (C = A x B) with support for various data types @@ -497,7 +486,6 @@ class DenseGemmKernel: grid=grid, block=[self.threads_per_cta, 1, 1], cluster=(*self.cluster_shape_mn, 1), - smem=self.shared_storage.size_in_bytes(), stream=stream, ) return @@ -576,13 +564,19 @@ class DenseGemmKernel: pipeline.Agent.Thread, num_tma_producer ) ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), num_stages=self.num_ab_stage, producer_group=ab_pipeline_producer_group, consumer_group=ab_pipeline_consumer_group, tx_count=self.num_tma_load_bytes, - barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), cta_layout_vmnk=cluster_layout_vmnk, ) + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) # Initialize acc_pipeline (barrier) and states acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) @@ -590,10 +584,10 @@ class DenseGemmKernel: pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta ) acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), num_stages=self.num_acc_stage, producer_group=acc_pipeline_producer_group, consumer_group=acc_pipeline_consumer_group, - barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), cta_layout_vmnk=cluster_layout_vmnk, ) acc_producer_state = pipeline.make_pipeline_state( @@ -665,7 +659,7 @@ class DenseGemmKernel: gC_mnl = cute.local_tile( mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) ) - k_block_cnt = cute.size(gA_mkl, mode=[3]) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) # # Partition global tensor for TiledMMA_A/B/C @@ -793,24 +787,12 @@ class DenseGemmKernel: # /////////////////////////////////////////////////////////////////////////////// # MAINLOOP # /////////////////////////////////////////////////////////////////////////////// - prefetch_k_block_cnt = cutlass.min(self.num_ab_stage - 2, k_block_cnt) + prefetch_k_tile_cnt = cutlass.min(self.num_ab_stage - 2, k_tile_cnt) if warp_idx == 0: - for k_block in cutlass.range( - k_block_cnt, - pipelining=self.num_ab_stage - 2, + for k_tile in cutlass.range( + k_tile_cnt, + prefetch_stages=self.num_ab_stage - 2, ): - ab_producer_state = PipelineStateMinimal( - k_block, - k_block % self.num_ab_stage, - cutlass.Int32((k_block // self.num_ab_stage) % 2) ^ 1, - ) - - ab_consumer_state = PipelineStateMinimal( - k_block, - k_block % self.num_ab_stage, - cutlass.Int32((k_block // self.num_ab_stage) % 2), - ) - # wait for AB buffer empty ab_pipeline.producer_acquire(ab_producer_state) @@ -835,22 +817,26 @@ class DenseGemmKernel: ab_pipeline.consumer_wait(ab_consumer_state) # tCtAcc += tCrA * tCrB - num_kphases = cute.size(tCrA, mode=[2]) - for kphase_idx in cutlass.range(num_kphases, unroll_full=True): - kphase_coord = (None, None, kphase_idx, ab_consumer_state.index) + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = (None, None, kblock_idx, ab_consumer_state.index) cute.gemm( tiled_mma, tCtAcc, - tCrA[kphase_coord], - tCrB[kphase_coord], + tCrA[kblock_coord], + tCrB[kblock_coord], tCtAcc, ) - # Enable accumulate on tCtAcc after first kphase + # Enable accumulate on tCtAcc after first kblock tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Async arrive AB buffer empty ab_pipeline.consumer_release(ab_consumer_state) + + ab_producer_state.advance() + ab_consumer_state.advance() + # Async arrive accumulator buffer full if is_leader_cta: acc_pipeline.producer_commit(acc_producer_state) @@ -964,12 +950,10 @@ class DenseGemmKernel: # Wait A/B buffer empty # if warp_idx == 0: - ab_producer_state = PipelineStateMinimal( - k_block_cnt, - k_block_cnt % self.num_ab_stage, - cutlass.Int32((k_block_cnt // self.num_ab_stage) % 2) ^ 1, - ) - ab_pipeline.producer_acquire(ab_producer_state) + # Reverse prefetch_k_tile_cnt times to next available buffer + for i in range(prefetch_k_tile_cnt): + ab_producer_state.reverse() + ab_pipeline.producer_tail(ab_producer_state) return def epilog_tmem_copy_and_partition( @@ -1579,7 +1563,6 @@ def run_dense_gemm( warmup_iterations: int = 0, iterations: int = 1, skip_ref_check: bool = False, - measure_launch_overhead=False, ): """ Prepare A/B/C tensors, launch GPU kernel, and reference checking. @@ -1725,7 +1708,7 @@ def run_dense_gemm( ref_c = ref elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: # m major: (l, n, m) -> (m, n, l) - # k major: (l, m, n) -> (m, n, l) + # n major: (l, m, n) -> (m, n, l) permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) shape = (l, m, n) if c_major == "n" else (l, n, m) f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( diff --git a/examples/python/CuTeDSL/blackwell/fmha.py b/examples/python/CuTeDSL/blackwell/fmha.py index 259ddb85..560129df 100644 --- a/examples/python/CuTeDSL/blackwell/fmha.py +++ b/examples/python/CuTeDSL/blackwell/fmha.py @@ -253,6 +253,11 @@ class MaskType(enum.Enum): RESIDUAL_MASK = enum.auto() CAUSAL_MASK = enum.auto() + +def make_thread_cooperative_group(size: int): + return pipeline.CooperativeGroup(pipeline.Agent.Thread, size, size) + + class BlackwellFusedMultiHeadAttentionForward: def __init__( self, @@ -662,7 +667,6 @@ class BlackwellFusedMultiHeadAttentionForward: grid=grid, block=[self.threads_per_cta, 1, 1], cluster=self.cluster_shape_mnk, - smem=self.shared_storage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, ) @@ -760,32 +764,85 @@ class BlackwellFusedMultiHeadAttentionForward: smem = utils.SmemAllocator() storage = smem.allocate(self.shared_storage) - load_q_pipeline = self.make_and_init_load_q_pipeline( - storage.load_q_mbar_ptr.data_ptr() - ) - load_kv_pipeline = self.make_and_init_load_kv_pipeline( - storage.load_kv_mbar_ptr.data_ptr() - ) - mma_s0_pipeline = self.make_and_init_mma_si_pipeline( - storage.mma_s0_mbar_ptr.data_ptr() - ) - mma_s1_pipeline = self.make_and_init_mma_si_pipeline( - storage.mma_s1_mbar_ptr.data_ptr() - ) - s0_corr_pipeline = self.make_and_init_si_corr_pipeline( - storage.s0_corr_mbar_ptr.data_ptr() - ) - s1_corr_pipeline = self.make_and_init_si_corr_pipeline( - storage.s1_corr_mbar_ptr.data_ptr() - ) - corr_epi_pipeline = self.make_and_init_corr_epi_pipeline( - storage.corr_epi_mbar_ptr.data_ptr() - ) - mma_corr_pipeline = self.make_and_init_mma_corr_pipeline( - storage.mma_corr_mbar_ptr.data_ptr() - ) - s0_s1_sequence_pipeline = self.make_and_init_si_sequence_pipeline( - storage.s0_s1_sequence_mbar_ptr.data_ptr() + load_q_producer, load_q_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.q_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_q_bytes, + barrier_storage=storage.load_q_mbar_ptr.data_ptr(), + ).make_participants() + load_kv_producer, load_kv_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.kv_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_kv_bytes, + barrier_storage=storage.load_kv_mbar_ptr.data_ptr(), + ).make_participants() + mma_s0_producer, mma_s0_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_softmax_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax0_warp_ids) + ), + barrier_storage=storage.mma_s0_mbar_ptr.data_ptr(), + ).make_participants() + mma_s1_producer, mma_s1_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_softmax_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax1_warp_ids) + ), + barrier_storage=storage.mma_s1_mbar_ptr.data_ptr(), + ).make_participants() + s0_corr_producer, s0_corr_consumer = pipeline.PipelineAsync.create( + num_stages=self.softmax_corr_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax0_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.s0_corr_mbar_ptr.data_ptr(), + ).make_participants() + s1_corr_producer, s1_corr_consumer = pipeline.PipelineAsync.create( + num_stages=self.softmax_corr_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax1_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.s1_corr_mbar_ptr.data_ptr(), + ).make_participants() + corr_epi_producer, corr_epi_consumer = pipeline.PipelineAsync.create( + num_stages=self.epi_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len([self.epilogue_warp_id]) + ), + barrier_storage=storage.corr_epi_mbar_ptr.data_ptr(), + ).make_participants() + mma_corr_producer, mma_corr_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_corr_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.mma_corr_mbar_ptr.data_ptr(), + ).make_participants() + s0_s1_sequence_producer, s0_s1_sequence_consumer = ( + pipeline.PipelineAsync.create( + num_stages=1, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax0_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax1_warp_ids) + ), + barrier_storage=storage.s0_s1_sequence_mbar_ptr.data_ptr(), + ).make_participants() ) tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() @@ -802,7 +859,6 @@ class BlackwellFusedMultiHeadAttentionForward: ) ), ) - cute.arch.mbarrier_init_fence() # Generate smem tensor Q/K/V/O @@ -818,22 +874,18 @@ class BlackwellFusedMultiHeadAttentionForward: # Strip swizzle info to reuse smem sV_ptr = cute.recast_ptr(sK.iterator, v_smem_layout_staged.inner) sV = cute.make_tensor(sV_ptr, v_smem_layout_staged.outer) - sO = storage.sO.get_tensor( o_smem_layout_staged.outer, swizzle=o_smem_layout_staged.inner ) qk_thr_mma = qk_tiled_mma.get_slice(0) # default 1sm pv_thr_mma = pv_tiled_mma.get_slice(0) # default 1sm - tSrQ = qk_thr_mma.make_fragment_A(sQ) tSrK = qk_thr_mma.make_fragment_B(sK) tOrV = pv_thr_mma.make_fragment_B(sV) - qk_acc_shape = qk_thr_mma.partition_shape_C( (self.qk_mma_tiler[0], self.qk_mma_tiler[1]) ) tStS = qk_thr_mma.make_fragment_C(qk_acc_shape) - pv_acc_shape = pv_thr_mma.partition_shape_C( (self.pv_mma_tiler[0], self.pv_mma_tiler[1]) ) @@ -841,13 +893,11 @@ class BlackwellFusedMultiHeadAttentionForward: tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) tStS1 = cute.make_tensor(tStS.iterator + self.tmem_s1_offset, tStS.layout) - tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) tOtO1 = cute.make_tensor(tOtO.iterator + self.tmem_o1_offset, tOtO.layout) tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] - tOrP0 = cute.make_tensor( tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, @@ -858,12 +908,10 @@ class BlackwellFusedMultiHeadAttentionForward: + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p1_offset, tOrP.layout, ) - cute.arch.barrier( barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta, ) - # /////////////////////////////////////////////////////////////////////////////// # EMPTY # /////////////////////////////////////////////////////////////////////////////// @@ -876,13 +924,6 @@ class BlackwellFusedMultiHeadAttentionForward: if warp_idx == self.load_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) - q_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.q_stage - ) - kv_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.kv_stage - ) - tile_sched = create_fmha_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) @@ -894,7 +935,6 @@ class BlackwellFusedMultiHeadAttentionForward: continue_cond = False cuseqlen_q = Int32(0) seqlen_q = mQ_qdl.shape[0] - if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q @@ -905,7 +945,6 @@ class BlackwellFusedMultiHeadAttentionForward: seqlen_q, ) ) - if not continue_cond: mQ_qdl_ = mQ_qdl mK_kdl_ = mK_kdl @@ -991,85 +1030,62 @@ class BlackwellFusedMultiHeadAttentionForward: # Q0 q0_coord = 2 * curr_block_coord_q[0] - load_q_pipeline.producer_acquire(q_producer_state) + q0_handle = load_q_producer.acquire_and_advance() cute.copy( tma_atom_q, tQgQ[None, q0_coord], - tQsQ[None, q_producer_state.index], - tma_bar_ptr=load_q_pipeline.producer_get_barrier( - q_producer_state - ), + tQsQ[None, q0_handle.index], + tma_bar_ptr=q0_handle.barrier, ) - q_producer_state.advance() - # K0 kv_coord = 0 # seqlen_kv_loop - load_kv_pipeline.producer_acquire(kv_producer_state) + k_handle = load_kv_producer.acquire_and_advance() cute.copy( tma_atom_k, tKgK[None, kv_coord], - tKsK[None, kv_producer_state.index], - tma_bar_ptr=load_kv_pipeline.producer_get_barrier( - kv_producer_state - ), + tKsK[None, k_handle.index], + tma_bar_ptr=k_handle.barrier, ) - kv_producer_state.advance() - # Q1 q1_coord = q0_coord + 1 - load_q_pipeline.producer_acquire(q_producer_state) + q1_handle = load_q_producer.acquire_and_advance() cute.copy( tma_atom_q, tQgQ[None, q1_coord], - tQsQ[None, q_producer_state.index], - tma_bar_ptr=load_q_pipeline.producer_get_barrier( - q_producer_state - ), + tQsQ[None, q1_handle.index], + tma_bar_ptr=q1_handle.barrier, ) - q_producer_state.advance() - # V0 - load_kv_pipeline.producer_acquire(kv_producer_state) + v_handle = load_kv_producer.acquire_and_advance() cute.copy( tma_atom_v, tVgV[None, kv_coord], - tVsV[None, kv_producer_state.index], - tma_bar_ptr=load_kv_pipeline.producer_get_barrier( - kv_producer_state - ), + tVsV[None, v_handle.index], + tma_bar_ptr=v_handle.barrier, ) - kv_producer_state.advance() kv_coord += 1 seqlen_kv_loop_steps = ( self.get_trip_count(curr_block_coord, self.cta_tiler, seqlen_k) - 1 ) - for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): # Ki - load_kv_pipeline.producer_acquire(kv_producer_state) + k_handle = load_kv_producer.acquire_and_advance() cute.copy( tma_atom_k, tKgK[None, kv_coord], - tKsK[None, kv_producer_state.index], - tma_bar_ptr=load_kv_pipeline.producer_get_barrier( - kv_producer_state - ), + tKsK[None, k_handle.index], + tma_bar_ptr=k_handle.barrier, ) - kv_producer_state.advance() # Vi - load_kv_pipeline.producer_acquire(kv_producer_state) - + v_handle = load_kv_producer.acquire_and_advance() cute.copy( tma_atom_v, tVgV[None, kv_coord], - tVsV[None, kv_producer_state.index], - tma_bar_ptr=load_kv_pipeline.producer_get_barrier( - kv_producer_state - ), + tVsV[None, v_handle.index], + tma_bar_ptr=v_handle.barrier, ) - kv_producer_state.advance() kv_coord += 1 # End of seqlen_kv loop @@ -1090,23 +1106,6 @@ class BlackwellFusedMultiHeadAttentionForward: barrier_id=self.tmem_alloc_sync_bar_id, number_of_threads=self.threads_per_warp, ) - mma_q_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.q_stage - ) - mma_kv_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.kv_stage - ) - mma_q_release_state = mma_q_consumer_state.clone() - mma_kv_release_state = mma_kv_consumer_state.clone() - mma_s0_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_softmax_stage - ) - mma_s1_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_softmax_stage - ) - mma_corr_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_corr_stage - ) tile_sched = create_fmha_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) @@ -1135,15 +1134,13 @@ class BlackwellFusedMultiHeadAttentionForward: # GEMM_QK00 (Q0 * K0 -> S0) # 1. wait for Q0 - load_q_pipeline.consumer_wait(mma_q_consumer_state) - tSrQ0 = tSrQ[None, None, None, mma_q_consumer_state.index] - mma_q_consumer_state.advance() + q0_handle = load_q_consumer.wait_and_advance() + tSrQ0 = tSrQ[None, None, None, q0_handle.index] # 2. wait for K0 - load_kv_pipeline.consumer_wait(mma_kv_consumer_state) - tSrK0 = tSrK[None, None, None, mma_kv_consumer_state.index] - mma_kv_consumer_state.advance() + k_handle = load_kv_consumer.wait_and_advance() + tSrK0 = tSrK[None, None, None, k_handle.index] # 3. acquire empty S0 buffer - mma_s0_pipeline.producer_acquire(mma_s0_producer_state) + s0_handle = mma_s0_producer.acquire_and_advance() # 4. gemm num_kphases = cute.size(tSrQ0, mode=[2]) for kphase_idx in cutlass.range(num_kphases, unroll_full=True): @@ -1157,17 +1154,15 @@ class BlackwellFusedMultiHeadAttentionForward: tStS0, ) # 5. release S0 - mma_s0_pipeline.producer_commit(mma_s0_producer_state) - mma_s0_producer_state.advance() + s0_handle.commit() # End of GEMM (Q0 * K0 -> S0) # GEMM_QK10 (Q1 * K0 -> S1), K0 is ready in GEMM_QK00 # 1. wait for Q1 - load_q_pipeline.consumer_wait(mma_q_consumer_state) - tSrQ1 = tSrQ[None, None, None, mma_q_consumer_state.index] - mma_q_consumer_state.advance() + q1_handle = load_q_consumer.wait_and_advance() + tSrQ1 = tSrQ[None, None, None, q1_handle.index] # 2. acquire empty S1 - mma_s1_pipeline.producer_acquire(mma_s1_producer_state) + s1_handle = mma_s1_producer.acquire_and_advance() # 3. gemm num_kphases = cute.size(tSrQ1, mode=[2]) for kphase_idx in cutlass.range(num_kphases, unroll_full=True): @@ -1181,28 +1176,25 @@ class BlackwellFusedMultiHeadAttentionForward: tStS1, ) # 4. release S1 - mma_s1_pipeline.producer_commit(mma_s1_producer_state) - mma_s1_producer_state.advance() + s1_handle.commit() # 5. release K0 - load_kv_pipeline.consumer_release(mma_kv_release_state) - mma_kv_release_state.advance() + k_handle.release() # End of GEMM (Q1 * K0 -> S1) # Note: Q0 & Q1 are still needed in the seqlen_kv loop # so we need to release them after the seqlen_kv loop # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop # 1. wait for V0 - load_kv_pipeline.consumer_wait(mma_kv_consumer_state) - tOrVi = tOrV[None, None, None, mma_kv_consumer_state.index] - mma_kv_consumer_state.advance() + v_handle = load_kv_consumer.wait_and_advance() + tOrVi = tOrV[None, None, None, v_handle.index] # 2. acquire corrected O0_partial # Note: acquire corr first to take it out of the critical # path since softmax takes longer - mma_corr_pipeline.producer_acquire(mma_corr_producer_state) + o0_handle = mma_corr_producer.acquire_and_advance() # 3. acquire P0 # this acquire returns the ownership of all of S0 to the mma warp # including the P0 part (inplaced in S0) - mma_s0_pipeline.producer_acquire(mma_s0_producer_state) + s0_handle = mma_s0_producer.acquire_and_advance() # 4. gemm num_kphases = cute.size(tOrP0, mode=[2]) for kphase_idx in cutlass.range(num_kphases, unroll_full=True): @@ -1216,22 +1208,21 @@ class BlackwellFusedMultiHeadAttentionForward: tOtO0, ) # 5. release accumulated O0_partial - mma_corr_pipeline.producer_commit(mma_corr_producer_state) - mma_corr_producer_state.advance() + o0_handle.commit() # End of GEMM_PV00 (P0 * V0 -> O0_partial) seqlen_kv_loop_steps = ( self.get_trip_count(curr_block_coord, self.cta_tiler, seqlen_k) - 1 ) + # O1 hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate pv_whether_acc = False for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): # GEMM_QK0i (Q0 * Ki -> S0) # 1. wait for Ki - load_kv_pipeline.consumer_wait(mma_kv_consumer_state) - tSrKi = tSrK[None, None, None, mma_kv_consumer_state.index] - mma_kv_consumer_state.advance() + k_handle = load_kv_consumer.wait_and_advance() + tSrKi = tSrK[None, None, None, k_handle.index] # 2. gemm inner_num_kphases = cute.size(tSrQ0, mode=[2]) for kphase_idx in cutlass.range( @@ -1247,15 +1238,14 @@ class BlackwellFusedMultiHeadAttentionForward: tStS0, ) # 3. release S0 - mma_s0_pipeline.producer_commit(mma_s0_producer_state) - mma_s0_producer_state.advance() + s0_handle.commit() # End of GEMM_QK0i (Q0 * Ki -> S0) # GEMM_PV1(i-1) (P1 * V(i-1) -> O1_partial), V(i-1) is ready in GEMM_PV0(i-1) # 1. acquire corrected O1_partial - mma_corr_pipeline.producer_acquire(mma_corr_producer_state) + o1_handle = mma_corr_producer.acquire_and_advance() # 2. acquire P1 - mma_s1_pipeline.producer_acquire(mma_s1_producer_state) + s1_handle = mma_s1_producer.acquire_and_advance() # 3. gemm inner_num_kphases = cute.size(tOrP0, mode=[2]) for kphase_idx in cutlass.range( @@ -1272,11 +1262,9 @@ class BlackwellFusedMultiHeadAttentionForward: ) pv_whether_acc = True # 4. release accumulated O1_partial - mma_corr_pipeline.producer_commit(mma_corr_producer_state) - mma_corr_producer_state.advance() + o1_handle.commit() # 5. release V(i-1) - load_kv_pipeline.consumer_release(mma_kv_release_state) - mma_kv_release_state.advance() + v_handle.release() # End of GEMM_PV1(i-1) (P1 * V(i-1) -> O1_partial) # GEMM_QK1i (Q1 * Ki -> S1), Q1 is ready in GEMM_QK10; Ki is ready in GEMM_QK0i @@ -1294,22 +1282,19 @@ class BlackwellFusedMultiHeadAttentionForward: tSrKi[kphase_coord_5], tStS1, ) - mma_s1_pipeline.producer_commit(mma_s1_producer_state) - mma_s1_producer_state.advance() + s1_handle.commit() # 2. release Ki - load_kv_pipeline.consumer_release(mma_kv_release_state) - mma_kv_release_state.advance() + k_handle.release() # End of GEMM_QK1i (Q1 * Ki -> S1) # GEMM_PV0i (P0 * Vi -> O0_partial) # 1. wait for Vi - load_kv_pipeline.consumer_wait(mma_kv_consumer_state) - tOrVi = tOrV[None, None, None, mma_kv_consumer_state.index] - mma_kv_consumer_state.advance() + v_handle = load_kv_consumer.wait_and_advance() + tOrVi = tOrV[None, None, None, v_handle.index] # 2. acquire corrected O0_partial - mma_corr_pipeline.producer_acquire(mma_corr_producer_state) + o0_handle = mma_corr_producer.acquire_and_advance() # 3. acquire P0 - mma_s0_pipeline.producer_acquire(mma_s0_producer_state) + s0_handle = mma_s0_producer.acquire_and_advance() # 4. gemm inner_num_kphases = cute.size(tOrP0, mode=[2]) for kphase_idx in cutlass.range( @@ -1325,22 +1310,19 @@ class BlackwellFusedMultiHeadAttentionForward: tOtO0, ) # 5. release accumulated O0_partial - mma_corr_pipeline.producer_commit(mma_corr_producer_state) - mma_corr_producer_state.advance() + o0_handle.commit() # End of GEMM_PV0i (P0 * Vi -> O0_partial) # End of seqlen_kv loop # release Q0 & Q1 - load_q_pipeline.consumer_release(mma_q_release_state) - mma_q_release_state.advance() - load_q_pipeline.consumer_release(mma_q_release_state) - mma_q_release_state.advance() + q0_handle.release() + q1_handle.release() # GEMM_PV1(i_end) (P1 * Vi_end -> O1) # 1. acquire corrected O1_partial - mma_corr_pipeline.producer_acquire(mma_corr_producer_state) + o1_handle = mma_corr_producer.acquire_and_advance() # 2. acquire P1 - mma_s1_pipeline.producer_acquire(mma_s1_producer_state) + s1_handle = mma_s1_producer.acquire_and_advance() # 3. gemm num_kphases = cute.size(tOrP1, mode=[2]) for kphase_idx in cutlass.range(num_kphases, unroll_full=True): @@ -1354,18 +1336,14 @@ class BlackwellFusedMultiHeadAttentionForward: tOtO1, ) # 4. commit accumulated O1 - mma_corr_pipeline.producer_commit(mma_corr_producer_state) - mma_corr_producer_state.advance() + o1_handle.commit() # 5. release Vi_end - load_kv_pipeline.consumer_release(mma_kv_release_state) - mma_kv_release_state.advance() + v_handle.release() # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) # Commit S0 and S1 - mma_s0_pipeline.producer_commit(mma_s0_producer_state) - mma_s0_producer_state.advance() - mma_s1_pipeline.producer_commit(mma_s1_producer_state) - mma_s1_producer_state.advance() + s0_handle.commit() + s1_handle.commit() # Advance to next tile tile_sched.advance_to_next_work() @@ -1382,7 +1360,6 @@ class BlackwellFusedMultiHeadAttentionForward: alignment=16, ptr_to_buffer_holding_addr=storage.tmem_holding_buf, ) - cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols) # /////////////////////////////////////////////////////////////////////////////// @@ -1390,12 +1367,6 @@ class BlackwellFusedMultiHeadAttentionForward: # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.epilogue_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) - - corr_epi_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.epi_stage - ) - corr_epi_release_state = corr_epi_consumer_state.clone() - tile_sched = create_fmha_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) @@ -1418,7 +1389,6 @@ class BlackwellFusedMultiHeadAttentionForward: seqlen_q, ) ) - if not continue_cond: curr_block_coord_o = curr_block_coord mO_qdl_ = mO_qdl @@ -1453,27 +1423,23 @@ class BlackwellFusedMultiHeadAttentionForward: # wait from corr, issue tma store on smem # O0 # 1. wait for O0 final - corr_epi_pipeline.consumer_wait(corr_epi_consumer_state) - corr_epi_consumer_state.advance() + o0_handle = corr_epi_consumer.wait_and_advance() # 2. copy O0 to gmem cute.copy(tma_atom_o, tOsO[None, 0], tOgO[None, o0_coord]) cute.arch.cp_async_bulk_commit_group() # O1 # 1. wait for O1 final - corr_epi_pipeline.consumer_wait(corr_epi_consumer_state) - corr_epi_consumer_state.advance() + o1_handle = corr_epi_consumer.wait_and_advance() # 2. copy O1 to gmem cute.copy(tma_atom_o, tOsO[None, 1], tOgO[None, o1_coord]) cute.arch.cp_async_bulk_commit_group() # Ensure O0 buffer is ready to be released cute.arch.cp_async_bulk_wait_group(1, read=True) - corr_epi_pipeline.consumer_release(corr_epi_release_state) - corr_epi_release_state.advance() + o0_handle.release() # Ensure O1 buffer is ready to be released cute.arch.cp_async_bulk_wait_group(0, read=True) - corr_epi_pipeline.consumer_release(corr_epi_release_state) - corr_epi_release_state.advance() + o1_handle.release() # Advance to next tile tile_sched.advance_to_next_work() @@ -1496,9 +1462,10 @@ class BlackwellFusedMultiHeadAttentionForward: qk_thr_mma=qk_thr_mma, tStS=tStS, tStSi=tStS0, - mma_si_pipeline=mma_s0_pipeline, - si_corr_pipeline=s0_corr_pipeline, - s0_s1_sequence_pipeline=s0_s1_sequence_pipeline, + mma_si_consumer=mma_s0_consumer, + si_corr_producer=s0_corr_producer, + s0_s1_sequence_consumer=s0_s1_sequence_consumer, + s0_s1_sequence_producer=s0_s1_sequence_producer, tile_sched_params=tile_sched_params, ) cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) @@ -1522,9 +1489,10 @@ class BlackwellFusedMultiHeadAttentionForward: qk_thr_mma=qk_thr_mma, tStS=tStS, tStSi=tStS1, - mma_si_pipeline=mma_s1_pipeline, - si_corr_pipeline=s1_corr_pipeline, - s0_s1_sequence_pipeline=s0_s1_sequence_pipeline, + mma_si_consumer=mma_s1_consumer, + si_corr_producer=s1_corr_producer, + s0_s1_sequence_consumer=s0_s1_sequence_consumer, + s0_s1_sequence_producer=s0_s1_sequence_producer, tile_sched_params=tile_sched_params, ) cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) @@ -1535,19 +1503,6 @@ class BlackwellFusedMultiHeadAttentionForward: if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_correction) - s0_corr_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.softmax_corr_stage - ) - s1_corr_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.softmax_corr_stage - ) - o_corr_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_corr_stage - ) - corr_epi_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.epi_stage - ) - cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) tScS = qk_thr_mma.partition_C(cS) @@ -1602,127 +1557,93 @@ class BlackwellFusedMultiHeadAttentionForward: if cutlass.const_expr(cum_seqlen_k is not None): cuseqlen_k = cum_seqlen_k[batch_coord] seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k - # Ignore first signal from softmax as no correction is required - s0_corr_pipeline.consumer_wait(s0_corr_consumer_state) - s0_corr_pipeline.consumer_release(s0_corr_consumer_state) - s0_corr_consumer_state.advance() - - s1_corr_pipeline.consumer_wait(s1_corr_consumer_state) - + vec0_handle = s0_corr_consumer.wait_and_advance() + vec0_handle.release() + vec1_handle = s1_corr_consumer.wait_and_advance() seqlen_kv_loop_steps = ( self.get_trip_count(curr_block_coord, self.cta_tiler, seqlen_k) - 1 ) for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): - # wait for S0 - s0_corr_pipeline.consumer_wait(s0_corr_consumer_state) + # wait for vec0 (row_wise current max & previous max) + vec0_handle = s0_corr_consumer.wait_and_advance() tTMEM_LOAD_VECrS = cute.make_fragment( tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype ) - # read row_wise new global max cute.copy( tiled_tmem_load_vec, tTMEM_LOAD_VECtS0, tTMEM_LOAD_VECrS ) - scale_ = scale_softmax_log2 * ( tTMEM_LOAD_VECrS[0] - tTMEM_LOAD_VECrS[1] ) - scale = cute.arch.exp2(scale_) - - mma_corr_pipeline.consumer_wait(o_corr_consumer_state) + scale = cute.math.exp2(scale_, fastmath=True) + # wait for o0 + o0_handle = mma_corr_consumer.wait_and_advance() self.correction_rescale(pv_thr_mma, tOtO0, scale) - - s1_corr_pipeline.consumer_release(s1_corr_consumer_state) - s1_corr_consumer_state.advance() - + # release vec1 & o0 + vec1_handle.release() cute.arch.fence_view_async_tmem_store() + o0_handle.release() - mma_corr_pipeline.consumer_release(o_corr_consumer_state) - o_corr_consumer_state.advance() - - s1_corr_pipeline.consumer_wait(s1_corr_consumer_state) - + # wait for vec1 (row_wise current max & previous max) + vec1_handle = s1_corr_consumer.wait_and_advance() cute.copy( tiled_tmem_load_vec, tTMEM_LOAD_VECtS1, tTMEM_LOAD_VECrS ) - scale_ = scale_softmax_log2 * ( tTMEM_LOAD_VECrS[0] - tTMEM_LOAD_VECrS[1] ) - scale = cute.arch.exp2(scale_) - - mma_corr_pipeline.consumer_wait(o_corr_consumer_state) + scale = cute.math.exp2(scale_, fastmath=True) + o1_handle = mma_corr_consumer.wait_and_advance() self.correction_rescale(pv_thr_mma, tOtO1, scale) - - s0_corr_pipeline.consumer_release(s0_corr_consumer_state) - s0_corr_consumer_state.advance() - + vec0_handle.release() cute.arch.fence_view_async_tmem_store() - mma_corr_pipeline.consumer_release(o_corr_consumer_state) - o_corr_consumer_state.advance() + o1_handle.release() # End of seqlen_corr_loop_steps + vec1_handle.release() - s1_corr_pipeline.consumer_release(s1_corr_consumer_state) - s1_corr_consumer_state.advance() - - s0_corr_pipeline.consumer_wait(s0_corr_consumer_state) - + # wait for vec0 (row_wise global sum) + vec0_handle = s0_corr_consumer.wait_and_advance() tTMEM_LOAD_VECrS = cute.make_fragment( tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype ) cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS0, tTMEM_LOAD_VECrS) cute.arch.fence_view_async_tmem_load() - - s0_corr_pipeline.consumer_release(s0_corr_consumer_state) - s0_corr_consumer_state.advance() - - mma_corr_pipeline.consumer_wait(o_corr_consumer_state) - corr_epi_pipeline.producer_acquire(corr_epi_producer_state) - + vec0_handle.release() + # wait for o0 + o0_handle = mma_corr_consumer.wait_and_advance() + o0_final_handle = corr_epi_producer.acquire_and_advance() self.correction_epilog( pv_thr_mma, tOtO0, scale_output / tTMEM_LOAD_VECrS[0], sO[None, None, 0], ) + o0_handle.release() + o0_final_handle.commit() - mma_corr_pipeline.consumer_release(o_corr_consumer_state) - o_corr_consumer_state.advance() - - corr_epi_pipeline.producer_commit(corr_epi_producer_state) - corr_epi_producer_state.advance() - - s1_corr_pipeline.consumer_wait(s1_corr_consumer_state) - # load from V1 + # wait for vec1 (row_wise global sum) + vec1_handle = s1_corr_consumer.wait_and_advance() cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS1, tTMEM_LOAD_VECrS) cute.arch.fence_view_async_tmem_load() - - s1_corr_pipeline.consumer_release(s1_corr_consumer_state) - s1_corr_consumer_state.advance() - - mma_corr_pipeline.consumer_wait(o_corr_consumer_state) - - corr_epi_pipeline.producer_acquire(corr_epi_producer_state) + vec1_handle.release() + # wait for o1 + o1_handle = mma_corr_consumer.wait_and_advance() + o1_final_handle = corr_epi_producer.acquire_and_advance() self.correction_epilog( pv_thr_mma, tOtO1, scale_output / tTMEM_LOAD_VECrS[0], sO[None, None, 1], ) - mma_corr_pipeline.consumer_release(o_corr_consumer_state) - o_corr_consumer_state.advance() - - corr_epi_pipeline.producer_commit(corr_epi_producer_state) - corr_epi_producer_state.advance() - + o1_handle.release() + o1_final_handle.commit() # Advance to next tile tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() # End of persistent scheduler loop - cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) - return @cute.jit @@ -1730,33 +1651,19 @@ class BlackwellFusedMultiHeadAttentionForward: self, stage: int, need_apply_mask: bool, - seqlen_k: Int32, - row_max: Float32, - row_sum: Float32, - mma_si_consumer_state: pipeline.PipelineState, - si_corr_producer_state: pipeline.PipelineState, - s0_s1_sequence_state: pipeline.PipelineState, - mma_si_pipeline: pipeline.PipelineAsync, - si_corr_pipeline: pipeline.PipelineAsync, - s0_s1_sequence_pipeline: pipeline.PipelineAsync, - scale_softmax_log2: Float32, - cS: cute.Tensor, - qk_thr_mma: cute.core.ThrMma, - tiled_tmem_load: cute.TiledCopy, - tiled_tmem_store: cute.TiledCopy, - tiled_tmem_store_vec: cute.TiledCopy, - thr_tmem_load: cute.CopyAtom, - thr_tmem_store: cute.CopyAtom, - thr_tmem_store_vec: cute.CopyAtom, - tTMEM_LOADtS: cute.Tensor, - tTMEM_STORE_VECtS: cute.Tensor, - tTMEM_STOREtS_x4: cute.Tensor, + iter_args: tuple, + value_args: tuple, + pipeline_args: tuple, + atom_args: tuple, + tensor_args: tuple, ) -> Tuple[ Float32, Float32, - pipeline.PipelineState, - pipeline.PipelineState, - pipeline.PipelineState, + pipeline.PipelineProducer.ImmutableResourceHandle, + pipeline.PipelineConsumer, + pipeline.PipelineProducer, + pipeline.PipelineConsumer, + pipeline.PipelineProducer, ]: """Perform a single step of the softmax computation on a block of attention scores. @@ -1777,49 +1684,42 @@ class BlackwellFusedMultiHeadAttentionForward: :type stage: int :param need_apply_mask: Whether to apply attention masking :type need_apply_mask: bool - :param row_max: Current maximum value for the row - :type row_max: cute.core.Tensor - :param row_sum: Current sum value for the row - :type row_sum: cute.core.Tensor - :param mma_si_consumer_state: Pipeline state for MMA consumer operations - :type mma_si_consumer_state: pipeline.PipelineState - :param si_corr_producer_state: Pipeline state for correction producer operations - :type si_corr_producer_state: pipeline.PipelineState - :param s0_s1_sequence_state: Pipeline state for sequence synchronization - :type s0_s1_sequence_state: pipeline.PipelineState - :param mma_si_pipeline: Pipeline for MMA operations - :type mma_si_pipeline: pipeline.PipelineAsync - :param si_corr_pipeline: Pipeline for correction operations - :type si_corr_pipeline: pipeline.PipelineAsync - :param s0_s1_sequence_pipeline: Pipeline for sequence synchronization - :type s0_s1_sequence_pipeline: pipeline.PipelineAsync - :param scale_softmax_log2: Log2 scale factor for softmax computation - :type scale_softmax_log2: Float32 - :param cS: Current slice of attention matrix - :type cS: cute.Tensor - :param qk_thr_mma: Thread MMA operation - :type qk_thr_mma: cute.core.ThrMma - :param tiled_tmem_load: Tiled copy operation for loading from tensor memory - :type tiled_tmem_load: cute.TiledCopy - :param tiled_tmem_store: Tiled copy operation for storing to tensor memory - :type tiled_tmem_store: cute.TiledCopy - :param tiled_tmem_store_vec: Tiled copy operation for storing vector data - :type tiled_tmem_store_vec: cute.TiledCopy - :param thr_tmem_load: Thread copy operation for loading - :type thr_tmem_load: cute.CopyAtom - :param thr_tmem_store: Thread copy operation for storing - :type thr_tmem_store: cute.CopyAtom - :param thr_tmem_store_vec: Thread copy operation for storing vector data - :type thr_tmem_store_vec: cute.CopyAtom - :param tTMEM_LOADtS: Tensor for loading from tensor memory - :type tTMEM_LOADtS: cute.Tensor - :param tTMEM_STORE_VECtS: Tensor for storing vector data - :type tTMEM_STORE_VECtS: cute.Tensor - :param tTMEM_STOREtS_x4: Tensor for storing processed data - :type tTMEM_STOREtS_x4: cute.Tensor - :return: Updated state values (row_max, row_sum, and pipeline states) + :param iter_args: Tuple containing the counting tensor, row_max, row_sum, and vector buffer's handle for current iteration + :type iter_args: tuple + :param value_args: Tuple containing seqlen_k and scale_softmax_log2 + :type value_args: tuple + :param pipeline_args: Tuple containing pipeline related arguments for MMA, correction, and sequence synchronization + :type pipeline_args: tuple + :param atom_args: Tuple containing mma & copy atoms + :type atom_args: tuple + :param tensor_args: Tuple containing softmax related tensors + :type tensor_args: tuple + :return: Updated state values (row_max, row_sum, and pipeline related arguments) :rtype: tuple """ + cS, row_max, row_sum, vec_i_handle = iter_args + seqlen_k, scale_softmax_log2 = value_args + ( + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) = pipeline_args + ( + qk_thr_mma, + tiled_tmem_load, + tiled_tmem_store, + tiled_tmem_store_vec, + thr_tmem_load, + thr_tmem_store, + thr_tmem_store_vec, + ) = atom_args + ( + tTMEM_LOADtS, + tTMEM_STORE_VECtS, + tTMEM_STOREtS_x4, + ) = tensor_args + tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width tScS = qk_thr_mma.partition_C(cS) tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2))) @@ -1834,10 +1734,9 @@ class BlackwellFusedMultiHeadAttentionForward: tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P) # Wait for Si - mma_si_pipeline.consumer_wait(mma_si_consumer_state) + si_handle = mma_si_consumer.wait_and_advance() tTMEM_LOADrS = cute.make_fragment(tTMEM_LOADcS.shape, self.qk_acc_dtype) cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) - if need_apply_mask: self.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, seqlen_k) @@ -1846,7 +1745,6 @@ class BlackwellFusedMultiHeadAttentionForward: row_max_safe = row_max if row_max == -cutlass.Float32.inf: row_max_safe = 0.0 - tTMEM_STORE_VECrS = cute.make_fragment( tTMEM_STORE_VECcS.shape, self.qk_acc_dtype ) @@ -1855,8 +1753,7 @@ class BlackwellFusedMultiHeadAttentionForward: cute.copy(tiled_tmem_store_vec, tTMEM_STORE_VECrS, tTMEM_STORE_VECtS) cute.arch.fence_view_async_tmem_store() # Notify correction wg that row_max is ready - si_corr_pipeline.producer_commit(si_corr_producer_state) - si_corr_producer_state.advance() + vec_i_handle.commit() tTMEM_STORErS_x4 = cute.make_fragment(tTMEM_STOREcS.shape, self.qk_acc_dtype) tTMEM_STORErS_x4_e = cute.make_tensor( @@ -1868,11 +1765,10 @@ class BlackwellFusedMultiHeadAttentionForward: minus_row_max_scale = (0.0 - row_max_safe) * scale # Sequence barrier wait - if stage == 0: - s0_s1_sequence_pipeline.producer_acquire(s0_s1_sequence_state) + if cutlass.const_expr(stage == 0): + sequence_producer_handle = s0_s1_sequence_producer.acquire_and_advance() else: - s0_s1_sequence_pipeline.consumer_wait(s0_s1_sequence_state) - + sequence_consumer_handle = s0_s1_sequence_consumer.wait_and_advance() frg_cnt = 4 frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) @@ -1888,29 +1784,27 @@ class BlackwellFusedMultiHeadAttentionForward: (minus_row_max_scale, minus_row_max_scale), ) ) - tTMEM_LOADrS_frg[k, j] = cute.arch.exp2(tTMEM_LOADrS_frg[k, j]) - tTMEM_LOADrS_frg[k + 1, j] = cute.arch.exp2(tTMEM_LOADrS_frg[k + 1, j]) + tTMEM_LOADrS_frg[k, j] = cute.math.exp2( + tTMEM_LOADrS_frg[k, j], fastmath=True + ) + tTMEM_LOADrS_frg[k + 1, j] = cute.math.exp2( + tTMEM_LOADrS_frg[k + 1, j], fastmath=True + ) s_vec = tTMEM_LOADrS_frg[None, j].load() tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype)) - # Sequence barrier arrive - if stage == 0: - s0_s1_sequence_pipeline.producer_commit(s0_s1_sequence_state) + if cutlass.const_expr(stage == 0): + sequence_producer_handle.commit() else: - s0_s1_sequence_pipeline.consumer_release(s0_s1_sequence_state) - s0_s1_sequence_state.advance() - + sequence_consumer_handle.release() cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4) cute.arch.fence_view_async_tmem_store() + # Notify tensor core warp that softmax(S->P) is ready + si_handle.release() - # Notify tensor core warp that P is ready - mma_si_pipeline.consumer_release(mma_si_consumer_state) - mma_si_consumer_state.advance() - - si_corr_pipeline.producer_acquire(si_corr_producer_state) - + vec_i_handle = si_corr_producer.acquire_and_advance() acc_scale_ = scale * (old_row_max - row_max_safe) - acc_scale = cute.arch.exp2(acc_scale_) * 0.5 + acc_scale = cute.math.exp2(acc_scale_, fastmath=True) * 0.5 row_sum *= acc_scale local_row_sum_0 = (row_sum, row_sum) local_row_sum_1 = (0.0, 0.0) @@ -1943,12 +1837,14 @@ class BlackwellFusedMultiHeadAttentionForward: return ( row_max, row_sum, - mma_si_consumer_state, - si_corr_producer_state, - s0_s1_sequence_state, + vec_i_handle, + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, ) - # for both softmax0 and softmax1 warp group + # For both softmax0 and softmax1 warp group @cute.jit def softmax( self, @@ -1960,9 +1856,10 @@ class BlackwellFusedMultiHeadAttentionForward: qk_thr_mma: cute.core.ThrMma, tStS: cute.Tensor, tStSi: cute.Tensor, - mma_si_pipeline: pipeline.PipelineAsync, - si_corr_pipeline: pipeline.PipelineAsync, - s0_s1_sequence_pipeline: pipeline.PipelineAsync, + mma_si_consumer: pipeline.PipelineConsumer, + si_corr_producer: pipeline.PipelineProducer, + s0_s1_sequence_consumer: pipeline.PipelineConsumer, + s0_s1_sequence_producer: pipeline.PipelineProducer, tile_sched_params: FmhaStaticTileSchedulerParams, ): """Compute softmax on attention scores from QK matrix multiplication. @@ -2008,29 +1905,22 @@ class BlackwellFusedMultiHeadAttentionForward: cS_base = cute.make_identity_tensor( (self.qk_mma_tiler[0], self.qk_mma_tiler[1]) ) - tilePlikeFP32 = self.qk_mma_tiler[1] // 32 * self.o_dtype.width - tScS = qk_thr_mma.partition_C(cS_base) - tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2))) tmem_vec_offset = self.tmem_vec0_offset if stage == 0 else self.tmem_vec1_offset tStS_vec = cute.make_tensor(tStS.iterator + tmem_vec_offset, tStS_vec_layout) - tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2))) tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) - tStS_P_layout = cute.composition( tStS.layout, cute.make_layout((128, tilePlikeFP32)) ) tmem_p_offset = self.tmem_p0_offset if stage == 0 else self.tmem_p1_offset tStS_P = cute.make_tensor(tStS.iterator + tmem_p_offset, tStS_P_layout) - tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype, ) - tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi) thread_idx = tidx % ( self.threads_per_warp @@ -2042,14 +1932,12 @@ class BlackwellFusedMultiHeadAttentionForward: ) thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) tTMEM_LOADtS = thr_tmem_load.partition_S(tStSi) - tmem_store_vec_atom = cute.make_copy_atom( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(2)), self.qk_acc_dtype, ) tiled_tmem_store_vec = tcgen05.make_tmem_copy(tmem_store_vec_atom, tStS_vec) thr_tmem_store_vec = tiled_tmem_store_vec.get_slice(thread_idx) - tTMEM_STORE_VECtS = thr_tmem_store_vec.partition_D(tStS_vec) tTMEM_STORE_VECcS = thr_tmem_store_vec.partition_S(tScS_vec) tmem_store_atom = cute.make_copy_atom( @@ -2060,21 +1948,6 @@ class BlackwellFusedMultiHeadAttentionForward: thr_tmem_store = tiled_tmem_store.get_slice(thread_idx) tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P) - mma_si_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_softmax_stage - ) - si_corr_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.softmax_corr_stage - ) - s0_s1_sequence_state = pipeline.make_pipeline_state( - ( - pipeline.PipelineUserType.Producer - if stage == 0 - else pipeline.PipelineUserType.Consumer - ), - 1, - ) - tile_sched = create_fmha_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) @@ -2101,59 +1974,62 @@ class BlackwellFusedMultiHeadAttentionForward: if cutlass.const_expr(cum_seqlen_k is not None): cuseqlen_k = cum_seqlen_k[batch_coord] seqlen_k_ = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + row_max = -Float32.inf + row_sum = 0.0 + value_args = (seqlen_k_, scale_softmax_log2) + atom_args = ( + qk_thr_mma, + tiled_tmem_load, + tiled_tmem_store, + tiled_tmem_store_vec, + thr_tmem_load, + thr_tmem_store, + thr_tmem_store_vec, + ) + tensor_args = ( + tTMEM_LOADtS, + tTMEM_STORE_VECtS, + tTMEM_STOREtS_x4, + ) logical_offset = ( curr_block_coord[0] * self.cta_tiler[0] + stage * self.qk_mma_tiler[0], 0, ) - cS = cute.domain_offset(logical_offset, cS_base) - si_corr_pipeline.producer_acquire(si_corr_producer_state) - + vec_i_handle = si_corr_producer.acquire_and_advance() unmask_count = self.get_unmasked_trip_count( curr_block_coord, self.cta_tiler, seqlen_k_, ) - - row_max = -Float32.inf - row_sum = 0.0 - for i in cutlass.range(0, unmask_count, 1, unroll=1): cS_iter = cute.domain_offset((0, i * self.qk_mma_tiler[1]), cS) + iter_args = (cS_iter, row_max, row_sum, vec_i_handle) + pipeline_args = ( + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) ( row_max, row_sum, - mma_si_consumer_state, - si_corr_producer_state, - s0_s1_sequence_state, + vec_i_handle, + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, ) = self.softmax_step( stage, False, - seqlen_k_, - row_max, - row_sum, - mma_si_consumer_state, - si_corr_producer_state, - s0_s1_sequence_state, - mma_si_pipeline, - si_corr_pipeline, - s0_s1_sequence_pipeline, - scale_softmax_log2, - cS_iter, - qk_thr_mma, - tiled_tmem_load, - tiled_tmem_store, - tiled_tmem_store_vec, - thr_tmem_load, - thr_tmem_store, - thr_tmem_store_vec, - tTMEM_LOADtS, - tTMEM_STORE_VECtS, - tTMEM_STOREtS_x4, + iter_args, + value_args, + pipeline_args, + atom_args, + tensor_args, ) - mask_count = self.get_masked_trip_count( curr_block_coord, self.cta_tiler, @@ -2164,40 +2040,31 @@ class BlackwellFusedMultiHeadAttentionForward: unmask_count, unmask_count + mask_count, 1, unroll=1 ): cS_iter = cute.domain_offset((0, i * self.qk_mma_tiler[1]), cS) + iter_args = (cS_iter, row_max, row_sum, vec_i_handle) + pipeline_args = ( + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) ( row_max, row_sum, - mma_si_consumer_state, - si_corr_producer_state, - s0_s1_sequence_state, + vec_i_handle, + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, ) = self.softmax_step( stage, True, - seqlen_k_, - row_max, - row_sum, - mma_si_consumer_state, - si_corr_producer_state, - s0_s1_sequence_state, - mma_si_pipeline, - si_corr_pipeline, - s0_s1_sequence_pipeline, - scale_softmax_log2, - cS_iter, - qk_thr_mma, - tiled_tmem_load, - tiled_tmem_store, - tiled_tmem_store_vec, - thr_tmem_load, - thr_tmem_store, - thr_tmem_store_vec, - tTMEM_LOADtS, - tTMEM_STORE_VECtS, - tTMEM_STOREtS_x4, + iter_args, + value_args, + pipeline_args, + atom_args, + tensor_args, ) - - mma_si_pipeline.consumer_wait(mma_si_consumer_state) - + si_handle = mma_si_consumer.wait_and_advance() tTMEM_STORE_VECrS = cute.make_fragment( tTMEM_STORE_VECcS.shape, self.qk_acc_dtype ) @@ -2205,15 +2072,10 @@ class BlackwellFusedMultiHeadAttentionForward: tTMEM_STORE_VECrS[1] = row_max cute.copy(tiled_tmem_store_vec, tTMEM_STORE_VECrS, tTMEM_STORE_VECtS) cute.arch.fence_view_async_tmem_store() - - si_corr_pipeline.producer_commit(si_corr_producer_state) - si_corr_producer_state.advance() - - si_corr_pipeline.producer_acquire(si_corr_producer_state) - + vec_i_handle.commit() + si_corr_producer.acquire() # Empty step to sync against pipe s - mma_si_pipeline.consumer_release(mma_si_consumer_state) - mma_si_consumer_state.advance() + si_handle.release() # Advance to next tile tile_sched.advance_to_next_work() @@ -2439,7 +2301,11 @@ class BlackwellFusedMultiHeadAttentionForward: else: result = 0 elif self.mask_type == MaskType.CAUSAL_MASK: - result = cute.ceil_div(tile_shape[0], tile_shape[1]) + trip_count = self.get_trip_count(blk_coord, tile_shape, seqlen_k) + result = cutlass.min( + trip_count, + cute.ceil_div(tile_shape[0], tile_shape[1]), + ) return result @cute.jit @@ -2481,122 +2347,6 @@ class BlackwellFusedMultiHeadAttentionForward: if pos[0] < pos[1] or pos[1] >= seqlen_k: acc_qk[i] = -Float32.inf - def make_and_init_load_q_pipeline(self, load_q_mbar_ptr): - load_q_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.load_warp_id]) - ) - load_q_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) - ) - return pipeline.PipelineTmaUmma.create( - num_stages=self.q_stage, - producer_group=load_q_producer_group, - consumer_group=load_q_consumer_group, - tx_count=self.tma_copy_q_bytes, - barrier_storage=load_q_mbar_ptr, - ) - - def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): - load_kv_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.load_warp_id]) - ) - load_kv_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) - ) - return pipeline.PipelineTmaUmma.create( - num_stages=self.kv_stage, - producer_group=load_kv_producer_group, - consumer_group=load_kv_consumer_group, - tx_count=self.tma_copy_kv_bytes, - barrier_storage=load_kv_mbar_ptr, - ) - - def make_and_init_mma_si_pipeline(self, mma_si_mbar_ptr): - mma_si_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) - ) - mma_si_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - self.threads_per_warp * len(self.softmax0_warp_ids), - self.threads_per_warp * len(self.softmax0_warp_ids), - ) - return pipeline.PipelineUmmaAsync.create( - num_stages=self.mma_softmax_stage, - producer_group=mma_si_producer_group, - consumer_group=mma_si_consumer_group, - barrier_storage=mma_si_mbar_ptr, - ) - - def make_and_init_si_corr_pipeline(self, si_corr_mbar_ptr): - si_corr_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - self.threads_per_warp * len(self.softmax0_warp_ids), - self.threads_per_warp * len(self.softmax0_warp_ids), - ) - si_corr_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - self.threads_per_warp * len(self.correction_warp_ids), - self.threads_per_warp * len(self.correction_warp_ids), - ) - return pipeline.PipelineAsync.create( - num_stages=self.softmax_corr_stage, - producer_group=si_corr_producer_group, - consumer_group=si_corr_consumer_group, - barrier_storage=si_corr_mbar_ptr, - ) - - def make_and_init_corr_epi_pipeline(self, corr_epi_mbar_ptr): - corr_epi_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - self.threads_per_warp * len(self.correction_warp_ids), - self.threads_per_warp * len(self.correction_warp_ids), - ) - corr_epi_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - self.threads_per_warp * len([self.epilogue_warp_id]), - self.threads_per_warp * len([self.epilogue_warp_id]), - ) - return pipeline.PipelineAsync.create( - num_stages=self.epi_stage, - producer_group=corr_epi_producer_group, - consumer_group=corr_epi_consumer_group, - barrier_storage=corr_epi_mbar_ptr, - ) - - def make_and_init_mma_corr_pipeline(self, mma_corr_mbar_ptr): - mma_corr_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) - ) - mma_corr_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - self.threads_per_warp * len(self.correction_warp_ids), - self.threads_per_warp * len(self.correction_warp_ids), - ) - return pipeline.PipelineUmmaAsync.create( - num_stages=self.mma_corr_stage, - producer_group=mma_corr_producer_group, - consumer_group=mma_corr_consumer_group, - barrier_storage=mma_corr_mbar_ptr, - ) - - def make_and_init_si_sequence_pipeline(self, si_sequence_mbar_ptr): - s0_sequence_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - self.threads_per_warp * len(self.softmax0_warp_ids), - self.threads_per_warp * len(self.softmax0_warp_ids), - ) - s1_sequence_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - self.threads_per_warp * len(self.softmax1_warp_ids), - self.threads_per_warp * len(self.softmax1_warp_ids), - ) - return pipeline.PipelineAsync.create( - num_stages=1, - producer_group=s0_sequence_group, - consumer_group=s1_sequence_group, - barrier_storage=si_sequence_mbar_ptr, - ) - @staticmethod def _compute_grid( o_shape: cute.Shape, @@ -2612,7 +2362,6 @@ class BlackwellFusedMultiHeadAttentionForward: ), ) grid = FmhaStaticTileScheduler.get_grid_shape(tile_sched_params) - return tile_sched_params, grid diff --git a/examples/python/CuTeDSL/blackwell/grouped_gemm.py b/examples/python/CuTeDSL/blackwell/grouped_gemm.py index 0dba2bb6..c279f758 100644 --- a/examples/python/CuTeDSL/blackwell/grouped_gemm.py +++ b/examples/python/CuTeDSL/blackwell/grouped_gemm.py @@ -475,7 +475,6 @@ class GroupedGemmKernel: grid=grid, block=[self.threads_per_cta, 1, 1], cluster=(*self.cluster_shape_mn, 1), - smem=self.shared_storage.size_in_bytes(), stream=stream, ) return @@ -785,7 +784,7 @@ class GroupedGemmKernel: ) tensormap_init_done = cutlass.Boolean(False) # tile count we have searched - total_k_block_cnt = cutlass.Int32(0) + total_k_tile_cnt = cutlass.Int32(0) # group index of last tile last_group_idx = cutlass.Int32(-1) work_tile = tile_sched.initial_work_tile_info() @@ -795,7 +794,7 @@ class GroupedGemmKernel: cur_tile_coord, problem_sizes_mnkl, ) - cur_k_block_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k + cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k cur_group_idx = grouped_gemm_cta_tile_info.group_idx is_group_changed = cur_group_idx != last_group_idx # skip tensormap update if we're working on the same group @@ -861,17 +860,17 @@ class GroupedGemmKernel: (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) ] - num_prev_k_blk = total_k_block_cnt - total_k_block_cnt += cur_k_block_cnt + num_prev_k_blk = total_k_tile_cnt + total_k_tile_cnt += cur_k_tile_cnt - # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt - tma_wr_k_block = cutlass.Int32(0) - smem_wr_buffer = (num_prev_k_blk + tma_wr_k_block) % self.num_ab_stage + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + tma_wr_k_tile = cutlass.Int32(0) + smem_wr_buffer = (num_prev_k_blk + tma_wr_k_tile) % self.num_ab_stage tma_wr_ab_empty_phase = ( - num_prev_k_blk + tma_wr_k_block + num_prev_k_blk + tma_wr_k_tile ) // self.num_ab_stage % 2 ^ 1 peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait( - tma_wr_k_block < cur_k_block_cnt, + tma_wr_k_tile < cur_k_tile_cnt, ab_empty_mbar_ptr + smem_wr_buffer, tma_wr_ab_empty_phase, ) @@ -882,10 +881,10 @@ class GroupedGemmKernel: # # Tma load loop # - for k_block in cutlass.range(0, cur_k_block_cnt, 1, unroll=1): - tma_wr_k_block_next = tma_wr_k_block + 1 + for k_tile in cutlass.range(0, cur_k_tile_cnt, 1, unroll=1): + tma_wr_k_tile_next = tma_wr_k_tile + 1 smem_wr_buffer_next = ( - num_prev_k_blk + tma_wr_k_block_next + num_prev_k_blk + tma_wr_k_tile_next ) % self.num_ab_stage tma_wr_ab_empty_phase_next = ( tma_wr_ab_empty_phase ^ 1 @@ -911,7 +910,7 @@ class GroupedGemmKernel: # Load A/B with TMA cute.copy( tma_atom_a, - tAgA_slice[(None, tma_wr_k_block)], + tAgA_slice[(None, tma_wr_k_tile)], tAsA[(None, smem_wr_buffer)], tma_bar_ptr=smem_full_mbar_ptr, mcast_mask=a_full_mcast_mask, @@ -922,7 +921,7 @@ class GroupedGemmKernel: ) cute.copy( tma_atom_b, - tBgB_slice[(None, tma_wr_k_block)], + tBgB_slice[(None, tma_wr_k_tile)], tBsB[(None, smem_wr_buffer)], tma_bar_ptr=smem_full_mbar_ptr, mcast_mask=b_full_mcast_mask, @@ -932,14 +931,14 @@ class GroupedGemmKernel: ), ) - # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1 + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait( - tma_wr_k_block_next < cur_k_block_cnt, + tma_wr_k_tile_next < cur_k_tile_cnt, ab_empty_mbar_ptr + smem_wr_buffer_next, tma_wr_ab_empty_phase_next, ) - tma_wr_k_block = tma_wr_k_block_next + tma_wr_k_tile = tma_wr_k_tile_next smem_wr_buffer = smem_wr_buffer_next tma_wr_ab_empty_phase = tma_wr_ab_empty_phase_next @@ -998,12 +997,12 @@ class GroupedGemmKernel: work_tile = tile_sched.initial_work_tile_info() # tile count we have searched - total_k_block_cnt = cutlass.Int32(0) + total_k_tile_cnt = cutlass.Int32(0) while work_tile.is_valid_tile: cur_tile_coord = work_tile.tile_idx # MMA warp is only interested in number of tiles along K dimension ( - cur_k_block_cnt, + cur_k_tile_cnt, cur_group_idx, ) = group_gemm_ts_helper.search_cluster_tile_count_k( cur_tile_coord, @@ -1014,17 +1013,17 @@ class GroupedGemmKernel: # (MMA, MMA_M, MMA_N) tCtAcc = tCtAcc_base[(None, None, None, acc_buf_idx)] - num_prev_k_blk = total_k_block_cnt - total_k_block_cnt += cur_k_block_cnt + num_prev_k_blk = total_k_tile_cnt + total_k_tile_cnt += cur_k_tile_cnt - # Peek (try_wait) AB buffer full for k_block = 0 - mma_rd_k_block = cutlass.Int32(0) - smem_rd_buffer = (num_prev_k_blk + mma_rd_k_block) % self.num_ab_stage + # Peek (try_wait) AB buffer full for k_tile = 0 + mma_rd_k_tile = cutlass.Int32(0) + smem_rd_buffer = (num_prev_k_blk + mma_rd_k_tile) % self.num_ab_stage need_check_rd_buffer_full = ( - mma_rd_k_block < cur_k_block_cnt and is_leader_cta + mma_rd_k_tile < cur_k_tile_cnt and is_leader_cta ) mma_rd_ab_full_phase = ( - (num_prev_k_blk + mma_rd_k_block) // self.num_ab_stage % 2 + (num_prev_k_blk + mma_rd_k_tile) // self.num_ab_stage % 2 ) peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait( need_check_rd_buffer_full, @@ -1051,10 +1050,10 @@ class GroupedGemmKernel: # # Mma mainloop # - for k_block in range(cur_k_block_cnt): - mma_rd_k_block_next = cutlass.Int32(k_block + 1) + for k_tile in range(cur_k_tile_cnt): + mma_rd_k_tile_next = cutlass.Int32(k_tile + 1) smem_rd_buffer_next = ( - num_prev_k_blk + mma_rd_k_block_next + num_prev_k_blk + mma_rd_k_tile_next ) % self.num_ab_stage mma_rd_ab_full_phase_next = ( mma_rd_ab_full_phase ^ 1 @@ -1069,18 +1068,18 @@ class GroupedGemmKernel: ) # tCtAcc += tCrA * tCrB - num_kphases = cute.size(tCrA, mode=[2]) - for kphase_idx in cutlass.range(num_kphases, unroll_full=True): - kphase_coord = (None, None, kphase_idx, smem_rd_buffer) + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = (None, None, kblock_idx, smem_rd_buffer) cute.gemm( tiled_mma, tCtAcc, - tCrA[kphase_coord], - tCrB[kphase_coord], + tCrA[kblock_coord], + tCrB[kblock_coord], tCtAcc, ) - # Enable accumulate on tCtAcc after first kphase + # Enable accumulate on tCtAcc after first kblock tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Async arrive AB buffer empty @@ -1091,9 +1090,9 @@ class GroupedGemmKernel: self.cta_group, ) - # Peek (try_wait) AB buffer full for k_block = k_block + 1 + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 need_check_rd_buffer_full = ( - mma_rd_k_block_next < cur_k_block_cnt and is_leader_cta + mma_rd_k_tile_next < cur_k_tile_cnt and is_leader_cta ) peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait( @@ -1102,7 +1101,7 @@ class GroupedGemmKernel: mma_rd_ab_full_phase_next, ) - mma_rd_k_block = mma_rd_k_block_next + mma_rd_k_tile = mma_rd_k_tile_next smem_rd_buffer = smem_rd_buffer_next mma_rd_ab_full_phase = mma_rd_ab_full_phase_next @@ -1201,7 +1200,7 @@ class GroupedGemmKernel: # wait tensormap initialization complete before update tensormap_manager.fence_tensormap_initialization() # tile count we have searched - total_k_block_cnt = cutlass.Int32(0) + total_k_tile_cnt = cutlass.Int32(0) # group index of last tile last_group_idx = cutlass.Int32(-1) while work_tile.is_valid_tile: @@ -1240,8 +1239,8 @@ class GroupedGemmKernel: grouped_gemm_cta_tile_info.cta_tile_idx_n, 0, ) - cur_k_block_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k - total_k_block_cnt += cur_k_block_cnt + cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k + total_k_tile_cnt += cur_k_tile_cnt # # Slice to per mma tile index @@ -1370,8 +1369,8 @@ class GroupedGemmKernel: # if warp_idx == self.epilog_warp_id[0]: cute.arch.mbarrier_wait( - (ab_empty_mbar_ptr + ((total_k_block_cnt - 1) % self.num_ab_stage)), - (((total_k_block_cnt - 1) // self.num_ab_stage) % 2), + (ab_empty_mbar_ptr + ((total_k_tile_cnt - 1) % self.num_ab_stage)), + (((total_k_tile_cnt - 1) // self.num_ab_stage) % 2), ) @cute.jit diff --git a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py index e77221bb..829d6b7e 100644 --- a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py +++ b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py @@ -622,7 +622,6 @@ class SSDKernel: block=[self.threads_per_cta, 1, 1], cluster=self.cluster_shape_mnk, min_blocks_per_mp=1, - smem=self.shared_storage.size_in_bytes(), stream=stream, ) @@ -693,7 +692,7 @@ class SSDKernel: G = cute.size(tma_tensor_b, mode=[3]) NGROUP_RATIO = EH // G - # Make tiledMma + # Make TiledMma ( tiled_mma_intra1, tiled_mma_intra2, @@ -1745,7 +1744,7 @@ class SSDKernel: cute.arch.fence_view_async_tmem_load() # Combine INTER1_ACC/last_column/State - exp_last_column = cute.arch.exp(last_column.ir_value()) + exp_last_column = cute.math.exp(last_column, fastmath=True) for reg_idx in range(0, cute.size(tTR_rP), 2): ( tTR_rP[reg_idx], @@ -2267,9 +2266,11 @@ class SSDKernel: ) = cute.arch.fma_packed_f32x2( (tTR_rInter[reg_idx], tTR_rInter[reg_idx + 1]), ( - cute.arch.exp(tTR_rDeltaA[reg_idx].ir_value()), - cute.arch.exp( - tTR_rDeltaA[reg_idx + 1].ir_value() + cute.math.exp( + tTR_rDeltaA[reg_idx], fastmath=True + ), + cute.math.exp( + tTR_rDeltaA[reg_idx + 1], fastmath=True ), ), (tTR_rIntra[reg_idx], tTR_rIntra[reg_idx + 1]), @@ -3072,14 +3073,19 @@ class SSDKernel: m, n = tCoord[subtile_idx] if m < n: tCompute[subtile_idx] = cutlass.Float32(-float("inf")) + LOG2_E = cutlass.Float32(1.4426950408889634) for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True): # TODO: use math.exp directly + tCompute_log2e = cute.arch.mul_packed_f32x2( + (tCompute[subtile_idx], tCompute[subtile_idx + 1]), (LOG2_E, LOG2_E) + ) ( tCompute[subtile_idx], tCompute[subtile_idx + 1], ) = cute.arch.mul_packed_f32x2( - cute.arch.exp_packed_f32x2( - (tCompute[subtile_idx], tCompute[subtile_idx + 1]) + ( + cute.math.exp2(tCompute_log2e[0], fastmath=True), + cute.math.exp2(tCompute_log2e[1], fastmath=True), ), (tCrDelta[subtile_idx], tCrDelta[subtile_idx + 1]), ) @@ -3245,11 +3251,11 @@ class SSDKernel: for reg_idx in range(0, cute.size(tBrB_Compute), 2): tCompute[reg_idx], tCompute[reg_idx + 1] = cute.arch.mul_packed_f32x2( ( - cute.arch.exp( - (last_column - tBrDeltaA_Compute[reg_idx]).ir_value() + cute.math.exp( + (last_column - tBrDeltaA_Compute[reg_idx]), fastmath=True ), - cute.arch.exp( - (last_column - tBrDeltaA_Compute[reg_idx + 1]).ir_value() + cute.math.exp( + (last_column - tBrDeltaA_Compute[reg_idx + 1]), fastmath=True ), ), (tBrDelta_Compute[reg_idx], tBrDelta_Compute[reg_idx + 1]), diff --git a/examples/python/CuTeDSL/hopper/dense_gemm.py b/examples/python/CuTeDSL/hopper/dense_gemm.py index 6bab06ea..c59ace02 100644 --- a/examples/python/CuTeDSL/hopper/dense_gemm.py +++ b/examples/python/CuTeDSL/hopper/dense_gemm.py @@ -44,7 +44,7 @@ import cutlass.utils.hopper_helpers as sm90_utils """ A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture -using CUTE DSL. +using CuTe DSL. - Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") - Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") - Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") @@ -70,7 +70,7 @@ To run this example: .. code-block:: bash python examples/hopper/dense_gemm.py \ - --mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \ + --mnkl 8192,8192,8192,1 --tile_shape_mn 128,256 \ --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \ --c_dtype Float16 --acc_dtype Float32 \ --a_major k --b_major k --c_major n @@ -85,7 +85,7 @@ To collect performance with NCU profiler: .. code-block:: bash ncu python examples/hopper/dense_gemm.py \ - --mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \ + --mnkl 8192,8192,8192,1 --tile_shape_mn 128,256 \ --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \ --c_dtype Float16 --acc_dtype Float32 \ --a_major k --b_major k --c_major n @@ -95,14 +95,11 @@ Constraints: * For fp16 types, A and B must have the same data type * For fp8 types, A and B can have different types (e4m3fn or e5m2) but both must be 8-bit * Fp8 types only support k-major layout -* Only fp32 accumulation is supported in this example * CTA tile shape M must be 64/128 * CTA tile shape N must be 64/128/256 -* CTA tile shape K must be 64 * Cluster shape M/N must be positive and power of 2, total cluster size <= 4 * The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively. -* OOB tiles are not allowed when TMA store is disabled """ @@ -128,10 +125,10 @@ def parse_arguments() -> argparse.Namespace: help="mnkl dimensions (comma-separated)", ) parser.add_argument( - "--tile_shape_mnk", + "--tile_shape_mn", type=parse_comma_separated_ints, - choices=[(128, 128, 64), (128, 256, 64), (128, 64, 64), (64, 64, 64)], - default=(128, 128, 64), + choices=[(128, 128), (128, 256), (128, 64), (64, 64)], + default=(128, 128), help="Cta tile shape (comma-separated)", ) parser.add_argument( @@ -190,8 +187,8 @@ def parse_arguments() -> argparse.Namespace: if len(args.mnkl) != 4: parser.error("--mnkl must contain exactly 4 values") - if len(args.tile_shape_mnk) != 3: - parser.error("--tile_shape_mnk must contain exactly 3 values") + if len(args.tile_shape_mn) != 2: + parser.error("--tile_shape_mn must contain exactly 2 values") if len(args.cluster_shape_mn) != 2: parser.error("--cluster_shape_mn must contain exactly 2 values") @@ -210,10 +207,10 @@ class HopperWgmmaGemmKernel: :param acc_dtype: Data type for accumulation during computation :type acc_dtype: type[cutlass.Numeric] - :param tile_shape_mnk: Shape of the CTA tile (M,N,K) - :type tile_shape_mnk: Tuple[int, int, int] - :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing - :type cluster_shape_mnk: Tuple[int, int, int] + :param tile_shape_mn: Shape of the CTA tile (M,N) + :type tile_shape_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] :note: Data type requirements: - For 16-bit types: A and B must have the same data type @@ -236,8 +233,8 @@ class HopperWgmmaGemmKernel: Example: >>> gemm = HopperWgmmaGemmKernel( ... acc_dtype=cutlass.Float32, - ... tile_shape_mnk=(128, 256, 64), - ... cluster_shape_mnk=(1, 1, 1) + ... tile_shape_mn=(128, 256), + ... cluster_shape_mn=(1, 1) ... ) >>> gemm(a_tensor, b_tensor, c_tensor, stream) """ @@ -245,8 +242,8 @@ class HopperWgmmaGemmKernel: def __init__( self, acc_dtype: type[cutlass.Numeric], - tile_shape_mnk: tuple[int, int, int], - cluster_shape_mnk: tuple[int, int, int], + tile_shape_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], ): """ Initializes the configuration for a Hopper dense GEMM kernel. @@ -256,28 +253,30 @@ class HopperWgmmaGemmKernel: :param acc_dtype: Data type for accumulation during computation :type acc_dtype: type[cutlass.Numeric] - :param tile_shape_mnk: Shape of the CTA tile (M,N,K) - :type tile_shape_mnk: Tuple[int, int, int] - :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing - :type cluster_shape_mnk: Tuple[int, int, int] + :param tile_shape_mn: Shape of the CTA tile (M,N) + :type tile_shape_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] """ self.acc_dtype = acc_dtype - self.cluster_shape_mnk = cluster_shape_mnk + self.cluster_shape_mn = cluster_shape_mn self.mma_inst_shape_mn = None - self.tile_shape_mnk = tuple(tile_shape_mnk) + # K dimension is deferred in _setup_attributes + self.tile_shape_mnk = (*tile_shape_mn, 1) # For large tile size, using two warp groups is preferred because using only one warp # group may result in register spill self.atom_layout_mnk = ( (2, 1, 1) - if tile_shape_mnk[0] > 64 and tile_shape_mnk[1] > 128 + if self.tile_shape_mnk[0] > 64 and self.tile_shape_mnk[1] > 128 else (1, 1, 1) ) self.num_mcast_ctas_a = None self.num_mcast_ctas_b = None self.is_a_mcast = False self.is_b_mcast = False + self.tiled_mma = None self.occupancy = 1 self.mma_warp_groups = math.prod(self.atom_layout_mnk) @@ -315,12 +314,27 @@ class HopperWgmmaGemmKernel: raise ValueError("CTA tile shape M must be 64/128") if self.tile_shape_mnk[1] not in [64, 128, 256]: raise ValueError("CTA tile shape N must be 64/128/256") - if self.tile_shape_mnk[2] not in [64]: - raise ValueError("CTA tile shape K must be 64") - self.cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk) - self.num_mcast_ctas_a = self.cluster_shape_mnk[1] - self.num_mcast_ctas_b = self.cluster_shape_mnk[0] + self.tiled_mma = sm90_utils.make_trivial_tiled_mma( + self.a_dtype, + self.b_dtype, + self.a_layout.sm90_mma_major_mode(), + self.b_layout.sm90_mma_major_mode(), + self.acc_dtype, + self.atom_layout_mnk, + tiler_mn=(64, self.tile_shape_mnk[1]), + ) + mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.tile_shape_mnk = ( + self.tile_shape_mnk[0], + self.tile_shape_mnk[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + + self.cta_layout_mnk = cute.make_layout((*self.cluster_shape_mn, 1)) + self.num_mcast_ctas_a = self.cluster_shape_mn[1] + self.num_mcast_ctas_b = self.cluster_shape_mn[0] self.is_a_mcast = self.num_mcast_ctas_a > 1 self.is_b_mcast = self.num_mcast_ctas_b > 1 @@ -401,28 +415,18 @@ class HopperWgmmaGemmKernel: self._setup_attributes() - tiled_mma = sm90_utils.make_trivial_tiled_mma( - self.a_dtype, - self.b_dtype, - self.a_layout.sm90_mma_major_mode(), - self.b_layout.sm90_mma_major_mode(), - self.acc_dtype, - self.atom_layout_mnk, - tiler_mn=(64, self.tile_shape_mnk[1]), - ) - tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors( a, self.a_smem_layout_staged, (self.tile_shape_mnk[0], self.tile_shape_mnk[2]), - self.cluster_shape_mnk[1], + self.cluster_shape_mn[1], ) tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors( b, self.b_smem_layout_staged, (self.tile_shape_mnk[1], self.tile_shape_mnk[2]), - self.cluster_shape_mnk[0], + self.cluster_shape_mn[0], ) tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors( @@ -431,20 +435,20 @@ class HopperWgmmaGemmKernel: self.epi_tile, ) - grid = self._compute_grid(c, self.tile_shape_mnk, self.cluster_shape_mnk) + grid = self._compute_grid(c, self.tile_shape_mnk, self.cluster_shape_mn) @cute.struct class SharedStorage: mainloop_pipeline_array_ptr: cute.struct.MemRange[ cutlass.Int64, self.ab_stage * 2 ] - sa: cute.struct.Align[ + sA: cute.struct.Align[ cute.struct.MemRange[ self.a_dtype, cute.cosize(self.a_smem_layout_staged) ], self.buffer_align_bytes, ] - sb: cute.struct.Align[ + sB: cute.struct.Align[ cute.struct.MemRange[ self.b_dtype, cute.cosize(self.b_smem_layout_staged) ], @@ -461,7 +465,7 @@ class HopperWgmmaGemmKernel: tma_tensor_b, tma_atom_c, tma_tensor_c, - tiled_mma, + self.tiled_mma, self.cta_layout_mnk, self.a_smem_layout_staged, self.b_smem_layout_staged, @@ -469,8 +473,7 @@ class HopperWgmmaGemmKernel: ).launch( grid=grid, block=[self.threads_per_cta, 1, 1], - cluster=self.cluster_shape_mnk, - smem=self.shared_storage.size_in_bytes(), + cluster=(*self.cluster_shape_mn, 1), stream=stream, ) return @@ -562,8 +565,8 @@ class HopperWgmmaGemmKernel: # Get the pid from cluster id bidx_in_cluster = cute.arch.block_in_cluster_idx() - pid_m = cid_m * self.cluster_shape_mnk[0] + bidx_in_cluster[0] - pid_n = cid_n * self.cluster_shape_mnk[1] + bidx_in_cluster[1] + pid_m = cid_m * self.cluster_shape_mn[0] + bidx_in_cluster[0] + pid_n = cid_n * self.cluster_shape_mn[1] + bidx_in_cluster[1] tile_coord_mnkl = (pid_m, pid_n, None, bidz) cta_rank_in_cluster = cute.arch.make_warp_uniform( @@ -621,22 +624,22 @@ class HopperWgmmaGemmKernel: ) # Cluster arrive after barrier init - if cute.size(self.cluster_shape_mnk) > 1: + if cute.size(self.cluster_shape_mn) > 1: cute.arch.cluster_arrive_relaxed() # /////////////////////////////////////////////////////////////////////////////// # Generate smem tensor A/B # /////////////////////////////////////////////////////////////////////////////// - sa = storage.sa.get_tensor( + sA = storage.sA.get_tensor( a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner ) - sb = storage.sb.get_tensor( + sB = storage.sB.get_tensor( b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner ) - sc_ptr = cute.recast_ptr( - sa.iterator, epi_smem_layout_staged.inner, dtype=self.c_dtype + sC_ptr = cute.recast_ptr( + sA.iterator, epi_smem_layout_staged.inner, dtype=self.c_dtype ) - sc = cute.make_tensor(sc_ptr, epi_smem_layout_staged.outer) + sC = cute.make_tensor(sC_ptr, epi_smem_layout_staged.outer) # /////////////////////////////////////////////////////////////////////////////// # Local_tile partition global tensors @@ -673,34 +676,34 @@ class HopperWgmmaGemmKernel: # TMA load A partition_S/D a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape) a_cta_crd = cluster_coord_mnk[1] - sa_for_tma_partition = cute.group_modes(sa, 0, 2) + sA_for_tma_partition = cute.group_modes(sA, 0, 2) gA_for_tma_partition = cute.group_modes(gA_mkl, 0, 2) tAsA, tAgA_mkl = cute.nvgpu.cpasync.tma_partition( tma_atom_a, a_cta_crd, a_cta_layout, - sa_for_tma_partition, + sA_for_tma_partition, gA_for_tma_partition, ) # TMA load B partition_S/D b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape) b_cta_crd = cluster_coord_mnk[0] - sb_for_tma_partition = cute.group_modes(sb, 0, 2) + sB_for_tma_partition = cute.group_modes(sB, 0, 2) gB_for_tma_partition = cute.group_modes(gB_nkl, 0, 2) tBsB, tBgB_nkl = cute.nvgpu.cpasync.tma_partition( tma_atom_b, b_cta_crd, b_cta_layout, - sb_for_tma_partition, + sB_for_tma_partition, gB_for_tma_partition, ) # ////////////////////////////////////////////////////////////////////////////// - # Make frangments + # Make fragments # ////////////////////////////////////////////////////////////////////////////// - tCsA = thr_mma.partition_A(sa) - tCsB = thr_mma.partition_B(sb) + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) tCrA = tiled_mma.make_fragment_A(tCsA) tCrB = tiled_mma.make_fragment_B(tCsB) @@ -711,7 +714,7 @@ class HopperWgmmaGemmKernel: # Cluster wait # /////////////////////////////////////////////////////////////////////////////// # cluster wait for barrier init - if cute.size(self.cluster_shape_mnk) > 1: + if cute.size(self.cluster_shape_mn) > 1: cute.arch.cluster_wait() else: cute.arch.sync_threads() @@ -788,7 +791,7 @@ class HopperWgmmaGemmKernel: tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False) num_k_blocks = cute.size(tCrA, mode=[2]) - for k_tile in range(k_pipe_mmas): + for k_tile in cutlass.range_constexpr(k_pipe_mmas): # Wait for A/B buffer to be ready mainloop_pipeline.consumer_wait( mainloop_consumer_read_state, peek_ab_full_status @@ -917,7 +920,7 @@ class HopperWgmmaGemmKernel: # ///////////////////////////////////////////////////////////////////////////// cute.nvgpu.warpgroup.wait_group(0) - if cute.size(self.cluster_shape_mnk) > 1: + if cute.size(self.cluster_shape_mn) > 1: # Wait for all threads in the cluster to finish, avoid early release of smem cute.arch.cluster_arrive() cute.arch.cluster_wait() @@ -950,33 +953,45 @@ class HopperWgmmaGemmKernel: # (R2S, R2S_M, R2S_N, PIPE_D) thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) - tRS_sD = thr_copy_r2s.partition_D(sc) + tRS_sD = thr_copy_r2s.partition_D(sC) # (R2S, R2S_M, R2S_N) tRS_rAcc = tiled_copy_r2s.retile(accumulators) # Allocate D registers. - rD_shape = cute.shape(thr_copy_r2s.partition_S(sc)) + rD_shape = cute.shape(thr_copy_r2s.partition_S(sC)) tRS_rD_layout = cute.make_layout(rD_shape[:3]) tRS_rD = cute.make_fragment_like(tRS_rD_layout, self.acc_dtype) size_tRS_rD = cute.size(tRS_rD) - sepi_for_tma_partition = cute.group_modes(sc, 0, 2) - tcgc_for_tma_partition = cute.zipped_divide(gC_mnl, self.epi_tile) + sepi_for_tma_partition = cute.group_modes(sC, 0, 2) + tCgC_for_tma_partition = cute.zipped_divide(gC_mnl, self.epi_tile) bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition( tma_atom_c, 0, cute.make_layout(1), sepi_for_tma_partition, - tcgc_for_tma_partition, + tCgC_for_tma_partition, ) - epi_tile_num = cute.size(tcgc_for_tma_partition, mode=[1]) - epi_tile_shape = tcgc_for_tma_partition.shape[1] + epi_tile_num = cute.size(tCgC_for_tma_partition, mode=[1]) + epi_tile_shape = tCgC_for_tma_partition.shape[1] + epi_tile_layout = cute.make_layout( + epi_tile_shape, stride=(epi_tile_shape[1], 1) + ) - for epi_idx in cutlass.range(epi_tile_num, unroll=epi_tile_num): + # Initialize tma store c_pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.epi_stage, + producer_group=c_producer_group, + ) + + for epi_idx in cutlass.range_constexpr(epi_tile_num): # Copy from accumulators to D registers - for epi_v in range(size_tRS_rD): + for epi_v in cutlass.range_constexpr(size_tRS_rD): tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v] # Type conversion @@ -997,10 +1012,6 @@ class HopperWgmmaGemmKernel: # barrier for sync cute.arch.barrier() - # Get the global memory coordinate for the current epi tile. - epi_tile_layout = cute.make_layout( - epi_tile_shape, stride=(epi_tile_shape[1], 1) - ) gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) # Copy from shared memory to global memory if warp_idx == 0: @@ -1009,11 +1020,14 @@ class HopperWgmmaGemmKernel: bSG_sD[(None, epi_buffer)], bSG_gD[(None, gmem_coord)], ) - cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True) + c_pipeline.producer_commit() + c_pipeline.producer_acquire() cute.arch.barrier() + if warp_idx == 0: + c_pipeline.producer_tail() + return @staticmethod @@ -1055,9 +1069,7 @@ class HopperWgmmaGemmKernel: mbar_helpers_bytes = 1024 ab_stage = ( - (smem_capacity - occupancy * 1024) // occupancy - - mbar_helpers_bytes - - epi_bytes + smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes ) // ab_bytes_per_stage return ab_stage, epi_stage @@ -1195,7 +1207,7 @@ class HopperWgmmaGemmKernel: def _compute_grid( c: cute.Tensor, tile_shape_mnk: tuple[int, int, int], - cluster_shape_mnk: tuple[int, int, int], + cluster_shape_mn: tuple[int, int], ) -> tuple[int, int, int]: """Compute grid shape for the output tensor C. @@ -1203,8 +1215,8 @@ class HopperWgmmaGemmKernel: :type c: cute.Tensor :param tile_shape_mnk: The shape (M, N, K) of the CTA tile. :type tile_shape_mnk: tuple[int, int, int] - :param cluster_shape_mnk: Shape of each cluster in M, N, K dimensions. - :type cluster_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] :return: Grid shape for kernel launch. :rtype: tuple[int, int, int] @@ -1212,8 +1224,9 @@ class HopperWgmmaGemmKernel: c_shape = (tile_shape_mnk[0], tile_shape_mnk[1]) gc = cute.zipped_divide(c, tiler=c_shape) - clusters = cute.ceil_div(cute.get(gc.layout, mode=[1]).shape, cluster_shape_mnk) - grid = tuple(x * y for x, y in zip(clusters, cluster_shape_mnk)) + cluster_shape_mnl = (*cluster_shape_mn, 1) + clusters = cute.ceil_div(cute.get(gc.layout, mode=[1]).shape, cluster_shape_mnl) + grid = tuple(x * y for x, y in zip(clusters, cluster_shape_mnl)) return grid @staticmethod @@ -1363,7 +1376,7 @@ def run( a_major: str, b_major: str, c_major: str, - tile_shape_mnk: Tuple[int, int, int], + tile_shape_mn: Tuple[int, int], cluster_shape_mn: Tuple[int, int], tolerance: float, warmup_iterations: int, @@ -1387,8 +1400,8 @@ def run( :type acc_dtype: Type[cutlass.Numeric] :param a_major/b_major/c_major: Memory layout of tensor A/B/C :type a_major/b_major/c_major: str - :param tile_shape_mnk: CTA tile shape (M, N, K) - :type tile_shape_mnk: Tuple[int, int, int] + :param tile_shape_mn: CTA tile shape (M, N) + :type tile_shape_mn: Tuple[int, int] :param cluster_shape_mn: Cluster shape (M, N) :type cluster_shape_mn: Tuple[int, int] :param tolerance: Tolerance value for reference validation comparison @@ -1411,7 +1424,7 @@ def run( f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}" ) print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") - print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}") + print(f"Tile Shape: {tile_shape_mn}, Cluster Shape: {cluster_shape_mn}") print(f"Tolerance: {tolerance}") print(f"Warmup iterations: {warmup_iterations}") print(f"Iterations: {iterations}") @@ -1420,7 +1433,6 @@ def run( # Unpack parameters m, n, k, l = mnkl - cluster_shape_mnk = (*cluster_shape_mn, 1) # Skip unsupported types if not HopperWgmmaGemmKernel.is_valid_dtypes( @@ -1488,7 +1500,7 @@ def run( b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype) c, mC, c_torch = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype) - gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mnk, cluster_shape_mnk) + gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mn, cluster_shape_mn) torch_stream = torch.cuda.Stream() stream = cuda.CUstream(torch_stream.cuda_stream) @@ -1572,7 +1584,7 @@ if __name__ == "__main__": args.a_major, args.b_major, args.c_major, - args.tile_shape_mnk, + args.tile_shape_mn, args.cluster_shape_mn, args.tolerance, args.warmup_iterations, diff --git a/examples/python/CuTeDSL/notebooks/tensorssa.ipynb b/examples/python/CuTeDSL/notebooks/tensorssa.ipynb index 3e812681..40a56d97 100644 --- a/examples/python/CuTeDSL/notebooks/tensorssa.ipynb +++ b/examples/python/CuTeDSL/notebooks/tensorssa.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -41,22 +41,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "a_vec: tensor_value o (3, 4)>\n", - "b_vec: tensor_value o (3, 4)>\n", - "tensor(raw_ptr(0x0000000006cff170: f32, generic, align<4>) o (3,4):(4,1), data=\n", - " [[ 2.000000, 2.000000, 2.000000, 2.000000, ],\n", - " [ 2.000000, 2.000000, 2.000000, 2.000000, ],\n", - " [ 2.000000, 2.000000, 2.000000, 2.000000, ]])\n" - ] - } - ], + "outputs": [], "source": [ "@cute.jit\n", "def load_and_store(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n", @@ -91,22 +78,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor_value o (4, 2, 3)> -> tensor_value o (4, 3)>\n", - "tensor(raw_ptr(0x00000000071acaf0: f32, generic, align<4>) o (4,3):(3,1), data=\n", - " [[ 3.000000, 4.000000, 5.000000, ],\n", - " [ 9.000000, 10.000000, 11.000000, ],\n", - " [ 15.000000, 16.000000, 17.000000, ],\n", - " [ 21.000000, 22.000000, 23.000000, ]])\n" - ] - } - ], + "outputs": [], "source": [ "@cute.jit\n", "def apply_slice(src: cute.Tensor, dst: cute.Tensor, indices: cutlass.Constexpr):\n", @@ -155,19 +129,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor_value o (4, 2, 3)> -> ?\n", - "tensor(raw_ptr(0x00000000013cbbe0: f32, generic, align<4>) o (1):(1), data=\n", - " [ 10.000000, ])\n" - ] - } - ], + "outputs": [], "source": [ "def slice_2():\n", " src_shape = (4, 2, 3)\n", @@ -195,40 +159,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n", - " [ 3.000000, ],\n", - " [ 3.000000, ],\n", - " [ 3.000000, ])\n", - "tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n", - " [-1.000000, ],\n", - " [-1.000000, ],\n", - " [-1.000000, ])\n", - "tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n", - " [ 2.000000, ],\n", - " [ 2.000000, ],\n", - " [ 2.000000, ])\n", - "tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n", - " [ 0.500000, ],\n", - " [ 0.500000, ],\n", - " [ 0.500000, ])\n", - "tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n", - " [ 0.000000, ],\n", - " [ 0.000000, ],\n", - " [ 0.000000, ])\n", - "tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n", - " [ 1.000000, ],\n", - " [ 1.000000, ],\n", - " [ 1.000000, ])\n" - ] - } - ], + "outputs": [], "source": [ "@cute.jit\n", "def binary_op_1(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n", @@ -236,28 +169,22 @@ " b_vec = b.load()\n", "\n", " add_res = a_vec + b_vec\n", - " res.store(add_res)\n", - " cute.print_tensor(res) # prints [3.000000, 3.000000, 3.000000]\n", + " cute.print_tensor(add_res) # prints [3.000000, 3.000000, 3.000000]\n", "\n", " sub_res = a_vec - b_vec\n", - " res.store(sub_res)\n", - " cute.print_tensor(res) # prints [-1.000000, -1.000000, -1.000000]\n", + " cute.print_tensor(sub_res) # prints [-1.000000, -1.000000, -1.000000]\n", "\n", " mul_res = a_vec * b_vec\n", - " res.store(mul_res)\n", - " cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n", + " cute.print_tensor(mul_res) # prints [2.000000, 2.000000, 2.000000]\n", "\n", " div_res = a_vec / b_vec\n", - " res.store(div_res)\n", - " cute.print_tensor(res) # prints [0.500000, 0.500000, 0.500000]\n", + " cute.print_tensor(div_res) # prints [0.500000, 0.500000, 0.500000]\n", "\n", " floor_div_res = a_vec // b_vec\n", - " res.store(floor_div_res)\n", - " cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n", + " cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n", "\n", " mod_res = a_vec % b_vec\n", - " res.store(mod_res)\n", - " cute.print_tensor(res) # prints [1.000000, 1.000000, 1.000000]\n", + " cute.print_tensor(mod_res) # prints [1.000000, 1.000000, 1.000000]\n", "\n", "\n", "a = np.empty((3,), dtype=np.float32)\n", @@ -270,68 +197,31 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n", - " [ 3.000000, ],\n", - " [ 3.000000, ],\n", - " [ 3.000000, ])\n", - "tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n", - " [-1.000000, ],\n", - " [-1.000000, ],\n", - " [-1.000000, ])\n", - "tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n", - " [ 2.000000, ],\n", - " [ 2.000000, ],\n", - " [ 2.000000, ])\n", - "tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n", - " [ 0.500000, ],\n", - " [ 0.500000, ],\n", - " [ 0.500000, ])\n", - "tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n", - " [ 0.000000, ],\n", - " [ 0.000000, ],\n", - " [ 0.000000, ])\n", - "tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n", - " [ 1.000000, ],\n", - " [ 1.000000, ],\n", - " [ 1.000000, ])\n" - ] - } - ], + "outputs": [], "source": [ "@cute.jit\n", "def binary_op_2(res: cute.Tensor, a: cute.Tensor, c: cutlass.Constexpr):\n", " a_vec = a.load()\n", "\n", " add_res = a_vec + c\n", - " res.store(add_res)\n", - " cute.print_tensor(res) # prints [3.000000, 3.000000, 3.000000]\n", + " cute.print_tensor(add_res) # prints [3.000000, 3.000000, 3.000000]\n", "\n", " sub_res = a_vec - c\n", - " res.store(sub_res)\n", - " cute.print_tensor(res) # prints [-1.000000, -1.000000, -1.000000]\n", + " cute.print_tensor(sub_res) # prints [-1.000000, -1.000000, -1.000000]\n", "\n", " mul_res = a_vec * c\n", - " res.store(mul_res)\n", - " cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n", + " cute.print_tensor(mul_res) # prints [2.000000, 2.000000, 2.000000]\n", "\n", " div_res = a_vec / c\n", - " res.store(div_res)\n", - " cute.print_tensor(res) # prints [0.500000, 0.500000, 0.500000]\n", + " cute.print_tensor(div_res) # prints [0.500000, 0.500000, 0.500000]\n", "\n", " floor_div_res = a_vec // c\n", - " res.store(floor_div_res)\n", - " cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n", + " cute.print_tensor(floor_div_res) # prints [0.000000, 0.000000, 0.000000]\n", "\n", " mod_res = a_vec % c\n", - " res.store(mod_res)\n", - " cute.print_tensor(res) # prints [1.000000, 1.000000, 1.000000]\n", + " cute.print_tensor(mod_res) # prints [1.000000, 1.000000, 1.000000]\n", "\n", "a = np.empty((3,), dtype=np.float32)\n", "a.fill(1.0)\n", @@ -342,17 +232,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[False True False]\n" - ] - } - ], + "outputs": [], "source": [ "@cute.jit\n", "def binary_op_3(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n", @@ -378,17 +260,9 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[3 0 7]\n" - ] - } - ], + "outputs": [], "source": [ "@cute.jit\n", "def binary_op_4(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n", @@ -420,44 +294,23 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n", - " [ 2.000000, ],\n", - " [ 2.000000, ],\n", - " [ 2.000000, ])\n", - "tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n", - " [-0.756802, ],\n", - " [-0.756802, ],\n", - " [-0.756802, ])\n", - "tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n", - " [ 16.000000, ],\n", - " [ 16.000000, ],\n", - " [ 16.000000, ])\n" - ] - } - ], + "outputs": [], "source": [ "@cute.jit\n", "def unary_op_1(res: cute.Tensor, a: cute.Tensor):\n", " a_vec = a.load()\n", "\n", " sqrt_res = cute.math.sqrt(a_vec)\n", - " res.store(sqrt_res)\n", - " cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n", + " cute.print_tensor(sqrt_res) # prints [2.000000, 2.000000, 2.000000]\n", "\n", " sin_res = cute.math.sin(a_vec)\n", " res.store(sin_res)\n", - " cute.print_tensor(res) # prints [-0.756802, -0.756802, -0.756802]\n", + " cute.print_tensor(sin_res) # prints [-0.756802, -0.756802, -0.756802]\n", "\n", " exp2_res = cute.math.exp2(a_vec)\n", - " res.store(exp2_res)\n", - " cute.print_tensor(res) # prints [16.000000, 16.000000, 16.000000]\n", + " cute.print_tensor(exp2_res) # prints [16.000000, 16.000000, 16.000000]\n", "\n", "a = np.array([4.0, 4.0, 4.0], dtype=np.float32)\n", "res = np.empty((3,), dtype=np.float32)\n", @@ -470,29 +323,18 @@ "source": [ "#### Reduction Operation\n", "\n", - "The `TensorSSA`'s `reduce` method applies a specified reduction operation (`ReductionOp.ADD`, `ReductionOp.MUL`, `ReductionOp.MAX`, `ReductionOp.MIN`) starting with an initial value, and performs this reduction along the dimensions specified by the `reduction_profile.`. The result is typically a new `TensorSSA` with reduced dimensions or a scalar value if reduces across all axes." + "The `TensorSSA`'s `reduce` method applies a specified reduction operation (`ReductionOp.ADD`, \n", + "`ReductionOp.MUL`, `ReductionOp.MAX`, `ReductionOp.MIN`) starting with an initial value, and \n", + "performs this reduction along the dimensions specified by the `reduction_profile`. The result \n", + "is typically a new `TensorSSA` with reduced dimensions or a scalar value if it reduces across \n", + "all axes." ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "21.000000\n", - "tensor(raw_ptr(0x00007ffd1ea2bca0: f32, rmem, align<32>) o (2):(1), data=\n", - " [ 6.000000, ],\n", - " [ 15.000000, ])\n", - "tensor(raw_ptr(0x00007ffd1ea2bcc0: f32, rmem, align<32>) o (3):(1), data=\n", - " [ 6.000000, ],\n", - " [ 8.000000, ],\n", - " [ 10.000000, ])\n" - ] - } - ], + "outputs": [], "source": [ "@cute.jit\n", "def reduction_op(a: cute.Tensor):\n", @@ -507,36 +349,138 @@ " 0.0,\n", " reduction_profile=0\n", " )\n", - " cute.printf(red_res) # prints 21.000000\n", + " cute.printf(red_res) # prints 21.000000\n", "\n", " red_res = a_vec.reduce(\n", " cute.ReductionOp.ADD,\n", " 0.0,\n", " reduction_profile=(None, 1)\n", " )\n", - " # We can't print the TensorSSA directly at this point, so we store it to a new Tensor and print it.\n", - " res = cute.make_fragment(red_res.shape, cutlass.Float32)\n", - " res.store(red_res)\n", - " cute.print_tensor(res) # prints [6.000000, 15.000000]\n", + " cute.print_tensor(red_res) # prints [6.000000, 15.000000]\n", "\n", " red_res = a_vec.reduce(\n", " cute.ReductionOp.ADD,\n", " 1.0,\n", " reduction_profile=(1, None)\n", " )\n", - " res = cute.make_fragment(red_res.shape, cutlass.Float32)\n", - " res.store(red_res)\n", - " cute.print_tensor(res) # prints [6.000000, 8.000000, 10.000000]\n", + " cute.print_tensor(red_res) # prints [6.000000, 8.000000, 10.000000]\n", "\n", "\n", "a = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)\n", "reduction_op(from_dlpack(a))" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Broadcast\n", + "\n", + "`TensorSSA` supports broadcasting operations following NumPy's broadcasting rules. Broadcasting \n", + "allows you to perform operations on arrays of different shapes when certain conditions are met. \n", + "The key rules are:\n", + "\n", + "1. Source shape is padded with 1's to match the rank of target shape\n", + "2. The size in each mode of source shape must either be 1 or equal to target shape\n", + "3. After broadcasting, all modes should match target shape\n", + "\n", + "Let's look at some examples of broadcasting in action:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import cutlass\n", + "import cutlass.cute as cute\n", + "\n", + "\n", + "@cute.jit\n", + "def broadcast_examples():\n", + " a = cute.make_fragment((1,3), dtype=cutlass.Float32)\n", + " a[0] = 0.0\n", + " a[1] = 1.0\n", + " a[2] = 2.0\n", + " a_val = a.load()\n", + " cute.print_tensor(a_val.broadcast_to((4, 3)))\n", + " # tensor(raw_ptr(0x00007ffe26625740: f32, rmem, align<32>) o (4,3):(1,4), data=\n", + " # [[ 0.000000, 1.000000, 2.000000, ],\n", + " # [ 0.000000, 1.000000, 2.000000, ],\n", + " # [ 0.000000, 1.000000, 2.000000, ],\n", + " # [ 0.000000, 1.000000, 2.000000, ]])\n", + "\n", + " c = cute.make_fragment((4,1), dtype=cutlass.Float32)\n", + " c[0] = 0.0\n", + " c[1] = 1.0\n", + " c[2] = 2.0\n", + " c[3] = 3.0\n", + " cute.print_tensor(a.load() + c.load())\n", + " # tensor(raw_ptr(0x00007ffe26625780: f32, rmem, align<32>) o (4,3):(1,4), data=\n", + " # [[ 0.000000, 1.000000, 2.000000, ],\n", + " # [ 1.000000, 2.000000, 3.000000, ],\n", + " # [ 2.000000, 3.000000, 4.000000, ],\n", + " # [ 3.000000, 4.000000, 5.000000, ]])\n", + "\n", + "\n", + "broadcast_examples()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "The examples above demonstrate two key broadcasting scenarios:\n", + "\n", + "1. **Row Vector Broadcasting**: In the first example, we create a row vector `a` with shape \n", + " (1, 3) containing values [0.0, 1.0, 2.0]. When we broadcast it to shape (4, 3), the values \n", + " are repeated across the first dimension, resulting in:\n", + " ```\n", + " [[0.0, 1.0, 2.0],\n", + " [0.0, 1.0, 2.0],\n", + " [0.0, 1.0, 2.0],\n", + " [0.0, 1.0, 2.0]]\n", + " ```\n", + " This demonstrates how a row vector can be broadcast to create multiple identical rows.\n", + "\n", + "2. **Column Vector and Row Vector Addition**: In the second example, we have:\n", + " - A row vector `a` with shape (1, 3) containing [0.0, 1.0, 2.0]\n", + " - A column vector `c` with shape (4, 1) containing [0.0, 1.0, 2.0, 3.0]\n", + " \n", + " When we add these together, both vectors are broadcast to shape (4, 3):\n", + " - The row vector is broadcast vertically (4 times)\n", + " - The column vector is broadcast horizontally (3 times)\n", + " \n", + " The result is:\n", + " ```\n", + " [[0.0 + 0.0, 1.0 + 0.0, 2.0 + 0.0],\n", + " [0.0 + 1.0, 1.0 + 1.0, 2.0 + 1.0],\n", + " [0.0 + 2.0, 1.0 + 2.0, 2.0 + 2.0],\n", + " [0.0 + 3.0, 1.0 + 3.0, 2.0 + 3.0]]\n", + " ```\n", + " =\n", + " ```\n", + " [[0.0, 1.0, 2.0],\n", + " [1.0, 2.0, 3.0],\n", + " [2.0, 3.0, 4.0],\n", + " [3.0, 4.0, 5.0]]\n", + " ```\n", + "\n", + "This demonstrates how `TensorSSA` can automatically handle broadcasting of both row and column \n", + "vectors in arithmetic operations, following the broadcasting rules where each dimension must \n", + "either be 1 or match the target size. The broadcasting is handled implicitly during operations, \n", + "making it easy to work with tensors of different shapes.\n" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv3_12", "language": "python", "name": "python3" }, @@ -550,7 +494,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.12.10" } }, "nbformat": 4, diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 58739ef9..718b3421 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -539,7 +539,8 @@ make_cotiled_copy(Copy_Atom const& copy_atom, auto layout_tv_data = composition(inv_data_layout, atom_tv_layout); // Check validity - CUTE_STATIC_ASSERT_V(coalesce(composition(data_layout, layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)), + // Append 1:0 to data_layout so that OOB coordinates get the stride-0 + CUTE_STATIC_ASSERT_V(coalesce(composition(make_layout(data_layout, Layout<_1,_0>{}), layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)), "The memory pointed to by AtomTVLayout does not exist in the DataLayout."); // // Tiler -- Find the active elements in the DATA tensor and generate a tiler to extract them diff --git a/include/cutlass/detail/collective/mixed_input_utils.hpp b/include/cutlass/detail/collective/mixed_input_utils.hpp index b6a19c94..89d25000 100644 --- a/include/cutlass/detail/collective/mixed_input_utils.hpp +++ b/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -705,7 +705,7 @@ public: auto smem_tiled_copy_S = cute::get<0>(partitioned_transform_extra_info); auto&& scales = cute::get<1>(partitioned_transform_extra_info); using ScaleType = decltype(scales); - auto tSrS = make_tensor(static_cast(scales).data(), scales.layout()); + auto tSrS = make_tensor(scales.data(), scales.layout()); auto tSsS = cute::get<2>(partitioned_transform_extra_info); copy(smem_tiled_copy_S, tSsS(_,_,_,_,load2transform_consumer_index), tSrS); @@ -714,7 +714,7 @@ public: } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { auto&& zeros = cute::get<3>(partitioned_transform_extra_info); using ZeroType = decltype(zeros); - auto tZrZ = make_tensor(static_cast(zeros).data(), zeros.layout()); + auto tZrZ = make_tensor(zeros.data(), zeros.layout()); auto tZsZ = cute::get<4>(partitioned_transform_extra_info); copy(smem_tiled_copy_S, tZsZ(_,_,_,_,load2transform_consumer_index), tZrZ); @@ -1002,7 +1002,6 @@ public: auto src_arr = recast(src); auto dst_arr = recast(dst); - Tensor src_vm = cute::group_modes<1,-1>(cute::zipped_divide(src, pack)); Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, pack)); cute::transform(src_arr, dst_arr, Converter::convert); @@ -1019,7 +1018,6 @@ public: auto scale_arr = recast(filter_zeros(scales)); if constexpr (is_same_v){ - Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, pack)); Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, pack)); for (int i = 0; i < size<1>(dst_vm); ++i){ @@ -1194,13 +1192,7 @@ public: Tensor tCsS = cta_mma.partition_A(sS); Tensor tSsS = smem_thr_copy_S.partition_S(tCsS); Tensor tSrS = make_tensor(tSsS(_,_,_,_,0).shape()); -#if 0 - if(cute::thread(128, 0)){ - print("sS: ");print(sS);print("\n"); - print("tSsS: ");print(tSsS);print("\n"); - print("tSrS: ");print(tSrS);print("\n"); - } -#endif + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS); } @@ -1209,16 +1201,6 @@ public: Tensor tCsZ = cta_mma.partition_A(sZ); Tensor tZsZ = smem_thr_copy_S.partition_S(tCsZ); Tensor tZrZ = make_tensor(tZsZ(_,_,_,_,0).shape()); -#if 0 - if(cute::thread(128, 0)){ - print("sS: ");print(sS);print("\n"); - print("tSsS: ");print(tSsS);print("\n"); - print("tSrS: ");print(tSrS);print("\n"); - print("sZ: ");print(sZ);print("\n"); - print("tZsZ: ");print(tZsZ);print("\n"); - print("tZrZ: ");print(tZrZ);print("\n"); - } -#endif return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS, tZrZ, tZsZ); } else { diff --git a/include/cutlass/epilogue/collective/builders/sm100_builder.inl b/include/cutlass/epilogue/collective/builders/sm100_builder.inl index 4c3fd96b..8634134b 100644 --- a/include/cutlass/epilogue/collective/builders/sm100_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm100_builder.inl @@ -582,7 +582,13 @@ sm100_sparse_get_tma_dispatch_policy() { * Selected op also maximizes the TMEM_LOAD shape in order to minimize TMEM_LOADs issued, * subject to the constraint of the provided per-warp tmem subpartition shape **/ -template +template< + class GmemStrideTypeD, + class ElementAccumulator, + class ElementD, + class TmemShape_MN, + bool IsBlockScaleSupported +> constexpr auto sm100_get_tmem_load_op() { using namespace cute; @@ -958,6 +964,172 @@ struct CallbacksBuilder< >; }; +// Attempts to compute a reasonably performant epilogue tile or allows the user to provide one. +template< + class OpClass, + class CtaTileShape_MNK, + class EpilogueTileType, + class TmemWarpShape_MN, + class ElementC_, + class GmemStrideTypeC, + class ElementD, + class GmemStrideTypeD, + bool IsPerColScaleSupported +> +static constexpr auto +sm100_dense_compute_tile_shape_or_override() { + using namespace cute; + static_assert(!cute::is_same_v && !cute::is_same_v); + + constexpr bool DisableSource = cute::is_void_v; + using ElementC = cute::conditional_t; + + if constexpr (is_same_v && + is_same_v && + size<1>(CtaTileShape_MNK{}) == 256) { + constexpr int CtaM = size<0>(CtaTileShape_MNK{}); + constexpr int WarpM = size<0>(TmemWarpShape_MN{}); + constexpr int DpFull = 32; + constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load + // Note: + // Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile. + // This is a general workable epi_tile_N which does not promise best perf. + return make_tile(Int{}, Int<128>{}); + } + else if constexpr (is_same_v) { + constexpr int CtaM = size<0>(CtaTileShape_MNK{}); + constexpr int CtaN = size<1>(CtaTileShape_MNK{}); + constexpr int WarpM = size<0>(TmemWarpShape_MN{}); + constexpr int WarpN = size<1>(TmemWarpShape_MN{}); + constexpr int MaxBits = cute::max(sizeof_bits_v, sizeof_bits_v); + + constexpr int DpFull = 32; // tmem datapaths in 1 subpartition + constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load + constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf + // Epilogues w/o residual load are less sensitive to smem allocation + // Target a fixed amount of compute per epilogue iteration + if (DisableSource) { + if (MaxBits == 4) { + // Make epilogue tile larger to reduce the epilogue iterations. + // 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same. + constexpr int ComputeElts = 8192; + return ComputeElts / M; + } + constexpr int ComputeElts = 4096; + return ComputeElts / M; + } + // Epilogues w/ residual load are more sensitive to smem allocation + // Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize + else { + if (MaxBits == 32) { + return (CtaM > 64 && CtaN <= 128) ? 16 : 32; + } + // Per-column scaling is high register pressure, reduce tile to prevent spills + else if (IsPerColScaleSupported) { + return 32; + } + else if (MaxBits == 16) { + return (CtaN <= 128) ? 32 : 64; + } + else { + return 64; + } + } + }(); + constexpr int N_min_C = (DisableSource || detail::is_m_major()) ? 8 * WarpN + : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type + : 128 / sizeof_bits_v * WarpN; + constexpr int N_min_D = (detail::is_m_major()) ? 8 * WarpN + : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type + : 128 / sizeof_bits_v * WarpN; + constexpr int N = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D)); + static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small"); + + // stride by tmem warp layout and return a by-mode tiler + auto tile_m = Layout>{}; + auto tile_n = Layout,Int< WarpN>>, + Stride,Int>>{}; + + return make_tile(tile_m, coalesce(tile_n)); + } + else { + static_assert(cute::is_tuple::value && not is_layout::value, + "EpilogueTile must be a cute::Tile or cute::Shape"); + + EpilogueTileType epi_tile; + constexpr int M = size<0>(shape(epi_tile)); + constexpr int N = size<1>(shape(epi_tile)); + static_assert(N % 8 == 0, "Unsupported tile shape"); + + return epi_tile; + } +} + +template< + bool Is2SmMma, + class MmaTileShape_MNK +> +static constexpr auto +sm100_tmem_warps() { + if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) { + return Shape<_2,_2>{}; + } + else { + return Shape<_4,_1>{}; + } +} + +template< + bool Is2SmMma, + class MmaTileShape_MNK +> +static constexpr auto +sm100_cta_tile_shape() { + if constexpr (Is2SmMma) { // 2x1 threadblock shape + auto [mma_tile_m, mma_tile_n, mma_tile_k] = MmaTileShape_MNK{}; + auto cta_tile_m = reverse(shape_div(reverse(mma_tile_m), _2{})); // first MmaTile_M/2 elements, preserve multimode + return make_shape(cta_tile_m, mma_tile_n, mma_tile_k); + } + else { // 1x1 threadblock shape + return MmaTileShape_MNK{}; + } +} + +template< + class EpilogueScheduleType, + class ElementC_, + class ElementD, + int EpiTiles, + int FragmentSize +> +static constexpr auto +sm100_dense_dispatch_policy() { + // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation + constexpr bool ReuseSmem = sizeof_bits_v > 8; + // TMA store delay performs worse with residual loads + constexpr bool DelayTmaStore = is_void_v; + + constexpr int StagesD = cute::min(EpiTiles, 2); + constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1) + : cute::min(EpiTiles, 4); + + if constexpr (is_base_of_v || + is_base_of_v) { + return Sm100PtrArrayNoSmemWarpSpecialized{}; + } + else if constexpr (is_base_of_v || is_base_of_v) { + return Sm100NoSmemWarpSpecialized{}; + } + else if constexpr (is_same_v || + is_same_v) { + constexpr bool DelayTmaStore_ = false; // TMA store delay complicates tensormap updates for Ptr-Array GEMMs + return Sm100PtrArrayTmaWarpSpecialized{}; + } + else { + return Sm100TmaWarpSpecialized{}; + } +} + // Helper for building TMA warp-specialized collective epilogues, specialized by // the fusion operation performed and the dispatch policy to use. template < @@ -1017,17 +1189,7 @@ private: } } using CtaTileShape_MNK = decltype(cta_tile_shape()); - - static constexpr auto - tmem_warps() { - if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) { - return Shape<_2,_2>{}; - } - else { - return Shape<_4,_1>{}; - } - } - using TmemWarpShape_MN = decltype(tmem_warps()); + using TmemWarpShape_MN = decltype(detail::sm100_tmem_warps()); // Attempts to compute a reasonably performant epilogue tile or allows the user to provide one. static constexpr auto @@ -1041,84 +1203,10 @@ private: ElementC_, GmemStrideTypeC, ElementD, GmemStrideTypeD, Schedule, FusionOp>(); } - else if constexpr (is_same_v && - is_same_v && - size<1>(CtaTileShape_MNK{}) == 256) { - constexpr int CtaM = size<0>(CtaTileShape_MNK{}); - constexpr int WarpM = size<0>(TmemWarpShape_MN{}); - constexpr int DpFull = 32; - constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load - // Note: - // Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile. - // This is a general workable epi_tile_N which does not promise best perf. - return make_tile(Int{}, Int<128>{}); - } - else if constexpr (is_same_v) { - constexpr int CtaM = size<0>(CtaTileShape_MNK{}); - constexpr int CtaN = size<1>(CtaTileShape_MNK{}); - constexpr int WarpM = size<0>(TmemWarpShape_MN{}); - constexpr int WarpN = size<1>(TmemWarpShape_MN{}); - constexpr int MaxBits = cute::max(sizeof_bits_v, sizeof_bits_v); - - constexpr int DpFull = 32; // tmem datapaths in 1 subpartition - constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load - constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf - // Epilogues w/o residual load are less sensitive to smem allocation - // Target a fixed amount of compute per epilogue iteration - if (DisableSource) { - if (MaxBits == 4) { - // Make epilogue tile larger to reduce the epilogue iterations. - // 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same. - constexpr int ComputeElts = 8192; - return ComputeElts / M; - } - constexpr int ComputeElts = 4096; - return ComputeElts / M; - } - // Epilogues w/ residual load are more sensitive to smem allocation - // Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize - else { - if (MaxBits == 32) { - return (CtaM > 64 && CtaN <= 128) ? 16 : 32; - } - // Per-column scaling is high register pressure, reduce tile to prevent spills - else if (FusionOp::IsPerColScaleSupported) { - return 32; - } - else if (MaxBits == 16) { - return (CtaN <= 128) ? 32 : 64; - } - else { - return 64; - } - } - }(); - constexpr int N_min_C = (DisableSource || detail::is_m_major()) ? 8 * WarpN - : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type - : 128 / sizeof_bits_v * WarpN; - constexpr int N_min_D = (detail::is_m_major()) ? 8 * WarpN - : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type - : 128 / sizeof_bits_v * WarpN; - constexpr int N = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D)); - static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small"); - - // stride by tmem warp layout and return a by-mode tiler - auto tile_m = Layout>{}; - auto tile_n = Layout,Int< WarpN>>, - Stride,Int>>{}; - - return make_tile(tile_m, coalesce(tile_n)); - } else { - static_assert(cute::is_tuple::value && not is_layout::value, - "EpilogueTile must be a cute::Tile or cute::Shape"); - - EpilogueTileType epi_tile; - constexpr int M = size<0>(shape(epi_tile)); - constexpr int N = size<1>(shape(epi_tile)); - static_assert(N % 8 == 0, "Unsupported tile shape"); - - return epi_tile; + return sm100_dense_compute_tile_shape_or_override< + OpClass, CtaTileShape_MNK, EpilogueTileType, TmemWarpShape_MN, + ElementC_, GmemStrideTypeC, ElementD, GmemStrideTypeD, FusionOp::IsPerColScaleSupported>(); } } using EpilogueTile_MN = decltype(epilogue_tile()); @@ -1129,30 +1217,18 @@ private: using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{})); using AccLoadOp = decltype(detail::sm100_get_tmem_load_op< - GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>()); + GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, + FusionOp::IsBlockScaleSupported + >()); static constexpr auto dispatch_policy() { - // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation - constexpr bool ReuseSmem = sizeof_bits_v > 8; - // TMA store delay performs worse with residual loads - constexpr bool DelayTmaStore = is_void_v; - - constexpr int StagesD = cute::min(EpiTiles, 2); - constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1) - : cute::min(EpiTiles, 4); - if constexpr (is_same_v || is_same_v) { return detail::sparse::sm100_sparse_get_tma_dispatch_policy(); } - else if constexpr (is_same_v || - is_same_v) { - constexpr bool DelayTmaStore_ = false; // TMA store delay complicates tensormap updates for Ptr-Array GEMMs - return Sm100PtrArrayTmaWarpSpecialized{}; - } else { - return Sm100TmaWarpSpecialized{}; + return detail::sm100_dense_dispatch_policy(); } } @@ -1228,6 +1304,87 @@ public: >; }; +template< + class OpClass, + class MmaTileShape_MNK, + class EpilogueTileType, + class ElementAccumulator_, + class ElementC, + class ElementD, + class Schedule, + class GmemStrideTypeC, + class GmemStrideTypeD, + bool IsPerColScaleSupported, + bool IsBlockScaleSupported +> +struct Sm100EpilogueDescriptor { + using ElementAccumulator = ElementAccumulator_; + + static constexpr bool Is2SmMma = is_base_of_v || is_base_of_v; + using CtaTileShape_MNK = decltype(sm100_cta_tile_shape()); + using TileShape = CtaTileShape_MNK; + + using TmemWarpShape_MN = decltype(detail::sm100_tmem_warps()); + + using EpilogueTile = decltype( + sm100_dense_compute_tile_shape_or_override() + ); + + using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile{}))); + static constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{})); + static constexpr int FragmentSize = size(EpilogueTileShape_MN{}) / NumThreadsPerWarpGroup; + + using DispatchPolicy = decltype(sm100_dense_dispatch_policy()); + + constexpr static int StagesC = DispatchPolicy::StagesC; + constexpr static int StagesD = DispatchPolicy::StagesD; + + using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{})); + using AccLoadOp = decltype(detail::sm100_get_tmem_load_op< + GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, + IsBlockScaleSupported + >()); +}; + +// Get Stride, SmemLayout, and CopyOpS2R for AuxLoad node +template< + typename EpilogueDescriptor, + typename StrideOrLayoutTag, + typename ElementAux +> +struct Sm100AuxLoadDescriptor { + constexpr static int Stages = EpilogueDescriptor::StagesC; + using EpilogueTile = typename EpilogueDescriptor::EpilogueTile; + using Element = ElementAux; + using Stride = cutlass::detail::TagToStrideC_t; + + using SmemLayoutAtom = decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom< + Stride, ElementAux, EpilogueTile>()); + + using CopyOpS2R = decltype(detail::sm100_get_smem_load_op< + Stride, ElementAux, typename EpilogueDescriptor::ElementAccumulator, typename EpilogueDescriptor::AccLoadOp>()); +}; + +// Get Stride, SmemLayout, and CopyOpS2R for AuxStore node +template< + typename EpilogueDescriptor, + typename StrideOrLayoutTag, + typename ElementAux +> +struct Sm100AuxStoreDescriptor { + constexpr static int Stages = EpilogueDescriptor::StagesD; + using EpilogueTile = typename EpilogueDescriptor::EpilogueTile; + using Element = ElementAux; + using Stride = cutlass::detail::TagToStrideC_t; + + using SmemLayoutAtom = decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom< + Stride, ElementAux, EpilogueTile>()); + + using CopyOpR2S = decltype(detail::sm100_get_smem_store_op< + Stride, ElementAux, typename EpilogueDescriptor::ElementAccumulator, typename EpilogueDescriptor::AccLoadOp>()); +}; + } // namespace detail /////////////////////////////////////////////////////////////////////////////// @@ -1304,17 +1461,7 @@ private: } } using CtaTileShape_MNK = decltype(cta_tile_shape()); - - static constexpr auto - tmem_warps() { - if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) { - return Shape<_2,_2>{}; - } - else { - return Shape<_4,_1>{}; - } - } - using TmemWarpShape_MN = decltype(tmem_warps()); + using TmemWarpShape_MN = decltype(detail::sm100_tmem_warps()); static constexpr auto epilogue_tile() { @@ -1338,20 +1485,15 @@ private: using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTile{}, TmemWarpShape_MN{})); using AccLoadOp = decltype(detail::sm100_get_tmem_load_op< - GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>()); + GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, + FusionOp::IsBlockScaleSupported + >()); static constexpr int FragmentSize = size(EpilogueTile{}) / NumThreadsPerWarpGroup; - static constexpr auto - dispatch_policy() { - if constexpr (std::is_base_of_v || - std::is_base_of_v) { - return Sm100PtrArrayNoSmemWarpSpecialized{}; - } - else { - return Sm100NoSmemWarpSpecialized{}; - } - } - using DispatchPolicy = decltype(dispatch_policy()); + using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile{}))); + static constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{})); + + using DispatchPolicy = decltype(detail::sm100_dense_dispatch_policy()); static constexpr auto fusion_callbacks() { diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp index af53a1c6..77ef3ed2 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -507,8 +507,7 @@ public: int thread_idx, TensorStorage& shared_tensors, TensorMapC const& load_tensormap, - int subtile_idx=-1, - bool wait_until_load_finishes = false) { + int subtile_idx=-1) { using namespace cute; // Indexing variables @@ -595,12 +594,6 @@ public: // Post-loop fusion callback entry point pld_callbacks.end(); - if (wait_until_load_finishes && did_load) { - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state = - {last_load_producer_state.index(), !last_load_producer_state.phase(), last_load_producer_state.count()}; - load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state); - } - return load_pipe_producer_state; } diff --git a/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl new file mode 100644 index 00000000..72cc3061 --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl @@ -0,0 +1,274 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +#include "cutlass/gemm/collective/collective_builder_decl.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template < + int CapacityBytes, + class ElementA, + class ElementB, + class TileShapeMNK, + class TileShapeSFA, + class TileShapeSFB, + int stages +> +constexpr int +sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync(StageCount stage_count) { + return stages; +} + +template < + int CapacityBytes, + class ElementA, + class ElementB, + class TileShapeMNK, + class TileShapeSFA, + class TileShapeSFB, + int carveout_bytes +> +constexpr int +sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync(StageCountAutoCarveout stage_count) { + // For MXF8F6F4 MMA, ElementA/B will be passed in as uint8_t + // Each stage include (CollectiveMma::SharedStorage) + // 1. smem for A and smem for B (CollectiveMma::SharedStorage::TensorStorage) + // 2. one MainloopPipeline = PipelineTmaUmmaAsync (CollectiveMma::SharedStorage::SharedStorage) + // 3. smem for SFB and smem for SFB (CollectiveMma::SharedStorage::TensorStorage, independent of input size b.c. sizeof(sf) is fixed) + constexpr auto mainloop_pipeline_bytes = + sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage) + + sizeof(typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage); + + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr auto stage_sfa_bytes = size(filter_zeros(TileShapeSFA{})); + constexpr auto stage_sfb_bytes = size(filter_zeros(TileShapeSFB{})); + + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + static_cast(mainloop_pipeline_bytes + stage_sfa_bytes + stage_sfb_bytes); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementPairA, + class GmemLayoutATag, + int AlignmentA, + class ElementPairB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class BuilderScheduleTag +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassBlockScaledTensorOp, + ElementPairA, + GmemLayoutATag, + AlignmentA, + ElementPairB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape (_1, _1, _1) + StageCountType, + BuilderScheduleTag, + cute::enable_if_t > +> +{ + using ElementSFA = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementSFB = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementA = typename detail::blockscaled::blockscaled_type::data_type; + using ElementB = typename detail::blockscaled::blockscaled_type::data_type; + using ElementSF = ElementSFA; + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + static_assert(cute::is_static_v, "TileShape has to be static"); + static_assert(detail::blockscaled::check_input_datatypes(), "Incorrect input types"); + + static constexpr bool is_2sm = false; // detail::blockscaled::is_2sm(); + static constexpr auto Instr = detail::blockscaled::select_instr(); + + using TiledMma = typename cutlass::gemm::collective::detail::TrivialBlockscaledMma::type; + + static constexpr bool UseMxf8f6f4 = Instr == detail::blockscaled::BlockScaledInstr::MXF4F6F8; + + static_assert(UseMxf8f6f4 || (cutlass::gemm::detail::is_k_major_A() && cutlass::gemm::detail::is_k_major_B()), "Only MMA.MXF8F6F4 supports non-K major inputs"); + + // Data type used by MMA instruction + using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + + static_assert(detail::sm1xx_gemm_check_for_f8f6f4_mix8bit_requirement(), + "TileSize and MNK Major does not met with MMA Mix 8-bit TMA load requirement" ); + + static constexpr uint32_t SFVectorSize = TiledMma::SFVecSize; + + using ElementAMma_SmemAllocType = cute::conditional_t; + using ElementBMma_SmemAllocType = cute::conditional_t; + + // using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma< + // ElementAMma, ElementBMma, ElementAccumulator, + // decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK, + // UmmaMajorA, UmmaMajorB, BuilderScheduleTag>()); + + using AtomThrID = typename TiledMma::AtomThrID; + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + // Assigning 4 warps for mainloop load of B + static constexpr int NumLoadThreadsCpAsync = 128; + + + using SmemShapeA_M = decltype(shape_div(shape<0>(TileShape_MNK{}), shape_div(shape<0>(TileShape_MNK{}), size<0>(TileShape_MNK{}) / size(AtomThrID{})))); + using SmemShapeA_K = decltype(cute::get<2>(TileShape_MNK{})); + + + using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorA, ElementAMma_SmemAllocType, SmemShapeA_M, SmemShapeA_K>()); + + using AlignmentTypeB = cute::uint_byte_t(cutlass::sizeof_bits::value) * AlignmentB / 8>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomB, NumLoadThreadsCpAsync, AlignmentB, TagToStrideB_t, + decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorB, ElementBMma_SmemAllocType, BlockTileB_N, BlockTileB_K>()); + + using GmemTiledCopySFA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + using GmemTiledCopySFB = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_SFB( + ClusterShape_MNK{}, AtomThrID{})); + + using GmemTiledCopyPairA = decltype(cute::make_tuple(GmemTiledCopyA{}, GmemTiledCopySFA{})); + using GmemTiledCopyPairB = decltype(cute::make_tuple(GmemTiledCopyB{}, GmemTiledCopySFB{})); + + using SmemLayoutAtomSFA = decltype(Sm1xxBlkScaledConfig::deduce_smem_layoutSFA(TiledMma{}, TileShape_MNK{})); + using SmemLayoutAtomSFB = decltype(Sm1xxBlkScaledConfig::deduce_smem_layoutSFB(TiledMma{}, TileShape_MNK{})); + using SmemLayoutAtomsA = decltype(cute::make_tuple(SmemLayoutAtomA{}, SmemLayoutAtomSFA{})); + using SmemLayoutAtomsB = decltype(cute::make_tuple(SmemLayoutAtomB{}, SmemLayoutAtomSFB{})); + + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using InternalStrideA = cute::remove_pointer_t; + using InternalStrideB = cute::remove_pointer_t; + using InternalLayoutSFA = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFA()); + using InternalLayoutSFB = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFB()); + using LayoutSFA = cute::conditional_t, InternalLayoutSFA, InternalLayoutSFA *>; + using LayoutSFB = cute::conditional_t, InternalLayoutSFB, InternalLayoutSFB *>; + using StridePairA = decltype(cute::make_tuple(StrideA{}, LayoutSFA{})); + using StridePairB = decltype(cute::make_tuple(StrideB{}, LayoutSFB{})); + + static constexpr uint32_t AccumulatorPipelineStageCount = 2; + // Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding. + static constexpr uint32_t SchedulerPipelineStageCount = AccumulatorPipelineStageCount + 1; + + // AccumulatorPipeline = PipelineUmmaAsync + static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( AccumulatorPipelineStorage + + CLCPipelineStorage + + CLCResponseStorage); + // Reduce SMEM capacity available for buffers considering barrier allocations. + static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + using SmemTileShape = cute::Shape; + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync< + Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); + static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, B, SFA, and SFB."); + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + cutlass::gemm::MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK>, + TileShape_MNK, + cute::tuple, + StridePairA, + cute::tuple, + StridePairB, + TiledMma, + GmemTiledCopyPairA, + SmemLayoutAtomsA, + void, + cute::identity, + GmemTiledCopyPairB, + SmemLayoutAtomsB, + void, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl index 3556fad6..68600c67 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl @@ -120,6 +120,7 @@ struct CollectiveBuilder< BuilderScheduleTag, cute::enable_if_t< // Blockscaled Gemm + (not cute::is_same_v) && (cute::is_base_of_v || cute::is_same_v) && diff --git a/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl new file mode 100644 index 00000000..5fd1201a --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl @@ -0,0 +1,171 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +#include "cutlass/gemm/collective/collective_builder_decl.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class BuilderScheduleTag +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape (_1, _1, _1) + StageCountType, + BuilderScheduleTag, + cute::enable_if_t > +> +{ + static_assert(cute::is_static_v, "TileShape has to be static"); + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + // Data type used by MMA instruction + using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + + using ElementAMma_SmemAllocType = cute::conditional_t < 8, uint8_t, ElementAMma>; + using ElementBMma_SmemAllocType = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma< + ElementAMma, ElementBMma, ElementAccumulator, + decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK, + UmmaMajorA, UmmaMajorB, BuilderScheduleTag>()); + + using AtomThrID = typename TiledMma::AtomThrID; + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + // Assigning 4 warps for mainloop load of B + static constexpr int NumLoadThreadsCpAsync = 128; + + + using SmemShapeA_M = decltype(shape_div(shape<0>(TileShape_MNK{}), shape_div(shape<0>(TileShape_MNK{}), size<0>(TileShape_MNK{}) / size(AtomThrID{})))); + using SmemShapeA_K = decltype(cute::get<2>(TileShape_MNK{})); + + + using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorA, ElementAMma_SmemAllocType, SmemShapeA_M, SmemShapeA_K>()); + + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * AlignmentB>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomB, NumLoadThreadsCpAsync, AlignmentB, TagToStrideB_t, + decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorB, ElementBMma_SmemAllocType, BlockTileB_N, BlockTileB_K>()); + + static constexpr uint32_t AccumulatorPipelineStageCount = 2; + // Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding. + static constexpr uint32_t SchedulerPipelineStageCount = AccumulatorPipelineStageCount + 1; + + // AccumulatorPipeline = PipelineUmmaAsync + static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( AccumulatorPipelineStorage + + CLCPipelineStorage + + CLCResponseStorage); + // Reduce SMEM capacity available for buffers considering barrier allocations. + static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + using SmemTileShape = cute::Shape; + using MainloopPipelineStorage = typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage; + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override< + Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, MainloopPipelineStorage>(StageCountType{}); + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + cutlass::gemm::MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK>, + TileShape_MNK, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + void, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + void, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl index 5edf637e..dfd4fece 100644 --- a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl @@ -184,6 +184,7 @@ struct CollectiveBuilder< not cute::is_complex_v && not cute::is_complex_v && // Dense Gemm / PtrArrayDenseGemm ( + (not cute::is_same_v) && (not cute::is_same_v) && (cute::is_base_of_v || cute::is_same_v)) && diff --git a/include/cutlass/gemm/collective/builders/sm1xx_common.inl b/include/cutlass/gemm/collective/builders/sm1xx_common.inl index f63842c0..629f95dc 100644 --- a/include/cutlass/gemm/collective/builders/sm1xx_common.inl +++ b/include/cutlass/gemm/collective/builders/sm1xx_common.inl @@ -502,6 +502,7 @@ check_input_datatypes() { || (cute::is_same_v) || (cute::is_same_v) || (cute::is_same_v) + || (cute::is_same_v) // SM100 BS ptr_array || (cute::is_same_v) || (cute::is_same_v) @@ -578,6 +579,8 @@ check_input_datatypes() { ((SfVectorSizeA == 32 && cute::is_same_v) || (SfVectorSizeA == 32 && cute::is_same_v) || (SfVectorSizeA == 32 && cute::is_same_v) + || (SfVectorSizeA == 32 && cute::is_same_v) + || (SfVectorSizeA == 32 && cute::is_same_v) || (SfVectorSizeA == 32 && cute::is_base_of_v) || (SfVectorSizeA == 32 && cute::is_base_of_v) || (SfVectorSizeA == 64 && cute::is_base_of_v) diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index c75af3ac..a1ea257e 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -1069,10 +1069,10 @@ struct CollectiveBuilder< StageCountType, KernelScheduleType, cute::enable_if_t< - (cute::is_same_v or - cute::is_same_v or - cute::is_same_v or - cute::is_same_v) and + (cute::is_same_v or + cute::is_same_v or + cute::is_same_v or + cute::is_same_v) and not detail::is_use_rmem_A() > > { @@ -1105,7 +1105,7 @@ struct CollectiveBuilder< cute::is_base_of_v); static constexpr bool IsFP8Input = detail::is_input_fp8(); - static_assert(IsFP8Input, "Warp Specialized gemm with FP8 BlockScaled Accumulator is only compatible with FP8 Blocked Scaled version right now."); + static_assert(IsFP8Input, "Warp Specialized gemm with FP8 Blockwise (Software) Scaling is only compatible with FP8 inputs version right now."); // For fp32 types, map to tf32 MMA value type using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; @@ -1146,8 +1146,8 @@ struct CollectiveBuilder< static constexpr int PipelineStages = detail::compute_stage_count_with_blockwise_scale(StageCountType{}); using DispatchPolicy = cute::conditional_t, - MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8>; + MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise, + MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8>; using SmemCopyAtomA = void; using SmemCopyAtomB = void; diff --git a/include/cutlass/gemm/collective/collective_builder.hpp b/include/cutlass/gemm/collective/collective_builder.hpp index ad2667dd..83a65059 100644 --- a/include/cutlass/gemm/collective/collective_builder.hpp +++ b/include/cutlass/gemm/collective/collective_builder.hpp @@ -49,6 +49,8 @@ #include "cutlass/gemm/collective/builders/sm100_simt_builder.inl" #include "cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_mma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl" diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index f698b79c..9e3ae800 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -65,6 +65,8 @@ #include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp" #include "cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp" #include "cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp" #include "cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp" #include "cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp" #include "cutlass/gemm/collective/sm120_mma_tma.hpp" diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp new file mode 100644 index 00000000..344de4d3 --- /dev/null +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp @@ -0,0 +1,1043 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/arch/memory.h" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass/gemm/collective/collective_mma_decl.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> { + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + // Statically asserting to ensure only 1x1x1 cluster shape & 1sm setup is received + static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment SM100 GEMM only supports 1SM MMA"); + static_assert(size(ClusterShape{}) == 1, "CPASYNC does not support multicast so the cluster shape is restricted to 1, 1, 1"); + + static_assert(size(typename TiledMma::AtomThrID{}) == 1); + + using DispatchPolicy = MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + // TileShape refers to MmaTileShape to adapt for runtime cluster + using TileShape = TileShape_; + using TiledMma_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + static_assert(!IsOverlappingAccum, "TMA+CPASYNC kernel currently only supports non-overlapping accum."); + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + // Define A and B block shapes + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + // using LoadShapeA_MK = decltype(select<0,2>(TileShape{})); + using LoadShapeB_NK = decltype(select<1,2>(TileShape{})); + + // CtaShape_MNK is queried from collective in all kernel layers + using CtaShape_MNK = TileShape; + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 64 or + shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 64/128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static constexpr int IsCtaN64 = shape<1>(CtaShape_MNK{}) == 64; + static int constexpr CTA_N_SF = cutlass::ceil_div(size<1>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}; + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + + using ElementA = remove_cvref_t(ElementPairA{}))>; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = remove_cvref_t(StridePairA{}))>; + using ElementB = remove_cvref_t(ElementPairB{}))>; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = remove_cvref_t(StridePairB{}))>; + + static constexpr bool IsRuntimeDataTypeA = cute::is_same_v; + static constexpr bool IsRuntimeDataTypeB = cute::is_same_v; + + static_assert(IsRuntimeDataTypeA == IsRuntimeDataTypeB, + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipelineTMA = cutlass::PipelineTmaUmmaAsync; + using MainloopPipelineTMAState = typename MainloopPipelineTMA::PipelineState; + + using MainloopPipelineCpAsync = cutlass::PipelineUmmaConsumerAsync; + using MainloopPipelineCpAsyncState = typename MainloopPipelineCpAsync::PipelineState; + + // static_assert(size(GmemTiledCopyA{}) == size(GmemTiledCopyB{}), "A and B GmemTiledCopy should share the same thread count"); + static constexpr int NumLoadThreadsCpAsync = size(GmemTiledCopyB{}); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + + using MmaSmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using LoadSmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + append(LoadShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + + using TmaInternalElementA = cute::conditional_t; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t::Type; + using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t::Type; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + using PipelineStorageTMA = typename MainloopPipelineTMA::SharedStorage; + using PipelineStorageCpAsync = typename MainloopPipelineCpAsync::SharedStorage; + + struct PipelineStorage : cute::aligned_struct<16, _0> { + alignas(16) PipelineStorageTMA tma; + alignas(16) PipelineStorageCpAsync cpasync; + } pipelines; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + static constexpr uint32_t ATmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v); + static constexpr uint32_t TmaTransactionBytes = ATmaTransactionBytes + SFTransactionBytes; + + template + struct TmemStorage { + AccTensor accumulators; + SfaTensor tCtSFA; + SfbTensor tCtSFB; + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementSF const* ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const* ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}), + make_tile(typename TiledMma::AtomThrID{}))); + using ClusterLayoutSfb_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}), + make_tile(typename TiledMma_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + using TMA_SFA = decltype(make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), LayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + using TMA_SFB = decltype(make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), LayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMma_SF{}, + ClusterLayoutSfb_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params) + : layout_SFA_(params.layout_SFA) + , layout_SFB_(params.layout_SFB) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_sfa = make_tensor(args.ptr_SFA, args.layout_SFA); + Tensor tensor_sfb = make_tensor(args.ptr_SFB, args.layout_SFB); + + auto cluster_layout_vmnk = tiled_divide(make_layout(ClusterShape{}), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_layout_sfb_vmnk = tiled_divide(make_layout(ClusterShape{}), make_tile(typename TiledMma_SF::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + typename Params::TMA_SFA tma_load_sfa = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + typename Params::TMA_SFB tma_load_sfb = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMma_SF{}, + cluster_layout_sfb_vmnk); + + return { + tma_load_a, + tma_load_sfa, + tma_load_sfb, + args.ptr_B, + args.dB, + args.layout_SFA, + args.layout_SFB, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + // static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + + bool implementable = true; + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for CpAsync.\n"); + } + + // Check for SFA SFB layout requirement + const auto layout_sfa_ref = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + const auto layout_sfb_ref = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + implementable = implementable && (layout_sfa_ref == args.layout_SFA); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_SFA mismatch, layout_SFA needs to be K-major\n"); + } + + implementable = implementable && (layout_sfb_ref == args.layout_SFB); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_SFB mismatch, layout_SFB needs to be K-major\n"); + } + + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_sfa_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_sfb_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + Tensor tCtSFA = make_tensor(shape(SmemLayoutAtomSFA{})); + Tensor tCtSFB = make_tensor(shape(SmemLayoutAtomSFB{})); + + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + tmem_storage.tCtSFA = tCtSFA; + tmem_storage.tCtSFB = tCtSFB; + + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + tmem_storage.tCtSFA.data() = tmem_storage.accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators); + tmem_storage.tCtSFB.data() = tmem_storage.tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtSFA); + } + + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + template + CUTLASS_DEVICE auto + load_init_tma( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + + // Represent the full tensor of Scale factors + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(layout_SFA_)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else if constexpr (IsCtaN64) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), + make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + } + }(); + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + + ThrMMA cta_mma = TiledMma{}.get_slice(0); + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + + ThrMMA cta_mma_sfb = TiledMma_SF{}.get_slice(blockIdx.x % size(typename TiledMma_SF::AtomThrID{})); + Tensor tCgSFA_mkl = cta_mma.partition_A(gSFA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgSFB_nkl = cta_mma_sfb.partition_B(gSFB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(ClusterShape{}); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(0); + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(0); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sSFA), group_modes<0,3>(tCgSFA_mkl)); + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(sSFB), group_modes<0,3>(tCgSFB_nkl)); + + return cute::make_tuple( + shape<3>(gA_mkl), // for scheduler + tAgA_mkl, tAsA, // for input tensor values + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB // for input scale factor tensor values + ); + } + + template + CUTLASS_DEVICE auto + load_init_cpasync( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TileScheduler const& scheduler, + typename TileScheduler::WorkTileInfo const& work_tile_info) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // convert to subptr iterator if necessary + auto ptr_B = recast_ptr(params.ptr_B); + // Represent the full tensors + Tensor mB_nkl = make_tensor(make_gmem_ptr(ptr_B), make_shape(N,K,L), params.dB); //(n,k,l) + // Partition for cpasync + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Build the coordinate tensors with the same shape as input matrices + Tensor cB_nk = make_identity_tensor(make_shape(N,K)); + // Slice the coordinate tensors in the same way as A/B tensor partitioning + Tensor cgB_nk = local_tile(cB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), LoadSmemLayoutB{}); + + GmemTiledCopyB gmem_to_smem_b_tiled_copy; + + int thread_idx = threadIdx.x % NumLoadThreadsCpAsync; + auto thr_copy_b = gmem_to_smem_b_tiled_copy.get_slice(thread_idx); + + return cute::make_tuple( + gB_nkl, cgB_nk, sB, + // problem_shape_MNKL, + gmem_to_smem_b_tiled_copy, thr_copy_b); + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + Params const& params, + TmemStorage tmem_storage, + // [[maybe_unused]] cute::tuple, cute::Tensor> const& accumulators_pair, + TensorStorage& shared_tensors) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), MmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + // + // Scale Factor + // + Tensor tCtSFA = tmem_storage.tCtSFA; + Tensor tCtSFB = tmem_storage.tCtSFB; + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); + auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); + auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + return cute::make_tuple( + tiled_mma, + tCrA, tCrB, + tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t + + // debug + // , sA, sB, tCsSFA, tCsSFB + ); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + // class KTileCount, + // class GTensorPartitionedA, + // class STensorA, + class TileCoordMNKL, + class KTileIterator, + class... TLoadParams // see load_init_tma + > + CUTLASS_DEVICE auto + load_tma( + MainloopPipelineTMA mainloop_pipeline, + MainloopPipelineTMAState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + // Unpack from load_inputs + // KTileCount k_tiles = get<0>(load_inputs); + // GTensorPartitionedA tAgA_mkl = get<1>(load_inputs); + // STensorA tAsA = get<2>(load_inputs); + + auto [k_tiles, + tAgA_mkl, tAsA, + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB] = load_inputs; + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipelineTMA::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_sfa_->with(*tma_barrier), tAgSFA(_,*k_tile_iter), tAsSFA(_,write_stage)); + copy(observed_tma_load_sfb_->with(*tma_barrier), tBgSFB(_,*k_tile_iter), tBsSFB(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + + template < + // class GTensorB, + // class CTensorB, + // class STensorB, + // class ProblemShape_MNKL, + // class TiledCopyB, + // class ThreadCopyB, + class TileCoordMNKL, + class KTileIterator, + class ProblemShape_MNKL, + class... TParams + > + CUTLASS_DEVICE auto + load_cpasync( + Params const& params, + MainloopPipelineCpAsync mainloop_pipeline, + MainloopPipelineCpAsyncState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + ProblemShape_MNKL effective_shape + ) { + + // Unpack from load_inputs + // GTensorB tBgB_nkl = get<0>(load_inputs); + // CTensorB cgB_nk = get<1>(load_inputs); + // STensorB sB = get<2>(load_inputs); + // ProblemShape_MNKL problem_shape_MNKL = get<3>(load_inputs); + // TiledCopyB gmem_to_smem_b_tiled_copy = get<4>(load_inputs); + // ThreadCopyB thr_copy_b = get<5>(load_inputs); + + // auto [M,N,K,L] = problem_shape_MNKL; + + auto [ + tBgB_nkl, cgB_nk, sB, + // problem_shape_MNKL, + gmem_to_smem_b_tiled_copy, thr_copy_b] = load_inputs; + + auto [M,N,K,L] = effective_shape; + + // Slice out the work coord from partitioned tensors + Tensor gB_in = tBgB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + // Repeat slicing out coordinate tensor exactly the same as input tensor does + Tensor cgB_nk_in = cgB_nk(_, _, get<1>(cta_coord_mnkl), _); + + auto k_residue = K - size<1>(gB_in) * size<2>(gB_in); // K - BLK_K * k is negative + + Tensor gB = gB_in; + Tensor cB = cgB_nk_in; + + auto tBgB = thr_copy_b.partition_S(gB); + auto tBsB = thr_copy_b.partition_D(sB); + + // Allocate predicate tensors for n + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + Tensor tBcB_nk = thr_copy_b.partition_S(cgB_nk_in); + Tensor tBcB = thr_copy_b.partition_S(cB); + + // Copy gmem to smem for *k_tile_iter, predicating for k residue + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + + // Repeating on predicators with the same operations on tBgB + Tensor tBcBk = tBcB(_,_,_,*k_tile_iter); + + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = elem_less(get<0>(tBcBk(0,n,0)), N); // blk_n coord < N + } + + // we will process the last tile after the mainloop + if (k_residue != 0) { + --k_tile_count; + } + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + copy_if(gmem_to_smem_b_tiled_copy, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + --k_tile_count; + ++k_tile_iter; + ++mainloop_pipe_producer_state; + } + + // last tile with predication on k to account for residue + // For performance consideration, + // this predicated block for K-tail is only activated when there is k-residue + if (k_residue != 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (int(get<1>(tBcBk(0,0,k))) >= 0) { // blk_k coord < K + copy_if(gmem_to_smem_b_tiled_copy, tBpB(_,k), tBgB(_,_,k,*k_tile_iter), tBsB(_,_,k,write_stage)); + } + else { + clear(tBsB(_,_,k,write_stage)); + } + } + ++k_tile_iter; + --k_tile_count; + + // UNLOCK mainloop_pipe_producer_state + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + + // Advance mainloop_pipe_producer_state + ++mainloop_pipe_producer_state; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail_tma(MainloopPipelineTMA mainloop_pipeline, MainloopPipelineTMAState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + CUTLASS_DEVICE void + load_tail_cpasync(MainloopPipelineCpAsync mainloop_pipeline, MainloopPipelineCpAsyncState mainloop_pipe_producer_state) { + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class CtaTileCoord, + class... TMmaParams + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::tuple> const& accumulators_pair, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, tCrA, tCrB, tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_s2t, + thr_tCtSFA_s2t, tiled_copy_s2t_SFB, + thr_tCsSFB_s2t, thr_tCtSFB_s2t + + // debug + // , sA, sB, tCsSFA, tCsSFB + ] = mma_inputs; + + auto [mainloop_pipeline_tma, mainloop_pipeline_cpasync, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + auto tCtSFB_mma = [tCtSFB = tCtSFB, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB; + if (size<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else if constexpr (IsCtaN64) { + // Move in increments of 64 columns of SFB + auto tCtSFB_tmp = tCtSFB; + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + (size<1>(cta_tile_coord) % 2) * 2; + return tCtSFB_tmp; + } + else { + return tCtSFB; + } + }(); + + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + mainloop_pipeline_tma.consumer_wait(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_wait(mainloop_pipe_cpasync_consumer_state); + + int read_stage_tma = mainloop_pipe_tma_consumer_state.index(); + int read_stage_cpasync = mainloop_pipe_cpasync_consumer_state.index(); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage_tma), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage_tma), thr_tCtSFB_s2t); + } + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage_tma), + tCrB(_,_,k_block,read_stage_cpasync), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline_tma.consumer_release(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_release(mainloop_pipe_cpasync_consumer_state); + --k_tile_count; + ++mainloop_pipe_tma_consumer_state; + ++mainloop_pipe_cpasync_consumer_state; + } + + return cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; + + LayoutSFA layout_SFA_; + LayoutSFB layout_SFB_; + + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + // ClusterShape cluster_shape_; + // uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp new file mode 100644 index 00000000..c31ec335 --- /dev/null +++ b/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp @@ -0,0 +1,758 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/arch/memory.h" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass/gemm/collective/collective_mma_decl.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + // Statically asserting to ensure only 1x1x1 cluster shape & 1sm setup is received + static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment SM100 GEMM only supports 1SM MMA"); + static_assert(size(ClusterShape{}) == 1, "CPASYNC does not support multicast so the cluster shape is restricted to 1, 1, 1"); + + static_assert(size(typename TiledMma::AtomThrID{}) == 1); + + using DispatchPolicy = MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + // TileShape refers to MmaTileShape to adapt for runtime cluster + using TileShape = TileShape_; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + // Define A and B block shapes + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + // using LoadShapeA_MK = decltype(select<0,2>(TileShape{})); + using LoadShapeB_NK = decltype(select<1,2>(TileShape{})); + + // CtaShape_MNK is queried from collective in all kernel layers + using CtaShape_MNK = TileShape; + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = StrideA_; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + + static constexpr bool IsRuntimeDataTypeA = cute::is_same_v; + static constexpr bool IsRuntimeDataTypeB = cute::is_same_v; + + static_assert(IsRuntimeDataTypeA == IsRuntimeDataTypeB, + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipelineTMA = cutlass::PipelineTmaUmmaAsync; + using MainloopPipelineTMAState = typename MainloopPipelineTMA::PipelineState; + + using MainloopPipelineCpAsync = cutlass::PipelineUmmaConsumerAsync; + using MainloopPipelineCpAsyncState = typename MainloopPipelineCpAsync::PipelineState; + + // static_assert(size(GmemTiledCopyA{}) == size(GmemTiledCopyB{}), "A and B GmemTiledCopy should share the same thread count"); + static constexpr int NumLoadThreadsCpAsync = size(GmemTiledCopyB{}); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + + using MmaSmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using LoadSmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + append(LoadShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + using TmaInternalElementA = cute::conditional_t, cutlass::tfloat32_t, ElementAMma>; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorageTMA = typename MainloopPipelineTMA::SharedStorage; + using PipelineStorageCpAsync = typename MainloopPipelineCpAsync::SharedStorage; + + struct PipelineStorage : cute::aligned_struct<16, _0> { + alignas(16) PipelineStorageTMA tma; + alignas(16) PipelineStorageCpAsync cpasync; + } pipelines; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v); + + template + struct TmemStorage { + AccTensor accumulators; + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params) + : runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + + observed_tma_load_a_ = ¶ms.tma_load_a; + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + + auto cluster_layout_vmnk = tiled_divide(make_layout(ClusterShape{}), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + return { + tma_load_a, + args.ptr_B, + args.dB, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + + bool implementable = true; + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for CpAsync.\n"); + } + + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + template + CUTLASS_DEVICE auto + load_init_tma( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + + ThrMMA cta_mma = TiledMma{}.get_slice(0); + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(ClusterShape{}); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(0); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + return cute::make_tuple( + shape<3>(gA_mkl), // for scheduler + tAgA_mkl, tAsA // for input tensor values + ); + } + + template + CUTLASS_DEVICE auto + load_init_cpasync( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TileScheduler const& scheduler, + typename TileScheduler::WorkTileInfo const& work_tile_info) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors + Tensor mB_nkl = make_tensor(make_gmem_ptr(params.ptr_B), make_shape(N,K,L), params.dB); //(n,k,l) + // Partition for cpasync + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Build the coordinate tensors with the same shape as input matrices + Tensor cB_nk = make_identity_tensor(make_shape(N,K)); + // Slice the coordinate tensors in the same way as A/B tensor partitioning + Tensor cgB_nk = local_tile(cB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), LoadSmemLayoutB{}); + + GmemTiledCopyB gmem_to_smem_b_tiled_copy; + + int thread_idx = threadIdx.x % NumLoadThreadsCpAsync; + auto thr_copy_b = gmem_to_smem_b_tiled_copy.get_slice(thread_idx); + + return cute::make_tuple( + gB_nkl, cgB_nk, sB, + gmem_to_smem_b_tiled_copy, thr_copy_b); + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + Params const& params, + [[maybe_unused]] TmemStorage tmem_storage, + // [[maybe_unused]] cute::tuple, cute::Tensor> const& accumulators_pair, + TensorStorage& shared_tensors) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), MmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + return cute::make_tuple(tiled_mma, tCrA, tCrB); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class KTileCount, + class GTensorPartitionedA, + class STensorA, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_tma( + MainloopPipelineTMA mainloop_pipeline, + MainloopPipelineTMAState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + // Unpack from load_inputs + KTileCount k_tiles = get<0>(load_inputs); + GTensorPartitionedA tAgA_mkl = get<1>(load_inputs); + STensorA tAsA = get<2>(load_inputs); + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipelineTMA::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + + template < + // class GTensorB, + // class CTensorB, + // class STensorB, + // class ProblemShape_MNKL, + // class TiledCopyB, + // class ThreadCopyB, + class TileCoordMNKL, + class KTileIterator, + class ProblemShape_MNKL, + class... TParams + > + CUTLASS_DEVICE auto + load_cpasync( + Params const& params, + MainloopPipelineCpAsync mainloop_pipeline, + MainloopPipelineCpAsyncState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + ProblemShape_MNKL effective_shape + ) { + + // Unpack from load_inputs + // GTensorB tBgB_nkl = get<0>(load_inputs); + // CTensorB cgB_nk = get<1>(load_inputs); + // STensorB sB = get<2>(load_inputs); + // ProblemShape_MNKL problem_shape_MNKL = get<3>(load_inputs); + // TiledCopyB gmem_to_smem_b_tiled_copy = get<4>(load_inputs); + // ThreadCopyB thr_copy_b = get<5>(load_inputs); + + auto [ + tBgB_nkl, cgB_nk, sB, + // problem_shape_MNKL, + gmem_to_smem_b_tiled_copy, thr_copy_b] = load_inputs; + + auto [M,N,K,L] = effective_shape; + + // Slice out the work coord from partitioned tensors + Tensor gB_in = tBgB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + // Repeat slicing out coordinate tensor exactly the same as input tensor does + Tensor cgB_nk_in = cgB_nk(_, _, get<1>(cta_coord_mnkl), _); + + auto k_residue = K - size<1>(gB_in) * size<2>(gB_in); // K - BLK_K * k is negative + + Tensor gB = gB_in; + Tensor cB = cgB_nk_in; + + auto tBgB = thr_copy_b.partition_S(gB); + auto tBsB = thr_copy_b.partition_D(sB); + + // Allocate predicate tensors for n + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + Tensor tBcB_nk = thr_copy_b.partition_S(cgB_nk_in); + Tensor tBcB = thr_copy_b.partition_S(cB); + + // Copy gmem to smem for *k_tile_iter, predicating for k residue + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + + // Repeating on predicators with the same operations on tBgB + Tensor tBcBk = tBcB(_,_,_,*k_tile_iter); + + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = elem_less(get<0>(tBcBk(0,n,0)), N); // blk_n coord < N + } + + // we will process the last tile after the mainloop + if (k_residue != 0) { + --k_tile_count; + } + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + copy_if(gmem_to_smem_b_tiled_copy, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + --k_tile_count; + ++k_tile_iter; + ++mainloop_pipe_producer_state; + } + + // last tile with predication on k to account for residue + // For performance consideration, + // this predicated block for K-tail is only activated when there is k-residue + if (k_residue != 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (int(get<1>(tBcBk(0,0,k))) >= 0) { // blk_k coord < K + copy_if(gmem_to_smem_b_tiled_copy, tBpB(_,k), tBgB(_,_,k,*k_tile_iter), tBsB(_,_,k,write_stage)); + } + else { + clear(tBsB(_,_,k,write_stage)); + } + } + ++k_tile_iter; + --k_tile_count; + + // UNLOCK mainloop_pipe_producer_state + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + + // Advance mainloop_pipe_producer_state + ++mainloop_pipe_producer_state; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail_tma(MainloopPipelineTMA mainloop_pipeline, MainloopPipelineTMAState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + CUTLASS_DEVICE void + load_tail_cpasync(MainloopPipelineCpAsync mainloop_pipeline, MainloopPipelineCpAsyncState mainloop_pipe_producer_state) { + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::tuple> const& accumulators_pair, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + auto [mainloop_pipeline_tma, mainloop_pipeline_cpasync, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + mainloop_pipeline_tma.consumer_wait(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_wait(mainloop_pipe_cpasync_consumer_state); + + int read_stage_tma = mainloop_pipe_tma_consumer_state.index(); + int read_stage_cpasync = mainloop_pipe_cpasync_consumer_state.index(); + + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage_tma), tCrB(_,_,k_block,read_stage_cpasync), accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline_tma.consumer_release(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_release(mainloop_pipe_cpasync_consumer_state); + --k_tile_count; + ++mainloop_pipe_tma_consumer_state; + ++mainloop_pipe_cpasync_consumer_state; + } + + return cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp index 5a5bf458..5adc2b81 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp @@ -248,7 +248,7 @@ public: using Load2MmaPipeline = cutlass::PipelineTmaUmmaAsync< DispatchPolicy::Load2TransformPipelineStageCount, - ClusterShape, + ClusterShape, AtomThrShapeMNK>; using Load2MmaPipelineState = typename Load2MmaPipeline::PipelineState; @@ -316,7 +316,7 @@ public: using SmemLayoutACompute = decltype(UMMA::tile_to_mma_shape( SmemLayoutAtomACompute{}, - append(CtaShapeA_MK{}, Int{}), + append(CtaShapeA_MK{}, Int{}), (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( @@ -385,7 +385,7 @@ public: struct TensorStorageUntransformed { alignas(512) cute::ArrayEngine> smem_A; - cute::ArrayEngine> smem_B; + alignas(1024) cute::ArrayEngine> smem_B; cute::ArrayEngine smem_scale; cute::ArrayEngine smem_zero; }; diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index 7f956539..b6e662be 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -73,7 +73,7 @@ template < class SmemCopyAtomB_, class TransformB_> struct CollectiveMma< - MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling, + MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise, TileShape_, ElementA_, StridePairA_, @@ -92,7 +92,7 @@ struct CollectiveMma< // // Type Aliases // - using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling; + using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise; using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = cute::tuple_element_t<0,StridePairA_>; @@ -382,8 +382,6 @@ struct CollectiveMma< auto [M,N,K,L] = problem_shape_MNKL; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); - // We expect full tiles in K - implementable = implementable && K % size<2>(TileShape{}) == 0; } } @@ -824,16 +822,13 @@ struct CollectiveMma< // Prologue GMMAs tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - // fence_operand(); GmmaFP8Accumulation accumulation(accum, ScalePromotionInterval, size<2>(tCrA)); - warpgroup_fence_operand(accumulation()); - { + if (k_tile_count > 0) { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + pipeline.consumer_wait(smem_pipe_read, barrier_token); int read_stage = smem_pipe_read.index(); // Load per block scale values from shared memory to registers @@ -977,7 +972,7 @@ struct CollectiveMma< ++smem_pipe_release; } - if (k_tile_count) { + if (k_tile_count > 0) { pipeline.consumer_wait(smem_pipe_read, barrier_token); // @@ -1072,9 +1067,11 @@ struct CollectiveMma< /// Perform a Consumer Epilogue to release all buffers CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { - // The pipeline is not released in the first iteration - smem_pipe_release.advance(k_tile_count - 1); - pipeline.consumer_release(smem_pipe_release); + if (k_tile_count > 0) { + // The pipeline is not released in the first iteration + smem_pipe_release.advance(k_tile_count - 1); + pipeline.consumer_release(smem_pipe_release); + } } // diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index 7def1f32..48ddf7a0 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -73,7 +73,7 @@ template < class SmemCopyAtomB_, class TransformB_> struct CollectiveMma< - MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8, + MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8, TileShape_, ElementA_, StridePairA_, @@ -91,7 +91,7 @@ struct CollectiveMma< // // Type Aliases // - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8; + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8; using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = cute::tuple_element_t<0,StridePairA_>; @@ -391,12 +391,6 @@ struct CollectiveMma< implementable = false; CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load scale B.\n"); } - - // We expect full tiles in K - if (K % size<2>(TileShape{}) != 0) { - implementable = false; - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size K is incompatible with tile size.\n"); - } return implementable; } diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 314f99f5..6f42fc7b 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -127,10 +127,15 @@ struct KernelPtrArrayTmaWarpSpecializedCooperative { }; struct KernelPtrArrayTmaWarpSpecializedPingpong { }; // FP8 related policies (including Blocked Scaled Accumulation) -struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative { }; -struct KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum: KernelTmaWarpSpecializedPingpong { }; -struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedCooperative { }; -struct KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedPingpong { }; +struct KernelTmaWarpSpecializedCooperativeFP8Blockwise: KernelTmaWarpSpecializedCooperative { }; +struct KernelTmaWarpSpecializedPingpongFP8Blockwise: KernelTmaWarpSpecializedPingpong { }; +struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise: KernelPtrArrayTmaWarpSpecializedCooperative { }; +struct KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise: KernelPtrArrayTmaWarpSpecializedPingpong { }; + +using KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum = KernelTmaWarpSpecializedCooperativeFP8Blockwise; +using KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum = KernelTmaWarpSpecializedPingpongFP8Blockwise; +using KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum = KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise; +using KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum = KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise; // Policies to opt into mixed type GEMMs struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { }; @@ -319,17 +324,17 @@ struct MainloopSm90TmaGmmaWarpSpecializedFP8 // n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule -// For FP8 kernels with Block Scaling +// For FP8 kernels with Blockwise (Software) Scaling template< int Stages_, class ClusterShape_ = Shape<_1,_1,_1>, - class KernelSchedule = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum + class KernelSchedule = KernelTmaWarpSpecializedCooperativeFP8Blockwise > -struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8 +struct MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8 : MainloopSm90TmaGmmaWarpSpecialized { static_assert( - cute::is_same_v || - cute::is_same_v, + cute::is_same_v || + cute::is_same_v, "KernelSchedule must be one of the warp specialized FP8 block scale policies"); }; @@ -411,15 +416,15 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput { template< int Stages_, class ClusterShape_ = Shape<_1,_1,_1>, - class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum + class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise > -struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling +struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise : MainloopSm90ArrayTmaGmmaWarpSpecialized { static_assert( cute::is_any_of_v< KernelSchedule, - KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum, - KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum + KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise, + KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise >, "KernelSchedule must be one of the warp specialized FP8 block scale policies"); }; @@ -440,6 +445,15 @@ struct KernelWarpSpecializedSm100 final { static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; }; +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelMixedTmaCpAsyncWarpSpecializedSm100 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + template< int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_ @@ -653,7 +667,7 @@ template< class KernelSchedule > struct HasAuxiliaryLoad< - MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling< + MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise< Stages, ClusterShape, KernelSchedule @@ -666,7 +680,7 @@ template< class KernelSchedule > struct HasAuxiliaryLoad< - MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8< + MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8< Stages, ClusterShape, KernelSchedule @@ -700,6 +714,7 @@ struct KernelScheduleSm100DenseGemm : KernelScheduleSm100 {}; // Base policy struct KernelTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; // Use for 1SM Dense GEMM Kernels for Collective Mainloop Builder struct KernelTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100DenseGemm {}; // Use for 2SM Dense GEMM Kernels for Collective Mainloop Builder struct KernelWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; // Use for 1SM Dense GEMM Kernels for Collective Mainloop Builder Without TMA +struct KernelMixedTmaCpAsyncWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; /////////////////////////////////////////////////////////////////////////////////////////////////////// // SM100 Ptr-Array Dense GEMM Dispatch Policies @@ -795,6 +810,8 @@ struct KernelTmaWarpSpecialized1SmMxf4Sm100 final : KernelSchedule1 struct KernelTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, KernelScheduleMxNvf4Sm100 { }; struct KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelScheduleMxf8f6f4Sm100 { }; struct KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelScheduleMxf8f6f4Sm100 { }; +struct KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100 final : KernelSchedule1Sm, KernelScheduleBlockScaledGemmSm100 {}; + /////////////////////////////////////////////////////////////////////////////////////////////////////// // SM100 BlockScaled Ptr Array Dense GEMM Dispatch Policies /////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -950,6 +967,34 @@ struct MainloopSm100UmmaCpAsyncWarpSpecialized { using Schedule = KernelWarpSpecializedSm100; }; +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + using Schedule = KernelMixedTmaCpAsyncWarpSpecializedSm100; + constexpr static bool IsOverlappingAccum = false; +}; + +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + using Schedule = KernelMixedTmaCpAsyncWarpSpecializedSm100; + constexpr static bool IsOverlappingAccum = false; +}; + // n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule template< int Stages_, diff --git a/include/cutlass/gemm/group_array_problem_shape.hpp b/include/cutlass/gemm/group_array_problem_shape.hpp index fe5e4c53..400f7e6b 100644 --- a/include/cutlass/gemm/group_array_problem_shape.hpp +++ b/include/cutlass/gemm/group_array_problem_shape.hpp @@ -79,6 +79,16 @@ struct GroupProblemShape { } }; +template +struct MoEProblemShape { + using UnderlyingProblemShape = ProblemShape_; + using MaxProblemShape = MaxProblemShape_; + + UnderlyingProblemShape problem_shape; + MaxProblemShape max_problem_shape; +}; + + template class ArrayProblemShape { public: @@ -120,4 +130,14 @@ private: UnderlyingProblemShape problem_shape_{}; }; + +namespace detail { + +template +struct is_moe_problem_shape : cute::false_type {}; +template +struct is_moe_problem_shape> : cute::true_type {}; + +} + } // namespace cutlass::gemm diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index 7b086e27..b053963a 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -73,6 +73,7 @@ struct IsCutlass3ArrayKernel +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileSchedulerTag_, + cute::enable_if_t< + cutlass::detail::is_kernel_tag_of_v>> +{ +public: + using ProblemShape = ProblemShape_; + + static constexpr bool IsGroupedGemmKernel = cutlass::gemm::detail::is_moe_problem_shape::value; + static constexpr bool IsMoEScheduler = false; // stub for MoE scheduler, which accepts a MoEProblemShape instead of GroupProblemShape + + CUTLASS_HOST_DEVICE + static auto get_problem_shape_gemm(ProblemShape const& shape) { + if constexpr (IsGroupedGemmKernel) { + return shape.max_problem_shape; + } + else { + return shape; + } + } + CUTLASS_HOST_DEVICE + static auto get_problem_shape_scheduler(ProblemShape const& shape) { + if constexpr (IsMoEScheduler) { + return shape; + } + else if constexpr (IsGroupedGemmKernel) { + return shape.problem_shape; + } + else { + return shape; + } + } + + template + CUTLASS_HOST_DEVICE + static auto get_effective_shape(ProblemShape const& shape, WorkTileInfo const& work_tile_info) { + if constexpr (IsGroupedGemmKernel) { + return append<4>(shape.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + } + else { + return append<4>(shape, Int<1>{}); + } + } + + using ProblemShapeGemm = decltype(get_problem_shape_gemm(ProblemShape{})); + using ProblemShapeScheduler = decltype(get_problem_shape_scheduler(ProblemShape{})); + + static_assert(rank(ProblemShapeGemm{}) == 3 or rank(ProblemShapeGemm{}) == 4, + "ProblemShapeGemm{} should be or "); + static constexpr bool IsGdcEnabled = false; + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueTile = typename CollectiveEpilogue::EpilogueTile; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + static constexpr bool IsComplex = CollectiveEpilogue::NumAccumulatorMtxs == 2; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + static_assert(!IsOverlappingAccum, "TMA+CPASYNC kernel currently only supports non-overlapping accum."); + + // TileID scheduler + // Get Blk and Scheduling tile shapes + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + + static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment kernel only supports 1x1x1 cluster shape."); + using TileSchedulerTag = cute::conditional_t; + using TileScheduler = typename detail::TileSchedulerSelector< + TileSchedulerTag, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount, ProblemShapeScheduler>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEmptyThreads = 0; + static constexpr uint32_t NumMainloopTMALoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopCpAsyncLoadThreads = CollectiveMainloop::NumLoadThreadsCpAsync; // 4 warps + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopTMALoadThreads + NumMainloopCpAsyncLoadThreads + + NumMMAThreads + + NumEpilogueLoadThreads + NumEpilogueThreads + NumEmptyThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + static constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_load_pipe_increment(CtaShape_MNK{}); + + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + + // Pipelines and pipeline states + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + + // Pipeline and pipeline state types + using MainloopPipelineTMA = typename CollectiveMainloop::MainloopPipelineTMA; + using MainloopPipelineTMAState = typename CollectiveMainloop::MainloopPipelineTMAState; + using MainloopPipelineCpAsync = typename CollectiveMainloop::MainloopPipelineCpAsync; + using MainloopPipelineCpAsyncState = typename CollectiveMainloop::MainloopPipelineCpAsyncState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + // using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipeline = cute::conditional_t, + cutlass::PipelineAsync>; + using CLCPipelineState = typename CLCPipeline::PipelineState; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + + // Kernel level shared memory storage + struct SharedStorage { + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) CLCPipelineStorage clc; + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) arch::ClusterBarrier tmem_dealloc; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + + // Host facing host arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + ProblemShapeGemm problem_shape_gemm{}; + ProblemShapeScheduler problem_shape_scheduler{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopLoadTMA = 2, + EpilogueLoad = 3, + Epilogue = 4, + MainloopLoadCpAsync = 8 + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_load_tma = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + uint32_t main_load_cpasync = false; + }; + + // Convert to underlying arguments. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + // auto problem_shape = args.problem_shape; + // auto problem_shape_MNKL = append<4>(problem_shape, 1); + + auto problem_shape_gemm = get_problem_shape_gemm(args.problem_shape); + auto problem_shape_scheduler = get_problem_shape_scheduler(args.problem_shape); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count != 0) { + CUTLASS_TRACE_HOST(" WARNING: SM100 tile scheduler does not allow for user specified SM counts.\n" + " To restrict a kernel's resource usage, consider using CUDA driver APIs instead (green contexts)."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, problem_shape_scheduler, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + TileSchedulerParams scheduler; + if constexpr (IsGroupedGemmKernel) { + scheduler = TileScheduler::to_underlying_arguments( + problem_shape_scheduler, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace); + } + else { + auto problem_shape = args.problem_shape; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + scheduler = TileScheduler::to_underlying_arguments( + problem_shape, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ); + } + + return { + args.mode, + args.problem_shape, + problem_shape_gemm, + problem_shape_scheduler, + CollectiveMainloop::to_underlying_arguments(problem_shape_gemm, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(problem_shape_gemm, args.epilogue, epilogue_workspace), + hw_info, + scheduler + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = true; + + if constexpr (IsGroupedGemmKernel) { + implementable &= args.mode == GemmUniversalMode::kGrouped; + implementable &= rank(ProblemShapeGemm{}) == 4; + implementable &= rank(typename ProblemShape::UnderlyingProblemShape::UnderlyingProblemShape{}) == 3; + } + else { + implementable &= (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShapeGemm{}) == 4); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + + auto problem_shape_gemm = get_problem_shape_gemm(args.problem_shape); + implementable &= CollectiveMainloop::can_implement(problem_shape_gemm, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(problem_shape_gemm, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + static constexpr int MaxClusterSize = 16; + implementable &= size(ClusterShape{}) <= MaxClusterSize; + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + + auto problem_shape_gemm = get_problem_shape_gemm(args.problem_shape); + auto problem_shape_scheduler = get_problem_shape_scheduler(args.problem_shape); + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(problem_shape_gemm, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, problem_shape_scheduler, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + auto problem_shape_gemm = get_problem_shape_gemm(args.problem_shape); + auto problem_shape_scheduler = get_problem_shape_scheduler(args.problem_shape); + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(problem_shape_gemm, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shape_gemm, args.epilogue); + status = cutlass::Status::kSuccess; + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, problem_shape_scheduler, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, problem_shape_scheduler, args.hw_info, NumFixupBarriers); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = ClusterShape{}; + + dim3 grid_shape; + if constexpr (IsGroupedGemmKernel) { + grid_shape = TileScheduler::get_grid_shape( + params.scheduler, + params.problem_shape_scheduler, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + else { + auto problem_shape_MNKL = append<4>(params.problem_shape_scheduler, 1); + grid_shape = TileScheduler::get_grid_shape( + params.scheduler, + problem_shape_MNKL, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + return grid_shape; + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape_gemm, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Account for more than one epilogue warp + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) + : warp_idx < static_cast(WarpCategory::MainloopLoadCpAsync) ? WarpCategory::Epilogue + : WarpCategory::MainloopLoadCpAsync; + uint32_t lane_predicate = cute::elect_one_sync(); + auto tile_shape = TileShape{}; + auto cluster_shape = ClusterShape{}; + constexpr int cluster_size = size(ClusterShape{}); + int cta_rank_in_cluster = cute::block_rank_in_cluster(); + bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; + int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); + bool is_mma_leader_cta = cta_coord_v == 0; + int mma_leader_ctas = size(shape_div(cluster_shape, AtomThrShapeMNK{})); + [[maybe_unused]] uint32_t mma_peer_cta_rank = cta_rank_in_cluster; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop(params.mainloop); + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Do we load source tensor C or other aux inputs + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + + // printf("is_epi_load_needed = %d", (int)is_epi_load_needed); + + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA) && is_mma_leader_cta, // mma + (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::MainloopLoadTMA), // main_load_tma + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::MainloopLoadCpAsync) // main_load_cpasync + }; + + // Mainloop Load pipeline (TMA) + typename MainloopPipelineTMA::Params mainloop_pipeline_tma_params; + if (WarpCategory::MainloopLoadTMA == warp_category) { + mainloop_pipeline_tma_params.role = MainloopPipelineTMA::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_pipeline_tma_params.role = MainloopPipelineTMA::ThreadCategory::Consumer; + } + + mainloop_pipeline_tma_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_load_tma; + mainloop_pipeline_tma_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + mainloop_pipeline_tma_params.initializing_warp = 0; + MainloopPipelineTMA mainloop_pipeline_tma(shared_storage.pipelines.mainloop.tma, + mainloop_pipeline_tma_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // Mainloop Load pipeline (CpAsync) + typename MainloopPipelineCpAsync::Params mainloop_pipeline_cpasync_params; + if (WarpCategory::MainloopLoadCpAsync == warp_category) { + mainloop_pipeline_cpasync_params.role = MainloopPipelineCpAsync::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_pipeline_cpasync_params.role = MainloopPipelineCpAsync::ThreadCategory::Consumer; + } + + mainloop_pipeline_cpasync_params.producer_arv_count = NumMainloopCpAsyncLoadThreads; + mainloop_pipeline_cpasync_params.consumer_arv_count = 1; // Only UMMA consumes the A and B buffers + mainloop_pipeline_cpasync_params.dst_blockid = cta_rank_in_cluster; + mainloop_pipeline_cpasync_params.initializing_warp = 0; + MainloopPipelineCpAsync mainloop_pipeline_cpasync(shared_storage.pipelines.mainloop.cpasync, mainloop_pipeline_cpasync_params, cluster_shape); + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 3; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = IsSchedDynamicPersistent ? CLCPipeline::ThreadCategory::ProducerConsumer : CLCPipeline::ThreadCategory::Producer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_arv_count = 1; + + if constexpr (IsSchedDynamicPersistent) { + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopTMALoadThreads + NumMainloopCpAsyncLoadThreads + NumEpilogueThreads + NumMMAThreads); + clc_pipeline_params.transaction_bytes = CLCResponseSize; + } + else { + clc_pipeline_params.consumer_arv_count = NumMainloopTMALoadThreads + NumMainloopCpAsyncLoadThreads + NumEpilogueThreads + NumMMAThreads; + } + + clc_pipeline_params.initializing_warp = 1; + // CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + // Now declare the pipeline outside the if constexpr + CLCPipeline clc_pipeline = [&]() { + if constexpr (IsSchedDynamicPersistent) { + return CLCPipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + } + else { + return CLCPipeline(shared_storage.pipelines.clc, clc_pipeline_params); + } + }(); + + // Mainloop-Epilogue pipeline + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (WarpCategory::MMA == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + accumulator_pipeline_params.initializing_warp = 2; + AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, accumulator_pipeline_params, cluster_shape); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between MMA and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumMMAThreads + NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + + MainloopPipelineTMAState mainloop_pipe_tma_consumer_state; + MainloopPipelineTMAState mainloop_pipe_tma_producer_state = cutlass::make_producer_start_state(); + MainloopPipelineCpAsyncState mainloop_pipe_cpasync_consumer_state; + MainloopPipelineCpAsyncState mainloop_pipe_cpasync_producer_state = cutlass::make_producer_start_state(); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CLCPipelineState clc_pipe_consumer_state; + CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + // + // TMEM "Allocation" + // + // auto acc_shape = collective_mainloop.partition_accumulator_shape(); + // auto bulk_tmem = TiledMma::make_fragment_C(append(acc_shape, + // Int{})); + auto tmem_storage = collective_mainloop.template init_tmem_tensors(EpilogueTile{}); + + // + // END PROLOGUE + // + + // Synchronization call. Blocks until barriers are initialized in shared memory. + pipeline_init_wait(cluster_size); + + // __syncwarp(); + // if (threadIdx.x % 32 == 0) { + // printf("warp %d start\n", warp_idx); + // } + + if (is_participant.main_load_tma) { + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + // bool do_load_order_arrive = is_epi_load_needed; + bool requires_clc_query = true; + + auto load_inputs = collective_mainloop.load_init_tma( + problem_shape_MNKL, shared_storage.tensors.mainloop); + auto k_tiles = cute::get<0>(load_inputs); + + do { + auto effective_shape = get_effective_shape(params.problem_shape, work_tile_info); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, effective_shape, CtaShape_MNK{}, k_tiles); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, effective_shape, CtaShape_MNK{}); + // auto k_tile_prologue = min(MainloopPipeline::Stages, k_tile_count); + + + auto [mainloop_producer_state_next_, unused_] = collective_mainloop.load_tma( + mainloop_pipeline_tma, + mainloop_pipe_tma_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_count // - k_tile_prologue + ); + mainloop_pipe_tma_producer_state = mainloop_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + collective_mainloop.load_tail_tma(mainloop_pipeline_tma, mainloop_pipe_tma_producer_state); + + } + + else if (is_participant.main_load_cpasync) { + auto load_inputs = collective_mainloop.load_init_cpasync( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop, + scheduler, work_tile_info); + Tensor gA_mkl = get<0>(load_inputs); + + do { + // Get current work tile and fetch next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + auto effective_shape = get_effective_shape(params.problem_shape, work_tile_info); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, effective_shape, CtaShape_MNK{}, shape<3>(gA_mkl)); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, effective_shape, CtaShape_MNK{}); + + auto [mainloop_producer_state_next, unused_] = collective_mainloop.load_cpasync( + params.mainloop, + mainloop_pipeline_cpasync, + mainloop_pipe_cpasync_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_count, + effective_shape + ); + mainloop_pipe_cpasync_producer_state = mainloop_producer_state_next; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + + collective_mainloop.load_tail_cpasync(mainloop_pipeline_cpasync, mainloop_pipe_cpasync_producer_state); + + } + + else if (is_participant.sched) { + + if constexpr (IsSchedDynamicPersistent) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + cutlass::arch::wait_on_dependent_grids(); + + do { + if (requires_clc_query) { + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + else { + + cutlass::arch::wait_on_dependent_grids(); + + do { + auto [next_work_tile_info, increment_pipe] = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + work_tile_info = next_work_tile_info; + if (increment_pipe) { + ++clc_pipe_producer_state; + } + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + } + + else if (is_participant.mma) { + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + // bulk_tmem.data() = tmem_base_ptr; + collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); + + + // Pass the acc with tuple type since the bgrad kernel change the mma_init API + auto mma_inputs = collective_mainloop.mma_init(params.mainloop, + tmem_storage, + shared_storage.tensors.mainloop); + do { + auto effective_shape = get_effective_shape(params.problem_shape, work_tile_info); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, effective_shape, CtaShape_MNK{}); + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + // accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + int acc_stage = accumulator_pipe_producer_state.index(); + // Tensor accumulators = bulk_tmem(_,_,_,acc_stage); + auto [mainloop_pipe_tma_consumer_state_next_, mainloop_pipe_cpasync_consumer_state_next_] = collective_mainloop.mma( + cute::make_tuple(mainloop_pipeline_tma, mainloop_pipeline_cpasync, accumulator_pipeline), + cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, accumulator_pipe_producer_state), + // Pass the acc with tuple type since the bgrad kernel change the mma API + // cute::make_tuple(accumulators, accumulators), + collective_mainloop.slice_accumulator(tmem_storage, acc_stage), + mma_inputs, + cta_coord_mnkl, + k_tile_count + ); + mainloop_pipe_tma_consumer_state = mainloop_pipe_tma_consumer_state_next_; + mainloop_pipe_cpasync_consumer_state = mainloop_pipe_cpasync_consumer_state_next_; + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + + ++accumulator_pipe_producer_state; + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + // Release the right to allocate before deallocations so that the next CTA can rasterize + tmem_allocator.release_allocation_lock(); + + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + bool do_tail_load = false; + do { + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue + ); + + do_tail_load = true; + } + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + // Wait for tmem allocate here + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); + // bulk_tmem.data() = tmem_base_ptr; + + bool do_tail_store = false; + do { + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + // Accumulator stage slice + int acc_stage = accumulator_pipe_consumer_state.index(); + // Tensor accumulators = bulk_tmem(_,_,_,acc_stage); + auto accumulator = get<0>(collective_mainloop.slice_accumulator(tmem_storage, acc_stage)); + accumulator_pipe_consumer_state = scheduler.template fixup( + TiledMma{}, + work_tile_info, + accumulator, + accumulator_pipeline, + accumulator_pipe_consumer_state, + typename CollectiveEpilogue::CopyOpT2R{} + ); + + // + // Epilogue and write to gD + // + if (scheduler.compute_epilogue(work_tile_info)) { + auto [load_state_next, store_state_next, acc_state_next] = collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + accumulator_pipeline, + accumulator_pipe_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + accumulator, + shared_storage.tensors.epilogue + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + accumulator_pipe_consumer_state = acc_state_next; + do_tail_store = true; + } + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + } while (work_tile_info.is_valid()); + + // Only perform a tail store if one of the work units processed performed + // an epilogue. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp b/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp index 07d00fb6..8cf885f8 100755 --- a/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp +++ b/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp @@ -240,6 +240,27 @@ public: void fixup(WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t, uint32_t = 1) const { } + template < + bool IsComplex, + class TiledMma, + class AccEngine, + class AccLayout, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class CopyOpT2R + > + CUTLASS_DEVICE + AccumulatorPipelineState + fixup( + TiledMma const& , + WorkTileInfo const&, + cute::Tensor&, + AccumulatorPipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + CopyOpT2R) const { + return acc_pipe_consumer_state; + } + template static size_t get_workspace_size(Arguments const& args, ProblemShape problem_shape, KernelHardwareInfo const& hw_info, uint32_t, uint32_t = 1, uint32_t = 1) { diff --git a/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp index 5d6f13ec..06fd138d 100644 --- a/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp @@ -991,7 +991,7 @@ public: mainloop_sf_pipeline, mainloop_sf_pipe_producer_state, load_inputs, - cta_coord_mnkl, + cta_coord_mnk, k_tile_iter_next, k_tile_count - k_tile_prologue, false, /* did_batch_change - prologue loads handle tensormap acquire */ enable_prefetch ? k_tile_count - k_tile_prologue : 0 diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index 92bd5536..ec5cd4d0 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -831,8 +831,6 @@ public: collective_epilogue.template tensormaps_fence_acquire(epi_load_tensormap); } - bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx; - epi_load_pipe_producer_state = collective_epilogue.load( epi_load_pipeline, epi_load_pipe_producer_state, @@ -843,8 +841,7 @@ public: lane_idx, shared_storage.tensors.epilogue, epi_load_tensormap, - work_tile_info.reduction_subtile_idx(), - wait + work_tile_info.reduction_subtile_idx() ); } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp index 6ac24d34..fd7ff603 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -869,8 +869,6 @@ public: collective_epilogue.template tensormaps_fence_acquire(epi_load_tensormap); } - bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx; - epi_load_pipe_producer_state = collective_epilogue.load( epi_load_pipeline, epi_load_pipe_producer_state, @@ -881,8 +879,7 @@ public: lane_idx, shared_storage.tensors.epilogue, epi_load_tensormap, - work_tile_info.reduction_subtile_idx(), - wait + work_tile_info.reduction_subtile_idx() ); } diff --git a/media/docs/cpp/profiler.md b/media/docs/cpp/profiler.md index 8331b75f..57949a94 100644 --- a/media/docs/cpp/profiler.md +++ b/media/docs/cpp/profiler.md @@ -79,7 +79,7 @@ Instruction shape levels control the selection of WGMMA shapes used in kernel ge - **Level 2**: Includes shapes that are powers of 2. - **Level 3**: Includes all other shapes. -The detailed defination of the three instantiation levels controlling cluster shape, MMA shape multiplier, and instruction shape can be found in [sm90_shapes.py](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py). +The detailed definition of the three instantiation levels controlling cluster shape, MMA shape multiplier, and instruction shape can be found in [sm90_shapes.py](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py). Schedule pruning levels decide the epilogue schedule and mainloop schedule to stamp out a kernel instance. As defined in `get_valid_schedules` in [sm90_utils.py](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_utils.py), @@ -122,6 +122,55 @@ For each mixed dtype kernel, the kernel generator will generate combinations of For {4-bits-dtype, 8-bits-dtype} x 16-bits-dtype, the kernel generator will further generate kernels using shuffled layouts for the narrow data type matrix, which may have a better performance compared to its non-shuffle counter parts. +## Instantiating more kernels with Blackwell +Blackwell (SM100) and Blackwell Ultra similarly support +`CUTLASS_LIBRARY_INSTANTIATION_LEVEL`, in order to instantiate all possible combinations. +Due to this, `CUTLASS_LIBRARY_KERNELS` must be non-empty, since generating and filtering these +kernels alone can take hours. +You must also exercise caution, because not all of these configs are tested, and some may fail to +compile or fail to launch at runtime. + +```bash +$ cmake .. \ + -DCUTLASS_NVCC_ARCHS="100f" \ + -DCUTLASS_LIBRARY_KERNELS="cutlass3x_sm100_tensorop_gemm_f16_f16_f32_void_f32_*" \ + -DCUTLASS_LIBRARY_INSTANTIATION_LEVEL="max" \ + -DCUTLASS_UNITY_BUILD_ENABLED=ON +``` + +The CUTLASS profiler uses the same four-digit integer level (global instantiation level) mechanism to manage the generation of kernel configurations for Blackwell as well: + +0. **Instruction Shape** +1. **MMA Shape Multiplier** +2. **Cluster Shape** +3. **Data Type and Schedule Pruning** + +Note for Blackwell kernels an MMA shape multiplier is no longer necessary since Blackwell kernels do not have a different +ping pong or cooperative schedule. The profiler ignores this digit when instantiating. + +Cluster shape levels define the number of CTAs (Cooperative Thread Arrays) included in the kernel generation: + +- **Level 0**: Only dynamic cluster shapes. +- **Level 1**: For 1SM kernels `(1, 1, 1)` and `(2, 1, 1)` for 2SM kernels. +- **Level 2**: For 1SM kernels we also have `(1, 2, 1)` and for 2SM we have `(2, 2, 1)` and `(4, 1, 1)`. +- **Level 3**: For 1SM kernels we have `(1, 4, 1)` and for 2SM we have `(2, 4, 1)` and `(4, 2, 1)`. +- **Level 4**: For 1SM kernels we have `(4, 4, 1)` and for 2SM we have `(4, 4, 1)`. +- **Level 5**: For 1SM kernels we have `(2, 1, 1)`. +- **Level 6**: For 1SM kernels we have `(2, 2, 1)` and `(4, 1, 1)` and for 2SM kernels we have `(8, 1, 1)`. +- **Level 7**: For 1SM kernels we have `(2, 4, 1)` and `(4, 2, 1)` +- **Level 8**: For 1SM kernels we have `(1, 8, 1)` and `(8, 1, 1)` + +Instruction shape levels control the selection of MMA shapes used in kernel generation: + +- **Level 0**: Generates the "default" shape only. +- **Level 1**: Includes additional shapes for FP8, FP6, and FP4 as well as MX and NVFP4. +- **Level 2**: Includes small tile shapes. +- **Level 3**: Includes some non-power of 2 shapes. +- **Level 4**: Includes further small tile shapes and non-power of 2 shapes. +- **Level 5**: Includes all shapes. + +The detailed definition of the three instantiation levels controlling cluster shape and instruction shape can be found in [sm100_shapes.py](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm100_shapes.py). + ## CUTLASS Profiler usage The CUTLASS Profiler usage statement may be obtained by executing `cutlass_profiler --help` and appears as follows. @@ -577,6 +626,10 @@ cutlass3x_sm90_tensorop_gemm_f16_f16_f16_void_f16_128x128x64_1x1x1_0_nnn_align8_ * `f16_f16_f16_void_f16`: In this case, C type is set to `void`, indicating that residual matrix support is disabled. +## Further Documentation + +For documentation on profiling blockwise and groupwise (software scaled) GEMMs see the [example 81 README](https://github.com/NVIDIA/cutlass/blob/main/examples/81_blackwell_gemm_blockwise/README.md). + # Convolution The CUTLASS Profiler is capable of executing 2-D and 3-D convolution problems for forwards and backwards diff --git a/media/docs/pythonDSL/cute_dsl_api.rst b/media/docs/pythonDSL/cute_dsl_api.rst index c4726eb3..2461af39 100644 --- a/media/docs/pythonDSL/cute_dsl_api.rst +++ b/media/docs/pythonDSL/cute_dsl_api.rst @@ -6,6 +6,7 @@ CuTe DSL API .. toctree:: :maxdepth: 1 + changelog cute cute_arch cute_nvgpu diff --git a/media/docs/pythonDSL/cute_dsl_api/changelog.rst b/media/docs/pythonDSL/cute_dsl_api/changelog.rst new file mode 100644 index 00000000..fc4d1fd4 --- /dev/null +++ b/media/docs/pythonDSL/cute_dsl_api/changelog.rst @@ -0,0 +1,54 @@ +====================================== +Changelog for CuTe DSL API changes +====================================== + +`4.2.0 `_ (2025-09-15) +============================================================================== + +* Added back ``cute.make_tiled_copy`` per the request from community +* Added support for explicit and implicit broadcast in ``TensorSSA`` + - ``cutlass.cute.TensorSSA``: support ``broadcast_to`` and implicit broadcasting for binary operations. +* Supported printing ``TensorSSA`` value in ``cutlass.cute.print_tensor`` +* Updated ``cute.gemm`` to support all dispatch patterns and improved checks for illegal inputs +* Introduced automatic kernel smem usage calculation for launch config. +* Introduced per op fast-math control for math ops(e.g. ``exp``, ``exp2``, ``log2``, ``log``) +* Introduced ``CopyReduceBulkTensorTileS2GOp`` in `tcgen05/copy.py `_ to support TMA Reduce. + + +`4.1.0 `_ (2025-07-16) +============================================================================== + +* for loop + + - Python built-in ``range`` now always generates codes and executes at runtime + - ``cutlass.range`` is advanced ``range`` with kernel code level unrolling and pipelining control + - Deprecated ``cutlass.range_dynamic``, please replace with ``range`` or ``cutlass.range`` + - **Experimental** Added ``pipelining`` control for compiler generated software pipeline code + +* while/if + + - ``while``/``if`` now by default generates codes and executes at runtime unless ``cutlass.const_expr`` is specified for the predicate + - Deprecated ``cutlass.dynamic_expr``, please remove it + +* Rename mbarrier functions to reduce ambiguity +* Modify SyncObject API (``MbarrierArray``, ``NamedBarrier``, ``TmaStoreFence``) to match ``std::barrier`` +* Change pipeline ``create`` function to take only keyword arguments, and make ``barrier_storage`` optional. +* Introduce ``cutlass.cute.arch.get_dyn_smem_size`` api to get runtime dynamic shared memory size. +* Various API Support for SM100 BlockScaled Gemm + + - Introduce BlockScaled MmaOps in `tcgen05/mma.py `_, and provide a ``make_blockscaled_trivial_tiled_mma`` function in `blackwell_helpers.py `_ to help construct a BlockScaled TiledMma. + - Introduce S2T CopyOps in `tcgen05/copy.py `_. + - Introduce BlockScaled layout utilities in `blockscaled_layout.py `_ for creating the required scale factor layouts in global memory, shared memory and tensor memory. + +* ``cutlass.cute.compile`` now supports compilation options. Refer to `JIT compilation options `_ for more details. +* ``cutlass.cute.testing.assert_`` now works for device JIT function. Specify ``--enable-device-assertions`` as compilation option to enable. +* ``cutlass.cute.make_tiled_copy`` is now deprecated. Please use ``cutlass.cute.make_tiled_copy_tv`` instead. +* Shared memory capacity query + + - Introduce ``cutlass.utils.get_smem_capacity_in_bytes`` for querying the shared memory capacity. + - ``_utils.SMEM_CAPACITY[""]`` is now deprecated. + +`4.0.0 `_ (2025-06-03) +============================================================================== + +* Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer`` diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_control_flow.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_control_flow.rst index 21f0912e..f5c2ea35 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_control_flow.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_control_flow.rst @@ -72,6 +72,55 @@ All loop indices must be |Constexpr|. for i in cutlass.range(bound, unroll=2): cute.printf("%d\\n", i) +Software Pipelining +~~~~~~~~~~~~~~~~~~~ + +Software pipelining is a technique used to optimize loops. Typically, this involves writing a prefetch loop and a main loop. + +.. code-block:: python + + @cute.jit + def example(): + ... + # build a circular buffer + buffer = ... + + # prefetch loop + for i in range(prefetch_stages): + cute.copy(atom, gmem[i], buffer[i], ...) + + # main loop + for i in range(bound): + if i + prefetch_stages < bound: + cute.copy(atom, gmem[i + prefetch_stages], buffer[(i + prefetch_stages) % total_stages], ...) + + use(buffer[i % total_stages]) + + ... + +This can be tedious to write and tune. |DSL| provides a loop attribute to ask the compiler to do this. + +.. code-block:: python + + @cute.jit + def example(): + ... + # build a circular buffer + buffer = ... + + for i in cutlass.range(bound, prefetch_stages=prefetch_stages): + # Compiler automatically handles the pipelining: + # - Generates prefetch loop for initial stages + # - In main loop, prefetches future data while using current data + cute.copy(atom, gmem[i], buffer[i % total_stages], ...) + use(buffer[i % total_stages]) # Uses data from previous iterations + + ... + +Compiler will automatically generate the prefetch loop with `prefetch_stages` iterations and a corresponding main loop. + +This feature is experimental and only supported on sm90 and above. + If-Else Statements ------------------ diff --git a/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst b/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst index 269c6602..edd9eb94 100644 --- a/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst +++ b/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst @@ -7,7 +7,8 @@ Integration with Frameworks In order to facilitate the integration of CUTLASS Python with popular frameworks, we leverage the `DLPack protocol `_ and transform tensors originating from these frameworks to CuTe tensors. The present page documents the conventions, the API available to the -user, and provide example code snippets for common usage patterns. +user, and provide example code snippets for common usage patterns. We also provide a section on how to +bypass the DLPack protocol and directly call the JIT function. Implicit Conversion ------------------- @@ -396,3 +397,84 @@ The following example demonstrates how to use ``mark_compact_shape_dynamic`` to mode=0, divisibility=1, stride_order=(2, 1, 3, 0, 4) ) # The stride_order is not consistent with the layout + + +Bypass the DLPack Protocol +-------------------------- + +In certain scenarios, users may wish to bypass the DLPack protocol and invoke the JIT function directly. +This can be accomplished by creating a lightweight JIT wrapper around the existing JIT function, +utilizing ``cute.ptr`` and ``cute.make_tensor`` to pass pointers and construct tensors directly. + +Typical use cases for bypassing DLPack include: +1. Users want to call the JIT function directly to avoid the overhead introduced by the DLPack protocol. +2. DLPack canonicalizes the stride of shape-1 dimensions to 1, which may result in incorrect alignment +propagation and affect memory access or performance. +3. DLPack may lack support for some narrow data types. + +The following example illustrates how to bypass the DLPack protocol when invoking a JIT function. +Assume we have a pre-defined ``TensorOpGemm`` kernel whose JIT interface expects three +arguments of type ``cute.Tensor``. To enable direct invocation without DLPack, we first define a JIT wrapper +function that accepts ``cute.Pointer`` types as parameters. Within this wrapper, we use ``cute.make_tensor`` +to construct tensors from the provided pointers, and then call the ``TensorOpGemm`` kernel as usual. + +.. code-block:: python + + @cute.jit + def tensor_op_gemm_wrapper( + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + c_ptr: cute.Pointer, + m: cutlass.Int32, + n: cutlass.Int32, + k: cutlass.Int32, + l: cutlass.Int32, + ): + + # Assume alignment of shape to call tensorop_gemm example + m = cute.assume(m, divby=8) + n = cute.assume(n, divby=8) + + # Torch is row major + a_layout = cute.make_ordered_layout((m, k, l), order=(0, 1, 2)) + b_layout = cute.make_ordered_layout((n, k, l), order=(0, 1, 2)) + c_layout = cute.make_ordered_layout((m, n, l), order=(1, 0, 2)) + mA = cute.make_tensor(a_ptr, layout=a_layout) + mB = cute.make_tensor(b_ptr, layout=b_layout) + mC = cute.make_tensor(c_ptr, layout=c_layout) + + # TensorOpGemm is a pre-defined kernel from our example + tensor_op_gemm = TensorOpGemm( + a_ptr.value_type, c_ptr.value_type, cutlass.Float32, (2, 2, 1) + ) + + tensor_op_gemm(mA, mB, mC) + +To pass a PyTorch tensor to this new JIT wrapper, we retrieve the raw pointer from the PyTorch tensor +and create a ``cute.Pointer`` instance using ``cute.make_ptr``. +This approach allows us to bypass the DLPack protocol entirely, avoiding its overhead and potential +issues with shape-1 dimension handling. + +.. code-block:: python + + a = torch.randn( + m, k, l, dtype=torch.float16, device="cuda" + ).permute(2, 1, 0) + b = torch.randn( + n, k, l, dtype=torch.float16, device="cuda" + ).permute(2, 1, 0) + c = torch.randn( + n, m, l, dtype=torch.float16, device="cuda" + ).permute(1, 2, 0) + + # from cutlass.cute.runtime import make_ptr + a_ptr = make_ptr( + cutlass.Float16, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + b_ptr = make_ptr( + cutlass.Float16, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + c_ptr = make_ptr( + cutlass.Float16, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + tensor_op_gemm_wrapper(a_ptr, b_ptr, c_ptr, m, n, k, l) diff --git a/pyproject.toml b/pyproject.toml index c046dc94..d14493b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "nvidia-cutlass" -version = "4.0.0.0" +version = "4.2.0.0" description = "CUTLASS" readme = "README.md" requires-python = ">=3.8" diff --git a/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py b/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py index 8109d3c2..60cc8db3 100644 --- a/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py +++ b/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py @@ -207,7 +207,7 @@ def int_to_int(a, dst_elem_type, *, loc=None, ip=None): if dst_width == src_width: return a - elif src_signed and not dst_signed: + elif src_signed != False and not dst_signed: # Signed -> Unsigned if dst_width > src_width: return arith.extui(dst_mlir_type, a, loc=loc, ip=ip) @@ -216,7 +216,7 @@ def int_to_int(a, dst_elem_type, *, loc=None, ip=None): elif src_signed == dst_signed: # Same signedness if dst_width > src_width: - if src_signed and src_width > 1: + if src_signed != False and src_width > 1: return arith.extsi(dst_mlir_type, a, loc=loc, ip=ip) else: return arith.extui(dst_mlir_type, a, loc=loc, ip=ip) @@ -479,7 +479,7 @@ class ArithValue(ir.Value): if self.is_float: q = arith.divf(self, other, loc=loc, ip=ip) return math.floor(q, loc=loc, ip=ip) - elif self.signed: + elif self.signed != False: return arith.floordivsi(self, other, loc=loc, ip=ip) else: return arith.divui(self, other, loc=loc, ip=ip) @@ -489,7 +489,7 @@ class ArithValue(ir.Value): def __mod__(self, other, *, loc=None, ip=None) -> "ArithValue": if self.is_float: return arith.remf(self, other, loc=loc, ip=ip) - elif self.signed: + elif self.signed != False: return arith.remsi(self, other, loc=loc, ip=ip) else: return arith.remui(self, other, loc=loc, ip=ip) @@ -524,7 +524,7 @@ class ArithValue(ir.Value): def __lt__(self, other, *, loc=None, ip=None) -> "ArithValue": if self.is_float: return arith.cmpf(arith.CmpFPredicate.OLT, self, other, loc=loc, ip=ip) - elif self.signed: + elif self.signed != False: return arith.cmpi(arith.CmpIPredicate.slt, self, other, loc=loc, ip=ip) else: return arith.cmpi(arith.CmpIPredicate.ult, self, other, loc=loc, ip=ip) @@ -534,7 +534,7 @@ class ArithValue(ir.Value): def __le__(self, other, *, loc=None, ip=None) -> "ArithValue": if self.is_float: return arith.cmpf(arith.CmpFPredicate.OLE, self, other, loc=loc, ip=ip) - elif self.signed: + elif self.signed != False: return arith.cmpi(arith.CmpIPredicate.sle, self, other, loc=loc, ip=ip) else: return arith.cmpi(arith.CmpIPredicate.ule, self, other, loc=loc, ip=ip) @@ -561,7 +561,7 @@ class ArithValue(ir.Value): def __gt__(self, other, *, loc=None, ip=None) -> "ArithValue": if self.is_float: return arith.cmpf(arith.CmpFPredicate.OGT, self, other, loc=loc, ip=ip) - elif self.signed: + elif self.signed != False: return arith.cmpi(arith.CmpIPredicate.sgt, self, other, loc=loc, ip=ip) else: return arith.cmpi(arith.CmpIPredicate.ugt, self, other, loc=loc, ip=ip) @@ -571,7 +571,7 @@ class ArithValue(ir.Value): def __ge__(self, other, *, loc=None, ip=None) -> "ArithValue": if self.is_float: return arith.cmpf(arith.CmpFPredicate.OGE, self, other, loc=loc, ip=ip) - elif self.signed: + elif self.signed != False: return arith.cmpi(arith.CmpIPredicate.sge, self, other, loc=loc, ip=ip) else: return arith.cmpi(arith.CmpIPredicate.uge, self, other, loc=loc, ip=ip) @@ -599,7 +599,7 @@ class ArithValue(ir.Value): @_dispatch_to_rhs_r_op @_binary_op def __rshift__(self, other, *, loc=None, ip=None) -> "ArithValue": - if self.signed: + if self.signed != False: return arith.shrsi(self, other, loc=loc, ip=ip) else: return arith.shrui(self, other, loc=loc, ip=ip) @@ -633,7 +633,7 @@ class ArithValue(ir.Value): return super().__hash__() def __str__(self): - return super().__str__().replace(ir.Value.__name__, ArithValue.__name__) + return "?" def __repr__(self): return self.__str__() @@ -657,7 +657,7 @@ def _min(lhs, rhs, *, loc=None, ip=None): rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip) if arith._is_integer_like_type(lhs.type): - if lhs.signed: + if lhs.signed != False: return arith.minsi(lhs, rhs, loc=loc, ip=ip) else: return arith.minui(lhs, rhs, loc=loc, ip=ip) @@ -683,7 +683,7 @@ def _max(lhs, rhs, *, loc=None, ip=None): rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip) if arith._is_integer_like_type(lhs.type): - if lhs.signed: + if lhs.signed != False: return arith.maxsi(lhs, rhs, loc=loc, ip=ip) else: return arith.maxui(lhs, rhs, loc=loc, ip=ip) diff --git a/python/CuTeDSL/base_dsl/ast_helpers.py b/python/CuTeDSL/base_dsl/ast_helpers.py index b857e40e..7b0832b8 100644 --- a/python/CuTeDSL/base_dsl/ast_helpers.py +++ b/python/CuTeDSL/base_dsl/ast_helpers.py @@ -17,12 +17,16 @@ The preprocessor read through python's ast and changes the input code. from typing import Callable, Iterator, Optional, overload from typing_extensions import deprecated import warnings +import inspect +from types import BuiltinFunctionType +from functools import lru_cache from .utils.logger import log from .common import * from ._mlir_helpers.arith import ArithValue + class Executor: """ The Executor class handles dynamic and compile-time (constexpr) execution @@ -45,9 +49,11 @@ class Executor: self._compare_executor = None self._any_executor = None self._all_executor = None + self._builtin_redirector = None def set_functions( self, + *, is_dynamic_expression: Callable, loop_execute_range_dynamic: Callable, if_dynamic: Callable, @@ -55,6 +61,7 @@ class Executor: compare_executor: Callable, any_executor: Callable = None, all_executor: Callable = None, + builtin_redirector: Callable = None, ): self._is_dynamic_expression = is_dynamic_expression self._loop_execute_range_dynamic = loop_execute_range_dynamic @@ -63,6 +70,7 @@ class Executor: self._compare_executor = compare_executor self._any_executor = any_executor self._all_executor = all_executor + self._builtin_redirector = builtin_redirector @staticmethod def convert_to_list(x): @@ -90,42 +98,18 @@ class Executor: return res[0] return res - @staticmethod - def for_constexpr( - func: Callable, - start: int, - stop: int, - step: int, - used_args: list, - iter_args: list, - ): - log().debug("start [%s] stop [%s] step [%s]", start, stop, step) - loop_results = iter_args - log().debug("iter_args [%s]", iter_args) - for i in range(start, stop, step): - log().debug("i [%s] iter_args [%s]", i, iter_args) - loop_results = func(i, *used_args, *loop_results) - log().debug("loop_results [%s]", loop_results) - if loop_results is None: - loop_results = [] - if not isinstance(loop_results, list): - loop_results = [loop_results] - - log().debug("done loop_results [%s]", loop_results) - return Executor.converge_ret_val(loop_results) - def for_execute( self, func, start, stop, step, - used_args=[], - iter_args=[], - iter_arg_names=[], + write_args=[], + full_write_args_count=0, + write_args_names=[], unroll=-1, unroll_full=False, - pipelining=None, + prefetch_stages=None, ): assert ( self._loop_execute_range_dynamic @@ -137,12 +121,12 @@ class Executor: start, stop, step, - used_args, - iter_args, - iter_arg_names, + write_args, + full_write_args_count, + write_args_names, unroll, unroll_full, - pipelining, + prefetch_stages, ) def if_execute( @@ -150,15 +134,20 @@ class Executor: pred, then_block: Callable, else_block: Optional[Callable] = None, - used_args=[], - yield_args=[], - yield_arg_names=[], + write_args=[], + full_write_args_count=0, + write_args_names=[], ): assert self._if_dynamic, "Functions must be set before execution." # MLIR generation return self._if_dynamic( - pred, then_block, else_block, used_args, yield_args, yield_arg_names + pred, + then_block, + else_block, + write_args, + full_write_args_count, + write_args_names, ) def while_execute( @@ -166,9 +155,9 @@ class Executor: pred, while_before_block: Callable, while_after_block: Callable, - used_args=[], - yield_args=[], - yield_arg_names=[], + write_args=[], + full_write_args_count=0, + write_args_names=[], ): assert self._while_dynamic, "Functions must be set before execution." @@ -176,9 +165,9 @@ class Executor: return self._while_dynamic( while_before_block, while_after_block, - used_args, - yield_args, - yield_arg_names, + write_args, + full_write_args_count, + write_args_names, ) @@ -194,23 +183,24 @@ def loop_selector( stop, step, *, - used_args=[], - iter_args=[], - iter_arg_names=[], + write_args=[], + full_write_args_count=0, + write_args_names=[], unroll=-1, unroll_full=False, - pipelining=None, + prefetch_stages=None, ): log().debug( - "start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] pipelining [%s]", + "start [%s] stop [%s] step [%s] write_args [%s] full_write_args_count [%s] write_args_names [%s] unroll [%s] unroll_full [%s] prefetch_stages [%s]", start, stop, step, - used_args, - iter_args, + write_args, + full_write_args_count, + write_args_names, unroll, unroll_full, - pipelining, + prefetch_stages, ) from .typing import Integer, Numeric @@ -230,19 +220,19 @@ def loop_selector( start, stop, step, - used_args, - iter_args, - iter_arg_names, + write_args, + full_write_args_count, + write_args_names, unroll, unroll_full, - pipelining, + prefetch_stages, ) return ir_loop -def if_selector(pred, used_args=[], yield_args=[]): - log().debug("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args) +def if_selector(pred, write_args=[]): + log().debug("pred [%s] write_args [%s]", pred, write_args) # Handle Numeric types here? from .typing import Numeric @@ -251,14 +241,14 @@ def if_selector(pred, used_args=[], yield_args=[]): pred = pred.value def ir_loop(func): - return func(pred, *used_args, *yield_args) + return func(pred, *write_args) return ir_loop -def while_selector(pred, used_args=[], yield_args=[]): +def while_selector(pred, write_args=[]): def ir_while_loop(func): - return func(pred, *used_args, *yield_args) + return func(pred, *write_args) return ir_while_loop @@ -267,17 +257,17 @@ def while_executor( pred, while_before_block: Callable, while_after_block: Callable, - used_args=[], - yield_args=[], - yield_arg_names=[], + write_args=[], + full_write_args_count=0, + write_args_names=[], ): return executor.while_execute( pred, while_before_block, while_after_block, - used_args, - yield_args, - yield_arg_names, + write_args, + full_write_args_count, + write_args_names, ) @@ -285,12 +275,17 @@ def if_executor( pred, then_block: Callable, else_block: Optional[Callable] = None, - used_args=[], - yield_args=[], - yield_arg_names=[], + write_args=[], + full_write_args_count=0, + write_args_names=[], ): return executor.if_execute( - pred, then_block, else_block, used_args, yield_args, yield_arg_names + pred, + then_block, + else_block, + write_args, + full_write_args_count, + write_args_names, ) @@ -313,14 +308,17 @@ class range: - unroll: Number of iterations to unroll (0 or 1 = no unrolling) - unroll_full: Whether to fully unroll the loop - - pipelining: Compiler generated pipeline configuration + - prefetch_stages: Number of prefetch stages to generate """ + @overload - def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None): + def __new__(cls, stop, unroll=0, unroll_full=False, prefetch_stages=None): pass @overload - def __new__(cls, start, stop, step, unroll=0, unroll_full=False, pipelining=None): + def __new__( + cls, start, stop, step, unroll=0, unroll_full=False, prefetch_stages=None + ): pass def __new__(cls, *args, **kwargs): @@ -340,6 +338,7 @@ def range_dynamic(*args, **kwargs): def range_constexpr(*args): raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.") + # ============================================================================= # If expressions # ============================================================================= @@ -405,7 +404,7 @@ def assert_executor(test, msg=None): else: raise DSLRuntimeError( "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.", - suggestion = "Please replace with runtime assert." + suggestion="Please replace with runtime assert.", ) @@ -413,10 +412,11 @@ def bool_cast(value): if executor._is_dynamic_expression(value): raise DSLRuntimeError( "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.", - suggestion = "Please explicitly convert to boolean with expressions like comparision." + suggestion="Please explicitly convert to boolean with expressions like comparision.", ) return bool(value) + def compare_executor(left, comparators, ops): """ Executes comparison operations with a left operand and a list of comparators. @@ -470,6 +470,19 @@ def all_executor(iterable): # ============================================================================= # Control flow checks # ============================================================================= +class DSLOptimizationWarning(Warning): + """ + This warning is used to warn the user about the optimization related issues in DSL. + """ + + def __init__(self, message): + self.message = message + super().__init__() + + def __str__(self): + return self.message + + def range_value_check(*args): """ Ensure all `range_constexpr` bounds are compile-time constants (Python ints). @@ -495,7 +508,7 @@ def range_value_check(*args): if range_length >= 64: warnings.warn( f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.", - category=UserWarning, + category=DSLOptimizationWarning, stacklevel=2, ) @@ -519,7 +532,50 @@ def range_perf_warning(filename, lineno, *args): "This loop is no longer unrolled and may cause performance regression. " "Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants." ), - category=UserWarning, + category=DSLOptimizationWarning, filename=filename, lineno=lineno, ) + + +@lru_cache(maxsize=1) +def _get_self_module(): + """ + This function is used to get the owning module of this function. + """ + return inspect.getmodule(_get_self_module) + + +def cf_symbol_check(symbol): + """ + Check if the symbol is control flow symbol from current module. + """ + + failed = False + name = symbol.__name__ + self_module = _get_self_module() + if inspect.ismodule(symbol): + name = "range" + if not self_module.__name__.startswith(symbol.__name__): + failed = True + else: + owning_module = inspect.getmodule(symbol) + if owning_module != self_module: + failed = True + + if failed: + raise DSLRuntimeError( + f"Incorrect {symbol.__name__} is used.", + suggestion=f"Please avoid overriding `{symbol.__name__}` from DSL package.", + ) + + +def redirect_builtin_function(fcn): + """ + This function is used to redirect built-in function call + to the function defined in DSL package. + """ + # Only redirect if it's a built-in + if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector: + return executor._builtin_redirector(fcn) + return fcn diff --git a/python/CuTeDSL/base_dsl/ast_preprocessor.py b/python/CuTeDSL/base_dsl/ast_preprocessor.py index bffbc7f2..b9991a75 100644 --- a/python/CuTeDSL/base_dsl/ast_preprocessor.py +++ b/python/CuTeDSL/base_dsl/ast_preprocessor.py @@ -41,6 +41,7 @@ from dataclasses import dataclass from typing import List, Set, Dict, Any, Callable, Optional from types import ModuleType from collections import OrderedDict +from copy import deepcopy from .common import * from .utils.logger import log @@ -88,6 +89,16 @@ class OrderedSet: return result +@dataclass +class ImportInfo: + """ + Information about an import expression. + """ + module_path: str + attr_name: Optional[str] + alias_name: str + + @dataclass class ScopeManager: """ @@ -139,7 +150,7 @@ class DSLPreprocessor(ast.NodeTransformer): ANY_EXECUTOR = "any_executor" ALL_EXECUTOR = "all_executor" - def __init__(self): + def __init__(self, client_module_name): super().__init__() self.counter = 0 # Unique function names for multiple loops self.scope_manager = ScopeManager.create() @@ -151,10 +162,39 @@ class DSLPreprocessor(ast.NodeTransformer): self.function_depth = 0 self.local_closures = set() self.function_globals = None + self.client_module_name = client_module_name + self.import_top_module = False + + def _create_module_attribute( + self, + func_name, + *, + top_module_name="_dsl_", + submodule_name="ast_helpers", + lineno=None, + col_offset=None, + ): + # If we simply copy location from origin node, it contains a way to wide range, which cause location in traceback to be wrong. + def set_location(node, lineno, col_offset): + if lineno and col_offset: + node.lineno = lineno + node.end_lineno = lineno + node.col_offset = col_offset + node.end_col_offset = col_offset + + base = ast.Name(id=top_module_name, ctx=ast.Load()) + set_location(base, lineno, col_offset) + if submodule_name: + base = ast.Attribute(value=base, attr=submodule_name, ctx=ast.Load()) + set_location(base, lineno, col_offset) + node = ast.Attribute(value=base, attr=func_name, ctx=ast.Load()) + set_location(node, lineno, col_offset) + return node def _get_module_imports(self, decorated_func): """Extract imports from the module containing the decorated function""" - imports = OrderedDict() + imports = [] + # Get the module containing the decorated function if module := inspect.getmodule(decorated_func): try: @@ -167,7 +207,13 @@ class DSLPreprocessor(ast.NodeTransformer): for node in ast.walk(module_ast): if isinstance(node, ast.Import): for name in node.names: - imports[(name.name, None)] = alias(name) + imports.append( + ImportInfo( + module_path=name.name, + attr_name=None, + alias_name=alias(name), + ) + ) elif isinstance(node, ast.ImportFrom): module_name = node.module if node.level > 0: @@ -177,7 +223,13 @@ class DSLPreprocessor(ast.NodeTransformer): )[0] module_name = f"{package_name}.{module_name}" for name in node.names: - imports[(module_name, name.name)] = alias(name) + imports.append( + ImportInfo( + module_path=module_name, + attr_name=name.name, + alias_name=alias(name), + ) + ) except (IOError, TypeError): pass @@ -188,7 +240,12 @@ class DSLPreprocessor(ast.NodeTransformer): module_imports = self._get_module_imports(original_function) # Import all required modules - for (module_path, attr_name), alias_name in module_imports.items(): + for import_info in module_imports: + module_path, attr_name, alias_name = ( + import_info.module_path, + import_info.attr_name, + import_info.alias_name, + ) try: module = importlib.import_module(module_path) if attr_name: @@ -267,10 +324,24 @@ class DSLPreprocessor(ast.NodeTransformer): # Step 2. Transform the function transformed_tree = self.visit(tree) + + # Step 3. Import cutlass and base_dsl + top_module_name = ".".join(self.client_module_name) + import_stmts = [] + if self.import_top_module: + import_stmts.append(ast.Import(names=[ast.alias(name=top_module_name)])) + import_stmts.append( + ast.Import( + names=[ast.alias(name=f"{top_module_name}.base_dsl", asname="_dsl_")] + ) + ) + transformed_tree.body = import_stmts + transformed_tree.body + + # Step 4. Import cutlass and base_dsl ast.fix_missing_locations(transformed_tree) combined_body = transformed_tree.body - # Step 3. Return the transformed tree + # Step 5. Return the transformed tree return combined_body def check_early_exit(self, tree, kind): @@ -395,23 +466,45 @@ class DSLPreprocessor(ast.NodeTransformer): """ # we need orderedset to keep the insertion order the same. otherwise generated IR is different each time - read_args = OrderedSet() write_args = OrderedSet() + invoked_args = OrderedSet() local_closure = self.local_closures file_name = self.file_name region_node = node class RegionAnalyzer(ast.NodeVisitor): + force_store = False def visit_Name(self, node): """ - Mark every load as read, and every store as write. + Mark every store as write. """ - if isinstance(node.ctx, ast.Load): - read_args.add(node.id) - elif isinstance(node.ctx, ast.Store): + if isinstance(node.ctx, ast.Store) or self.force_store: write_args.add(node.id) + def visit_Subscript(self, node): + # When subscript occurs on the lhs of an assignment, the `Name` is still a load, but `Subscript` is marked as `Store`. + # We need to force the store for the `Name` to be marked as write. + if isinstance(node.ctx, ast.Store): + self.force_store = True + self.visit(node.value) + self.force_store = False + self.visit(node.slice) + else: + self.generic_visit(node) + + def visit_Assign(self, node): + self.force_store = True + [self.visit(target) for target in node.targets] + self.force_store = False + self.visit(node.value) + + def visit_AugAssign(self, node): + self.force_store = True + self.visit(node.target) + self.force_store = False + self.visit(node.value) + @staticmethod def get_call_base(func_node): if isinstance(func_node, ast.Attribute): @@ -454,21 +547,20 @@ class DSLPreprocessor(ast.NodeTransformer): # Classes are mutable by default. Mark them as write. If they are # dataclass(frozen=True), treat them as read in runtime. if base_name is not None and base_name not in ("self"): - write_args.add(base_name) + invoked_args.add(base_name) self.generic_visit(node) analyzer = RegionAnalyzer() analyzer.visit(ast.Module(body=node)) - # Argument can be Load and Store. We should just mark it as Store. - read_args = read_args - write_args + # If arg is both write and invoke, remove from invoked_args + invoked_args = invoked_args - write_args - used_args = read_args.intersections(active_symbols) - iter_args = write_args.intersections(active_symbols) - flattend_args = used_args | iter_args + write_args = list(write_args.intersections(active_symbols)) + invoked_args = list(invoked_args.intersections(active_symbols)) - return list(used_args), list(iter_args), list(flattend_args) + return write_args + invoked_args, len(write_args) def extract_range_args(self, iter_node): args = iter_node.args @@ -500,9 +592,24 @@ class DSLPreprocessor(ast.NodeTransformer): keywords.get("unroll_full", ast.Constant(value=False)), ) - def extract_pipelining_args(self, iter_node): + def issue_deprecation_warning(self, *, message, category, filename, lineno): + warnings.simplefilter("always", category) # turn off filter + warnings.warn_explicit( + message, category=category, filename=filename, lineno=lineno + ) + warnings.simplefilter("default", category) # reset filter + + def extract_prefetch_stages_args(self, iter_node): keywords = {kw.arg: kw.value for kw in iter_node.keywords} - return keywords.get("pipelining", ast.Constant(value=None)) + if "pipelining" in keywords: + self.issue_deprecation_warning( + message="pipelining is deprecated, use prefetch_stages instead", + category=DeprecationWarning, + filename=self.file_name, + lineno=iter_node.lineno, + ) + return keywords.get("pipelining", ast.Constant(value=None)) + return keywords.get("prefetch_stages", ast.Constant(value=None)) def create_loop_function( self, @@ -513,17 +620,16 @@ class DSLPreprocessor(ast.NodeTransformer): step, unroll, unroll_full, - pipelining, - used_args, - iter_args, - flattened_args, + prefetch_stages, + write_args, + full_write_args_count, ): """ Creates a loop body function with the `loop_selector` decorator. """ func_args = [ast.arg(arg=node.target.id, annotation=None)] - func_args += [ast.arg(arg=var, annotation=None) for var in flattened_args] + func_args += [ast.arg(arg=var, annotation=None) for var in write_args] # Create the loop body transformed_body = [] @@ -535,13 +641,13 @@ class DSLPreprocessor(ast.NodeTransformer): transformed_body.append(transformed_stmt) # Handle the return for a single iterated argument correctly - if len(iter_args) == 0: + if len(write_args) == 0: transformed_body.append(ast.Return()) else: transformed_body.append( ast.Return( value=ast.List( - elts=[ast.Name(id=var, ctx=ast.Load()) for var in iter_args], + elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args], ctx=ast.Load(), ) ) @@ -550,30 +656,33 @@ class DSLPreprocessor(ast.NodeTransformer): # Define the decorator with parameters decorator = ast.copy_location( ast.Call( - func=ast.Name(id=self.DECORATOR_FOR_STATEMENT, ctx=ast.Load()), + func=self._create_module_attribute( + self.DECORATOR_FOR_STATEMENT, + lineno=node.lineno, + col_offset=node.col_offset, + ), args=[start, stop, step], keywords=[ ast.keyword(arg="unroll", value=unroll), ast.keyword(arg="unroll_full", value=unroll_full), - ast.keyword(arg="pipelining", value=pipelining), + ast.keyword(arg="prefetch_stages", value=prefetch_stages), ast.keyword( - arg="used_args", - value=ast.List( - elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in used_args], - ctx=ast.Load(), - ), - ), - ast.keyword( - arg="iter_args", - value=ast.List( - elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in iter_args], - ctx=ast.Load(), + arg="write_args", + value=ast.List( + elts=[ + ast.Name(id=arg, ctx=ast.Load()) for arg in write_args + ], + ctx=ast.Load(), ), ), ast.keyword( - arg="iter_arg_names", + arg="full_write_args_count", + value=ast.Constant(value=full_write_args_count), + ), + ast.keyword( + arg="write_args_names", value=ast.List( - elts=[ast.Constant(value=arg) for arg in iter_args], + elts=[ast.Constant(value=arg) for arg in write_args], ctx=ast.Load(), ), ), @@ -635,7 +744,14 @@ class DSLPreprocessor(ast.NodeTransformer): # else # return and_(lhs, rhs) short_circuit_value = ast.Constant(value=False) - helper_func = ast.Name(id="and_", ctx=ast.Load()) + helper_func = self._create_module_attribute( + "and_", + top_module_name="cutlass", + submodule_name=None, + lineno=node.lineno, + col_offset=node.col_offset, + ) + self.import_top_module = True # Transform "or" to "or_" elif isinstance(node.op, ast.Or): # Create an if-else statement in AST form @@ -644,7 +760,14 @@ class DSLPreprocessor(ast.NodeTransformer): # else # return or_(lhs, rhs) short_circuit_value = ast.Constant(value=True) - helper_func = ast.Name(id="or_", ctx=ast.Load()) + helper_func = self._create_module_attribute( + "or_", + top_module_name="cutlass", + submodule_name=None, + lineno=node.lineno, + col_offset=node.col_offset, + ) + self.import_top_module = True else: # BoolOp should be either And or Or raise DSLAstPreprocessorError( @@ -696,22 +819,30 @@ class DSLPreprocessor(ast.NodeTransformer): # Transform "not" to "~" as we overload __invert__ if isinstance(node.op, ast.Not): - func_name = ast.Name(id="not_", ctx=ast.Load()) + func_name = self._create_module_attribute( + "not_", + top_module_name="cutlass", + submodule_name=None, + lineno=node.lineno, + col_offset=node.col_offset, + ) + self.import_top_module = True return ast.copy_location( ast.Call(func=func_name, args=[node.operand], keywords=[]), node ) return node - @staticmethod - def _insert_range_value_check(node): + def _insert_range_value_check(self, node): """ Insert a check for range arguments """ range_inputs = node.iter.args check_call = ast.copy_location( ast.Call( - func=ast.Name(id="range_value_check", ctx=ast.Load()), + func=self._create_module_attribute( + "range_value_check", lineno=node.lineno, col_offset=node.col_offset + ), args=range_inputs, keywords=[], ), @@ -726,40 +857,60 @@ class DSLPreprocessor(ast.NodeTransformer): node.iter, ) + def _insert_cf_symbol_check(self, func): + """ + Insert a check for range symbol + """ + check_call = ast.copy_location( + ast.Call( + func=self._create_module_attribute( + "cf_symbol_check", lineno=func.lineno, col_offset=func.col_offset + ), + args=[deepcopy(func)], + keywords=[], + ), + func, + ) + return ast.Expr(check_call) + def visit_For(self, node): + # For static for loop (for with range_constexpr or not range based for), preprocessor keeps the loop. + range_kind, is_builtin_range, has_keyword = self._get_range_kind(node.iter) + if range_kind == "range_constexpr" or range_kind == None: + self.generic_visit(node) + if range_kind == "range_constexpr": + check_call = self._insert_cf_symbol_check(node.iter.func) + # Rewrite range_constexpr to range + node.iter.func = ast.Name(id="range", ctx=ast.Load()) + self._insert_range_value_check(node) + return [check_call, node] + return node + active_symbols = self.scope_manager.get_active_symbols() with self.scope_manager: if isinstance(node.target, ast.Name): self.scope_manager.add_to_scope(node.target.id) - # For static for loop (for with range_constexpr or not range based for), preprocessor keeps the loop. - range_kind, is_builtin_range, has_keyword = self._get_range_kind(node.iter) - if range_kind == "range_constexpr" or range_kind == None: - self.generic_visit(node) - if range_kind == "range_constexpr": - # Rewrite range_constexpr to range - node.iter.func.id = "range" - self._insert_range_value_check(node) - return node - if range_kind == "range_dynamic": # Generate a warning - warnings.simplefilter("always", DeprecationWarning) # turn off filter - warnings.warn_explicit( - "range_dynamic is deprecated and will be removed in the future, please remove it.", + self.issue_deprecation_warning( + message="range_dynamic is deprecated and will be removed in the future, please remove it.", category=DeprecationWarning, filename=self.file_name, lineno=node.iter.lineno, ) - warnings.simplefilter("default", DeprecationWarning) # reset filter warning_call = None if range_kind == "range" and is_builtin_range and not has_keyword: # Warn about possible performance regression due to behavior change warning_call = ast.Expr( ast.Call( - func=ast.Name(id="range_perf_warning", ctx=ast.Load()), + func=self._create_module_attribute( + "range_perf_warning", + lineno=node.lineno, + col_offset=node.col_offset, + ), args=[ ast.Constant(value=self.file_name), ast.Constant(value=node.iter.lineno), @@ -770,7 +921,19 @@ class DSLPreprocessor(ast.NodeTransformer): ) ast.copy_location(warning_call, node.iter) + is_prefixed_range = range_kind == "range" and not is_builtin_range + check_call = None + if range_kind == "range_dynamic" or is_prefixed_range: + # Insert a check for range symbol + if not is_prefixed_range: + check_call = self._insert_cf_symbol_check(node.iter.func) + else: + # Get toplevel module + check_call = self._insert_cf_symbol_check(node.iter.func.value) + new_for_node = self.transform_for_loop(node, active_symbols) + if check_call is not None: + new_for_node = [check_call] + new_for_node return new_for_node if warning_call is None else [warning_call] + new_for_node @@ -946,12 +1109,12 @@ class DSLPreprocessor(ast.NodeTransformer): start_expr, stop_expr, step_expr, has_step = self.extract_range_args(node.iter) unroll, unroll_full = self.extract_unroll_args(node.iter) - pipelining = self.extract_pipelining_args(node.iter) - used_args, iter_args, flat_args = self.analyze_region_variables( + prefetch_stages = self.extract_prefetch_stages_args(node.iter) + write_args, full_write_args_count = self.analyze_region_variables( node, active_symbols ) - if has_step: + if has_step and self.client_module_name[0] == "cutlass": start, stop, step, exprs = self._handle_negative_step( node, start_expr, stop_expr, step_expr ) @@ -972,13 +1135,12 @@ class DSLPreprocessor(ast.NodeTransformer): step, unroll, unroll_full, - pipelining, - used_args, - iter_args, - flat_args, + prefetch_stages, + write_args, + full_write_args_count, ) - assign = ast.copy_location(self.create_loop_call(func_name, iter_args), node) + assign = ast.copy_location(self.create_loop_call(func_name, write_args), node) # This should work fine as it modifies the AST structure exprs = exprs + [func_def, assign] @@ -997,10 +1159,6 @@ class DSLPreprocessor(ast.NodeTransformer): return exprs - def visit_Name(self, node): - self.generic_visit(node) - return node - def visit_Assert(self, node): test = self.visit(node.test) @@ -1012,7 +1170,9 @@ class DSLPreprocessor(ast.NodeTransformer): # Rewrite to assert_executor(test, msg) new_node = ast.Expr( ast.Call( - func=ast.Name(id=self.ASSERT_EXECUTOR, ctx=ast.Load()), + func=self._create_module_attribute( + self.ASSERT_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset + ), args=[], keywords=args, ) @@ -1024,7 +1184,9 @@ class DSLPreprocessor(ast.NodeTransformer): def visit_Call(self, node): func = node.func - self.generic_visit(node) + # Visit args and kwargs + node.args = [self.visit(arg) for arg in node.args] + node.keywords = [self.visit(kwarg) for kwarg in node.keywords] # Rewrite call to some built-in functions if isinstance(func, ast.Name): @@ -1032,7 +1194,11 @@ class DSLPreprocessor(ast.NodeTransformer): if func.id == "bool": return ast.copy_location( ast.Call( - func=ast.Name(id=self.BOOL_CAST, ctx=ast.Load()), + func=self._create_module_attribute( + self.BOOL_CAST, + lineno=node.lineno, + col_offset=node.col_offset, + ), args=[node.args[0]], keywords=[], ), @@ -1044,18 +1210,38 @@ class DSLPreprocessor(ast.NodeTransformer): ) return ast.copy_location( ast.Call( - func=ast.Name(id=helper_func, ctx=ast.Load()), + func=self._create_module_attribute( + helper_func, lineno=node.lineno, col_offset=node.col_offset + ), args=[node.args[0]], keywords=[], ), node, ) + elif func.id in ["min", "max"]: + return ast.copy_location( + ast.Call( + func=self._create_module_attribute( + func.id, + top_module_name="cutlass", + submodule_name=None, + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[node.args[0], node.args[1]], + keywords=[], + ), + node, + ) elif isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name): def create_downcast_call(arg): return ast.copy_location( ast.Call( - func=ast.Name( - id=self.IMPLICIT_DOWNCAST_NUMERIC_TYPE, ctx=ast.Load() + func=self._create_module_attribute( + self.IMPLICIT_DOWNCAST_NUMERIC_TYPE, + submodule_name="typing", + lineno=node.lineno, + col_offset=node.col_offset, ), args=[arg], keywords=[], @@ -1084,6 +1270,8 @@ class DSLPreprocessor(ast.NodeTransformer): return ast.copy_location( ast.Call(func=func, args=args, keywords=kwargs), node ) + else: + node.func = self.visit(node.func) return node @@ -1113,9 +1301,24 @@ class DSLPreprocessor(ast.NodeTransformer): return node def visit_Name(self, node): - self.generic_visit(node) - if node.id == "_" and isinstance(node.ctx, ast.Load): + isLoad = isinstance(node.ctx, ast.Load) + if node.id in ["max", "min", "any", "all"] and isLoad: + return ast.copy_location( + ast.Call( + func=self._create_module_attribute( + "redirect_builtin_function", + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[node], + keywords=[], + ), + node, + ) + elif node.id == "_" and isLoad: raise DSLAstPreprocessorError("Read '_' is not allowed") + else: + self.generic_visit(node) return node def check_decorator(self, node: ast.AST) -> bool: @@ -1205,28 +1408,29 @@ class DSLPreprocessor(ast.NodeTransformer): return node def visit_While(self, node): - active_symbols = self.scope_manager.get_active_symbols() - # print(active_symbols) - with self.scope_manager: - # Constexpr doesn't get preprocessed - if self.is_node_constexpr(node): - self.generic_visit(node) - return node + # Constexpr doesn't get preprocessed + if self.is_node_constexpr(node): + self.generic_visit(node) + check = self._insert_cf_symbol_check(node.test.func) + return [check, node] + active_symbols = self.scope_manager.get_active_symbols() + + with self.scope_manager: # Check for early exit and raise exception self.check_early_exit(node, "while") - used_args, yield_args, flat_args = self.analyze_region_variables( + write_args, full_write_args_count = self.analyze_region_variables( node, active_symbols ) func_name = f"while_region_{self.counter}" self.counter += 1 func_def = self.create_while_function( - func_name, node, used_args, yield_args, flat_args + func_name, node, write_args, full_write_args_count ) assign = ast.copy_location( - self.create_loop_call(func_name, yield_args), node + self.create_loop_call(func_name, write_args), node ) return [func_def, assign] @@ -1243,7 +1447,7 @@ class DSLPreprocessor(ast.NodeTransformer): self.generic_visit(node) return node - def create_if_call(self, func_name, yield_args, flat_args): + def create_if_call(self, func_name, yield_args): """Creates the assignment statement for the if function call""" if not yield_args: return ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load())) @@ -1272,6 +1476,7 @@ class DSLPreprocessor(ast.NodeTransformer): # Emit # node if type(pred) == bool else select_(pred, body, orelse) # so if pred is a python bool, use python to short-circuit and avoid emit arith.select + self.import_top_module = True return ast.copy_location( ast.IfExp( test=ast.Compare( @@ -1285,7 +1490,9 @@ class DSLPreprocessor(ast.NodeTransformer): ), body=node, # Original ternary expression orelse=ast.Call( - func=ast.Name(id="select_", ctx=ast.Load()), + func=self._create_module_attribute( + "select_", top_module_name="cutlass", submodule_name=None + ), args=[ node.test, node.body, @@ -1330,7 +1537,7 @@ class DSLPreprocessor(ast.NodeTransformer): call = ast.copy_location( ast.Call( - func=ast.Name(id=self.COMPARE_EXECUTOR, ctx=ast.Load()), + func=self._create_module_attribute(self.COMPARE_EXECUTOR), args=[], keywords=keywords, ), @@ -1340,41 +1547,36 @@ class DSLPreprocessor(ast.NodeTransformer): return call def visit_If(self, node): + # const_expr doesn't get preprocessed + if self.is_node_constexpr(node): + self.generic_visit(node) + check = self._insert_cf_symbol_check(node.test.func) + return [check, node] + active_symbols = self.scope_manager.get_active_symbols() with self.scope_manager: - # non-dynamic expr doesn't get preprocessed - if self.is_node_constexpr(node): - self.generic_visit(node) - return node - # Check for early exit and raise exception self.check_early_exit(node, "if") - used_args, yield_args, flat_args = self.analyze_region_variables( + yield_args, full_write_args_count = self.analyze_region_variables( node, active_symbols ) func_name = f"if_region_{self.counter}" self.counter += 1 func_def = self.create_if_function( - func_name, node, used_args, yield_args, flat_args - ) - assign = ast.copy_location( - self.create_if_call(func_name, yield_args, flat_args), node + func_name, node, yield_args, full_write_args_count ) + assign = ast.copy_location(self.create_if_call(func_name, yield_args), node) return [func_def, assign] - def create_if_function( - self, func_name, node, used_args, yield_args, flattened_args - ): + def create_if_function(self, func_name, node, write_args, full_write_args_count): test_expr = self.visit(node.test) - pred_name = self.make_func_param_name("pred", flattened_args) + pred_name = self.make_func_param_name("pred", write_args) func_args = [ast.arg(arg=pred_name, annotation=None)] - func_args += [ast.arg(arg=var, annotation=None) for var in flattened_args] - func_args_then_else = [ - ast.arg(arg=var, annotation=None) for var in flattened_args - ] + func_args += [ast.arg(arg=var, annotation=None) for var in write_args] + func_args_then_else = [ast.arg(arg=var, annotation=None) for var in write_args] then_body = [] for stmt in node.body: @@ -1386,7 +1588,7 @@ class DSLPreprocessor(ast.NodeTransformer): # Create common return list for all blocks return_list = ast.List( - elts=[ast.Name(id=var, ctx=ast.Load()) for var in yield_args], + elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args], ctx=ast.Load(), ) @@ -1424,16 +1626,9 @@ class DSLPreprocessor(ast.NodeTransformer): arg="pred", value=test_expr ), # ast.Name(id="pred", ctx=ast.Load()) ast.keyword( - arg="used_args", + arg="write_args", value=ast.List( - elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in used_args], - ctx=ast.Load(), - ), - ), - ast.keyword( - arg="yield_args", - value=ast.List( - elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in yield_args], + elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args], ctx=ast.Load(), ), ), @@ -1442,7 +1637,11 @@ class DSLPreprocessor(ast.NodeTransformer): # Create decorator decorator = ast.copy_location( ast.Call( - func=ast.Name(id=self.DECORATOR_IF_STATEMENT, ctx=ast.Load()), + func=self._create_module_attribute( + self.DECORATOR_IF_STATEMENT, + lineno=node.lineno, + col_offset=node.col_offset, + ), args=[], keywords=decorator_keywords, ), @@ -1453,23 +1652,20 @@ class DSLPreprocessor(ast.NodeTransformer): execute_keywords = [ ast.keyword(arg="pred", value=ast.Name(id=pred_name, ctx=ast.Load())), ast.keyword( - arg="used_args", + arg="write_args", value=ast.List( - elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in used_args], + elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args], ctx=ast.Load(), ), ), ast.keyword( - arg="yield_args", - value=ast.List( - elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in yield_args], - ctx=ast.Load(), - ), + arg="full_write_args_count", + value=ast.Constant(value=full_write_args_count), ), ast.keyword( - arg="yield_arg_names", + arg="write_args_names", value=ast.List( - elts=[ast.Constant(value=arg) for arg in yield_args], + elts=[ast.Constant(value=arg) for arg in write_args], ctx=ast.Load(), ), ), @@ -1479,12 +1675,12 @@ class DSLPreprocessor(ast.NodeTransformer): ] # Handle different cases - if not yield_args and node.orelse == []: - # No yield_args case - only then_block needed + if not write_args and node.orelse == []: + # No write_args case - only then_block needed execute_call = ast.copy_location( ast.Call( - func=ast.copy_location( - ast.Name(id=self.IF_EXECUTOR, ctx=ast.Load()), node + func=self._create_module_attribute( + self.IF_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset ), args=[], keywords=execute_keywords, @@ -1501,7 +1697,7 @@ class DSLPreprocessor(ast.NodeTransformer): nested_if_name = elif_region_name # Recursion for nested elif nested_if = self.create_if_function( - nested_if_name, elif_node, used_args, yield_args, flattened_args + nested_if_name, elif_node, write_args, full_write_args_count ) else_block = ast.FunctionDef( name=else_block_name, @@ -1551,7 +1747,9 @@ class DSLPreprocessor(ast.NodeTransformer): execute_call = ast.copy_location( ast.Call( - func=ast.Name(id=self.IF_EXECUTOR, ctx=ast.Load()), + func=self._create_module_attribute( + self.IF_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset + ), args=[], keywords=execute_keywords, ), @@ -1573,66 +1771,61 @@ class DSLPreprocessor(ast.NodeTransformer): node, ) - def create_while_function( - self, func_name, node, used_args, yield_args, flattened_args - ): + def create_while_function(self, func_name, node, write_args, full_write_args_count): """Create a while function that looks like: - @while_selector(pred, used_args=[], yield_args=[]) - def while_region(pred, flattened_args): - def while_before_block(*used_args, *yield_args): + @while_selector(pred, write_args=[]) + def while_region(pred, write_args): + def while_before_block(*write_args): # Note that during eval of pred can possibly alter yield_args - return *pred, yield_args - def while_after_block(*used_args, yield_args): + return *pred, write_args + def while_after_block(*write_args): ...loop_body_transformed... - return yield_args - return self.while_executor(pred, used_args, yield_args, + return write_args + return self.while_executor(pred, write_args, while_before_block, while_after_block, constexpr) - yield_args = while_region(pred, flattened_args) + write_args = while_region(pred, write_args) Which will later be executed as psuedo-code: # Dynamic mode: - scf.WhileOp(types(yield_args), yield_args) + scf.WhileOp(types(write_args), write_args) with InsertionPoint(before_block): - cond, yield_args = while_before_block(*flattened_args) - scf.ConditionOp(cond, yield_args) + cond, write_args = while_before_block(*write_args) + scf.ConditionOp(cond, write_args) with InsertionPoint(after_block): - yield_args = while_after_block(yield_args) - scf.YieldOp(yield_args) + write_args = while_after_block(write_args) + scf.YieldOp(write_args) return while_op.results_ # Const mode: - cond, yield_args = while_before_block(yield_args) + cond, write_args = while_before_block(write_args) while pred: - yield_args = body_block(yield_args) - cond, yield_args = while_before_block(yield_args) - return yield_args + write_args = body_block(write_args) + cond, write_args = while_before_block(write_args) + return write_args """ test_expr = self.visit(node.test) - pred_name = self.make_func_param_name("pred", flattened_args) + pred_name = self.make_func_param_name("pred", write_args) # Section: decorator construction decorator_keywords = [ ast.keyword(arg="pred", value=test_expr), ast.keyword( - arg="used_args", + arg="write_args", value=ast.List( - elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in used_args], - ctx=ast.Load(), - ), - ), - ast.keyword( - arg="yield_args", - value=ast.List( - elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in yield_args], + elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args], ctx=ast.Load(), ), ), ] decorator = ast.copy_location( ast.Call( - func=ast.Name(id=self.DECORATOR_WHILE_STATEMENT, ctx=ast.Load()), + func=self._create_module_attribute( + self.DECORATOR_WHILE_STATEMENT, + lineno=node.lineno, + col_offset=node.col_offset, + ), args=[], keywords=decorator_keywords, ), @@ -1643,8 +1836,7 @@ class DSLPreprocessor(ast.NodeTransformer): while_before_block_name = f"while_before_block_{self.counter}" while_after_block_name = f"while_after_block_{self.counter}" self.counter += 1 - block_args_args = [ast.arg(arg=var, annotation=None) for var in used_args] - block_args_args += [ast.arg(arg=var, annotation=None) for var in yield_args] + block_args_args = [ast.arg(arg=var, annotation=None) for var in write_args] block_args = ast.arguments( posonlyargs=[], args=block_args_args, @@ -1654,7 +1846,7 @@ class DSLPreprocessor(ast.NodeTransformer): ) yield_args_ast_name_list = ast.List( - elts=[ast.Name(id=var, ctx=ast.Load()) for var in yield_args], + elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args], ctx=ast.Load(), ) @@ -1698,18 +1890,15 @@ class DSLPreprocessor(ast.NodeTransformer): execute_keywords = [ ast.keyword(arg="pred", value=ast.Name(id=pred_name, ctx=ast.Load())), ast.keyword( - arg="used_args", + arg="write_args", value=ast.List( - elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in used_args], + elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args], ctx=ast.Load(), ), ), ast.keyword( - arg="yield_args", - value=ast.List( - elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in yield_args], - ctx=ast.Load(), - ), + arg="full_write_args_count", + value=ast.Constant(value=full_write_args_count), ), ast.keyword( arg="while_before_block", @@ -1720,23 +1909,25 @@ class DSLPreprocessor(ast.NodeTransformer): value=ast.Name(id=while_after_block_name, ctx=ast.Load()), ), ast.keyword( - arg="yield_arg_names", + arg="write_args_names", value=ast.List( - elts=[ast.Constant(value=arg) for arg in yield_args], + elts=[ast.Constant(value=arg) for arg in write_args], ctx=ast.Load(), ), ), ] execute_call = ast.Call( - func=ast.Name(id=self.WHILE_EXECUTOR, ctx=ast.Load()), + func=self._create_module_attribute( + self.WHILE_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset + ), args=[], keywords=execute_keywords, ) # Putting everything together, FunctionDef for while_region func_args_args = [ast.arg(arg=pred_name, annotation=None)] - func_args_args += [ast.arg(arg=var, annotation=None) for var in flattened_args] + func_args_args += [ast.arg(arg=var, annotation=None) for var in write_args] func_args = ast.arguments( posonlyargs=[], args=func_args_args, diff --git a/python/CuTeDSL/base_dsl/cache_helpers.py b/python/CuTeDSL/base_dsl/cache_helpers.py index 8ea08874..5d9234f2 100644 --- a/python/CuTeDSL/base_dsl/cache_helpers.py +++ b/python/CuTeDSL/base_dsl/cache_helpers.py @@ -139,8 +139,7 @@ def dump_cache_to_path( dsl_name, jit_cache, cache_limit, path=default_generated_ir_path ): log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache)) - if not os.path.exists(path): - os.makedirs(path) + os.makedirs(path, exist_ok=True) original_path = os.getcwd() try: os.chdir(path) diff --git a/python/CuTeDSL/base_dsl/compiler.py b/python/CuTeDSL/base_dsl/compiler.py index 8776c91b..f8b2da07 100644 --- a/python/CuTeDSL/base_dsl/compiler.py +++ b/python/CuTeDSL/base_dsl/compiler.py @@ -205,6 +205,8 @@ class CompileOptions: self._parser.add_argument( "--enable-device-assertions", action="store_true", default=False ) + self._parser.add_argument("--link-libraries", type=str, default="") + try: self._options = self._parser.parse_args(options.split()) except SystemExit as e: diff --git a/python/CuTeDSL/base_dsl/dsl.py b/python/CuTeDSL/base_dsl/dsl.py index c6ece00c..2b17d22b 100644 --- a/python/CuTeDSL/base_dsl/dsl.py +++ b/python/CuTeDSL/base_dsl/dsl.py @@ -32,13 +32,14 @@ import hashlib from functools import lru_cache, wraps from collections import namedtuple from abc import ABC, abstractmethod -from typing import Any, Union, Tuple, get_origin, get_args -from types import FunctionType +from typing import Any, Union, Tuple, get_origin, get_args, List +from types import FunctionType, SimpleNamespace import warnings from . import typing as t from .env_manager import EnvironmentVarManager from .compiler import CompileOptions +from .ast_helpers import DSLOptimizationWarning # ============================================================================= # CUDA Python @@ -56,7 +57,7 @@ from .utils.timer import timer from .utils.logger import setup_log, log from .utils.stacktrace import filter_exception, walk_to_top_module, filter_stackframe from .runtime.jit_arg_adapters import is_argument_constexpr, JitArgAdapterRegistry -from .runtime.tensor_descriptor import TensorDescriptor + from .ast_preprocessor import DSLPreprocessor from .common import * from .typing import ( @@ -73,12 +74,6 @@ from .._mlir import runtime as rt from .._mlir.extras import types as T from .._mlir.dialects import arith, math, func -# ============================================================================= -# cutlass.dlpack_runtime -# ============================================================================= - -from .runtime.dlpack_runtime import dlpack_to_tensor_desc, mark_layout_dynamic - # ============================================================================= # Global Variables # ============================================================================= @@ -177,6 +172,7 @@ def is_dynamic_expression(value): return True return False + def extract_mlir_values(obj): """ Given the `obj`, recursively go through it to extract all contained IR values as list of MLIR values @@ -186,6 +182,10 @@ def extract_mlir_values(obj): res = obj.__extract_mlir_values__() elif isinstance(obj, (tuple, list)): res = sum((extract_mlir_values(x) for x in obj), []) + elif isinstance(obj, SimpleNamespace): + res = [] + for k, v in obj.__dict__.items(): + res.extend(extract_mlir_values(v)) # Can't call is_dynamic_expression as _is_dynamic_expression depends on extract_mlir_values elif isinstance(obj, set): raise DSLRuntimeError( @@ -215,6 +215,13 @@ def new_from_mlir_values(obj, values): values = values[n_items:] obj_ty = type(obj) return obj_ty(res) + elif isinstance(obj, SimpleNamespace): + res = SimpleNamespace() + for k, v in obj.__dict__.items(): + n_items = len(get_mlir_types(v)) + res.__dict__[k] = new_from_mlir_values(v, values[:n_items]) + values = values[n_items:] + return res elif isinstance(obj, set): raise DSLRuntimeError( "Sets are not supported in new_from_mlir_values to ensure order preservation", @@ -249,8 +256,6 @@ class DSLCallable: Methods: __call__(*args, **kwargs): Calls the wrapped function and clears it. - get_arg_spec(): Returns the argument specification of the function. - get_signature(): Returns the signature of the function. """ def __init__(self, func): @@ -266,23 +271,23 @@ class DSLCallable: assert self.func is not None, "DSLCallable is already called" return self.func + @property + def __signature__(self): + return inspect.signature(self.__func__) + @property def __name__(self): return self.__func__.__name__ - def get_arg_spec(self): - return inspect.getfullargspec(self.__func__) - - def get_signature(self): - return inspect.signature(self.__func__) - class BaseDSL: gpu_module = None def __init__( self, + *, name: str, + dsl_package_name: List[str], compiler_provider: Any, pass_sm_arch_name: str, device_compilation_only=False, @@ -293,6 +298,7 @@ class BaseDSL: Parameters: - name (str): Name of DSL, used for environment variables and logging. + - package_name (str): Name of the package, used for the preprocessor. - compiler_provider (MLIR dialect): Provider for compiler. - pass_sm_arch_name (str): The keyword name of the SM. - device_compilation_only (bool) : Only device code, and call it via cuda driver @@ -330,6 +336,9 @@ class BaseDSL: self.device_jit_decorator_name = f"@{BaseDSL.kernel.__name__}" # set warning + if not self.envar.enable_optimization_warnings: + # By default, optimization warnings are disabled + warnings.filterwarnings("ignore", category=DSLOptimizationWarning) if self.envar.warnings_as_errors: warnings.filterwarnings("error") if self.envar.warnings_ignore: @@ -355,7 +364,7 @@ class BaseDSL: self.compile_options = CompileOptions() if preprocess: - self.preprocessor = DSLPreprocessor() + self.preprocessor = DSLPreprocessor(dsl_package_name) log().info(f"Initializing {name} DSL") log().debug(f"Logger initialized for {self.name}") @@ -656,7 +665,7 @@ class BaseDSL: return ir_args, ir_kwargs @abstractmethod - def _generate_mlir_type_for_tensor_descriptor(self, tensor: TensorDescriptor): + def _generate_mlir_type_for_tensor_descriptor(self, tensor): """ Generate MLIR type for the tensor descriptor. """ @@ -671,13 +680,6 @@ class BaseDSL: """ pass - @abstractmethod - def _get_module_globals(self): - """ - Get the module's globals. - """ - pass - def _get_globals(self): """ Combines global and local variables from the current context and the @@ -690,43 +692,21 @@ class BaseDSL: AST preprocessor generates a new python code, so the resulting globals dictionary is used to execute the python code. """ - all_globals = self._get_module_globals().copy() + all_globals = {} if self.frame: all_globals.update(self.frame.f_globals) all_globals.update(self.frame.f_locals) return all_globals + @abstractmethod def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool: - return isinstance( - maybe_tensor_descriptor, TensorDescriptor - ) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor) + pass + @abstractmethod def _handle_tensor_descriptor( self, maybe_tensor, arg_name: str, need_gpu_memory: bool - ) -> TensorDescriptor: - if self._is_tensor_descriptor(maybe_tensor): - tensor = ( - maybe_tensor - if isinstance(maybe_tensor, TensorDescriptor) - else TensorDescriptor(maybe_tensor) - ) - if need_gpu_memory and not tensor.is_in_device: - log().info( - "FAIL name=[%s] tensor=[%s] in_gpu=[%s]", - arg_name, - tensor, - tensor.is_in_device, - ) - raise DSLRuntimeError( - f'Tensor "{arg_name}" is tensor "{tensor}" ' - "is not in the GPU memory. " - ) - - return tensor - - raise DSLRuntimeError( - f"Argument {arg_name} could not be transformed into a TensorDescriptor." - ) + ) -> Any: + pass def _validate_arg(self, arg, arg_index, arg_name, arg_spec): """ @@ -882,10 +862,11 @@ class BaseDSL: cluster: list = None grid: list = field(default_factory=lambda: [1, 1, 1]) block: list = field(default_factory=lambda: [1, 1, 1]) - smem: int = 0 + smem: int = None async_deps: list = field(default_factory=list) has_cluster: bool = False min_blocks_per_mp: int = 0 + auto_smem: bool = False def __post_init__(self): if len(self.grid) != 3: @@ -893,6 +874,10 @@ class BaseDSL: if len(self.block) != 3: raise DSLRuntimeError(f"Expect 3d block!") + if self.smem is None: + self.smem = 0 + self.auto_smem = True + self.has_cluster = self.cluster is not None if self.cluster is None: self.cluster = [None, None, None] @@ -1116,8 +1101,6 @@ class BaseDSL: try: result = funcBody(*ir_args, **ir_kwargs) func.ReturnOp([]) - except DSLAstPreprocessorError as pp_error: - raise pp_error except NameError as name_error: raise DSLRuntimeError( f"💥💥💥 Error during runtime code generation for function `{funcBody.__name__}` 💥💥💥", @@ -1127,11 +1110,6 @@ class BaseDSL: except DSLRuntimeError as dsl_error: # Throw it's already a DSL error raise dsl_error - except Exception as general_e: - # Transform internal error to a DSL error - raise DSLRuntimeError( - f"💥💥💥 Error during runtime code generation for function `{funcBody.__name__}` 💥💥💥" - ) from general_e return module, result # Build IR module @@ -1328,10 +1306,8 @@ class BaseDSL: raise DSLRuntimeError("Function body is not set.") # Pass the actual function object to inspect.signature to get the signature. - if isinstance(self.funcBody, DSLCallable): - sig = self.funcBody.get_signature() - else: - sig = inspect.signature(self.funcBody) + sig = inspect.signature(self.funcBody) + function_name = self.funcBody.__name__ bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs) @@ -1382,10 +1358,7 @@ class BaseDSL: # Check the number of arguments sig = self._check_arg_count(*args, **kwargs) - if isinstance(funcBody, DSLCallable): - args_spec = funcBody.get_arg_spec() - else: - args_spec = inspect.getfullargspec(funcBody) + args_spec = inspect.getfullargspec(funcBody) # Canonicalize the input arguments canonicalized_args, canonicalized_kwargs = self._canonicalize_args( @@ -1447,7 +1420,7 @@ class BaseDSL: return cuda_helpers.stream_create() def _execute_cuda( - self, fname_cubin, kernel_name, grid_size, block_size, stream=None + self, fname_cubin, kernel_name, grid_size, block_size, smem_size, stream=None ): """ Executes a specified CUDA kernel from a cubin file, handling module loading, @@ -1471,7 +1444,7 @@ class BaseDSL: grid_size, block_size, stream, - smem_size=16000, + smem_size=smem_size, kernel_args=self.exe_args, ) @@ -1480,7 +1453,13 @@ class BaseDSL: cuda_helpers.stream_sync(stream) def _execute_by_cuda_driver( - self, kernel_generator, generate_cubin, grid_size, block_size, stream=None + self, + kernel_generator, + generate_cubin, + grid_size, + block_size, + smem_size, + stream=None, ): """ This function builds IR and execute the module using cuda driver. @@ -1511,10 +1490,9 @@ class BaseDSL: fname_cubin = generate_cubin(module, kernel_name) # Execute a cuda kernel from cubin - if block_size is None: - # The TileIR driver should set this automatically. - block_size = self.block_size - self._execute_cuda(fname_cubin, kernel_name, grid_size, block_size, stream) + self._execute_cuda( + fname_cubin, kernel_name, grid_size, block_size, smem_size, stream + ) return ret @@ -1587,10 +1565,7 @@ class BaseDSL: kernelGenHelper = dkwargs.get("kernelGenHelper", None) kernel_name = funcBody.__name__ - if isinstance(funcBody, DSLCallable): - args_spec = funcBody.get_arg_spec() - else: - args_spec = inspect.getfullargspec(funcBody) + args_spec = inspect.getfullargspec(funcBody) self.funcBody = funcBody # Give each kernel a unique name. (The same kernel may be diff --git a/python/CuTeDSL/base_dsl/env_manager.py b/python/CuTeDSL/base_dsl/env_manager.py index 4a8f6591..fa683477 100644 --- a/python/CuTeDSL/base_dsl/env_manager.py +++ b/python/CuTeDSL/base_dsl/env_manager.py @@ -58,6 +58,11 @@ def get_int_env_var(var_name, default_value=0): return int(value) if value and value.isdigit() else default_value +@lru_cache(maxsize=None) +def has_env_var(var_name): + return os.getenv(var_name) is not None + + def detect_gpu_arch(prefix): """ Attempts to detect the machine's GPU architecture. @@ -256,6 +261,7 @@ class EnvironmentVarManager: - [DSL_NAME]_ARCH: GPU architecture (default: "sm_100") - [DSL_NAME]_WARNINGS_AS_ERRORS: Enable warnings as error (default: False) - [DSL_NAME]_WARNINGS_IGNORE: Ignore warnings (default: False) + - [DSL_NAME]_ENABLE_OPTIMIZATION_WARNINGS: Enable warnings of optimization warnings (default: False) - [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False) - [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False) - [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000) @@ -267,7 +273,6 @@ class EnvironmentVarManager: self.prefix = prefix # change if needed # Printing options - self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False) self.print_after_preprocessor = get_bool_env_var( f"{prefix}_PRINT_AFTER_PREPROCESSOR", False ) @@ -275,15 +280,29 @@ class EnvironmentVarManager: self.filterStacktrace = get_bool_env_var(f"{prefix}_FILTER_STACKTRACE", True) # File options self.keepIR = get_bool_env_var(f"{prefix}_KEEP_IR", False) + # Logging options + self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False) self.log_to_file = get_bool_env_var(f"{prefix}_LOG_TO_FILE", False) - # Other options + if ( + has_env_var(f"{prefix}_LOG_LEVEL") + and not self.log_to_console + and not self.log_to_file + ): + log().warning( + f"Log level was set, but neither logging to file ({prefix}_LOG_TO_FILE) nor logging to console ({prefix}_LOG_TO_CONSOLE) is enabled!" + ) self.log_level = get_int_env_var(f"{prefix}_LOG_LEVEL", 1) + + # Other options self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False) self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix)) self.warnings_as_errors = get_bool_env_var( f"{prefix}_WARNINGS_AS_ERRORS", False ) self.warnings_ignore = get_bool_env_var(f"{prefix}_WARNINGS_IGNORE", False) + self.enable_optimization_warnings = get_bool_env_var( + f"{prefix}_ENABLE_OPTIMIZATION_WARNINGS", False + ) self.jitTimeProfiling = get_bool_env_var(f"{prefix}_JIT_TIME_PROFILING", False) self.disable_file_caching = get_bool_env_var( f"{prefix}_DISABLE_FILE_CACHING", False diff --git a/python/CuTeDSL/base_dsl/runtime/__init__.py b/python/CuTeDSL/base_dsl/runtime/__init__.py index 6f8e2feb..ccc475fd 100644 --- a/python/CuTeDSL/base_dsl/runtime/__init__.py +++ b/python/CuTeDSL/base_dsl/runtime/__init__.py @@ -14,16 +14,12 @@ This module provides a runtime utility functions that are needed for the DSL. """ -from . import device_tensor from . import dlpack_types from . import cuda -from . import tensor_descriptor from . import jit_arg_adapters __all__ = [ - "device_tensor", "dlpack_types", "cuda", - "tensor_descriptor", "jit_arg_adapters", ] diff --git a/python/CuTeDSL/base_dsl/runtime/cuda.py b/python/CuTeDSL/base_dsl/runtime/cuda.py index 278a5118..c2ad2203 100644 --- a/python/CuTeDSL/base_dsl/runtime/cuda.py +++ b/python/CuTeDSL/base_dsl/runtime/cuda.py @@ -309,7 +309,7 @@ def get_kernel_function(module, kernel_name): return kernel -def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size=0, kernel_args=None): +def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size, kernel_args=None): """ Launches the CUDA kernel. """ diff --git a/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py b/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py index b09d2fcb..1a992ef6 100644 --- a/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py +++ b/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py @@ -183,6 +183,13 @@ class TensorDescriptor: """ return self.device_type == _dpack.DLDeviceType.kDLGPU + @staticmethod + def is_compatible(maybe_tensor_descriptor) -> bool: + """Check if the object is a TensorDescriptor or can be converted to one.""" + return isinstance( + maybe_tensor_descriptor, TensorDescriptor + ) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor) + def from_tensor(tensor) -> TensorDescriptor: """Create a TensorDescriptor from a tensor object.""" @@ -192,10 +199,3 @@ def from_tensor(tensor) -> TensorDescriptor: def to_tensor(tensor_descriptor: TensorDescriptor): """Return tensor object from tensor descriptor.""" return tensor_descriptor.tensor - - -def is_tensor_descriptor(maybe_tensor_descriptor) -> bool: - """Check if the object is a TensorDescriptor.""" - return isinstance( - maybe_tensor_descriptor, TensorDescriptor - ) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor) diff --git a/python/CuTeDSL/cutlass/cute/__init__.py b/python/CuTeDSL/cutlass/cute/__init__.py index 9076fcb0..8702ed91 100644 --- a/python/CuTeDSL/cutlass/cute/__init__.py +++ b/python/CuTeDSL/cutlass/cute/__init__.py @@ -126,6 +126,7 @@ from .core import ( basic_copy_if, autovec_copy, copy, + copy_atom_call, gemm, # Wrapper classes ComposedLayout, @@ -290,6 +291,7 @@ __all__ = [ "basic_copy_if", "autovec_copy", "copy", + "copy_atom_call", "gemm", # Tensor creation "full", diff --git a/python/CuTeDSL/cutlass/cute/arch/mbar.py b/python/CuTeDSL/cutlass/cute/arch/mbar.py index 8a6e3cfb..80cb7b0b 100644 --- a/python/CuTeDSL/cutlass/cute/arch/mbar.py +++ b/python/CuTeDSL/cutlass/cute/arch/mbar.py @@ -315,3 +315,35 @@ def mbarrier_arrive( loc=loc, ip=ip, ) + + +@dsl_user_op +def cp_async_mbarrier_arrive_noinc(mbar_ptr: Pointer, *, loc=None, ip=None) -> None: + """ + Arrives on an mbarrier for async load **without incrementing** the arrival count + (`cp.async.mbarrier.arrive.shared ..., noinc=1`). + Used in the warp-specialized kernel when the non-TMA load warp(producer) is not the same + as the math/epilogue warp(consumer). + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + + mbar_llvm_ptr = mbar_ptr.llvm_ptr + nvvm.cp_async_mbarrier_arrive_shared( + mbar_llvm_ptr, + noinc=True, + loc=loc, + ip=ip, + ) diff --git a/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py b/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py index 80a4c1d0..69e3b8ac 100644 --- a/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py +++ b/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py @@ -11,6 +11,7 @@ from functools import partial from typing import Optional, Tuple, Union, Callable +from typing_extensions import deprecated from cutlass.cutlass_dsl import T, dsl_user_op @@ -642,6 +643,9 @@ def rcp_approx(a: Union[float, Float32], *, loc=None, ip=None): @dsl_user_op +@deprecated( + "cute.arch.exp2 is deprecated, use cute.math.exp2 with `fastmath=True` instead" +) def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32: return Float32( llvm.inline_asm( @@ -656,15 +660,19 @@ def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32: ) -# TODO: add `fastmath` flag for this op @dsl_user_op +@deprecated( + "cute.arch.exp is deprecated, use cute.math.exp with `fastmath=True` instead" +) def exp(a: Union[float, Float32], *, loc=None, ip=None) -> Float32: LOG2_E = 1.4426950408889634 return exp2(a * LOG2_E, loc=loc, ip=ip) -# TODO: add `fastmath` flag for this op @dsl_user_op +@deprecated( + "cute.arch.exp_packed_f32x2 is deprecated, use cute.arch.mul_packed_f32x2 and cute.math.exp2 with `fastmath=True` instead" +) def exp_packed_f32x2( a: Tuple[Float32, Float32], *, loc=None, ip=None ) -> Tuple[Float32, Float32]: diff --git a/python/CuTeDSL/cutlass/cute/core.py b/python/CuTeDSL/cutlass/cute/core.py index e3f6b1e7..12d5e422 100644 --- a/python/CuTeDSL/cutlass/cute/core.py +++ b/python/CuTeDSL/cutlass/cute/core.py @@ -31,7 +31,6 @@ from typing import ( Optional, ) from enum import Enum, auto -from typing_extensions import deprecated from cutlass.cutlass_dsl import ( const, @@ -1662,7 +1661,9 @@ class _Tensor(Tensor): @dsl_user_op -def print_tensor(tensor: Tensor, *, verbose: bool = False, loc=None, ip=None): +def print_tensor( + tensor: Union[Tensor, "TensorSSA"], *, verbose: bool = False, loc=None, ip=None +): """Print content of the tensor in human readable format. Outputs the tensor data in a structured format showing both metadata @@ -1693,6 +1694,11 @@ def print_tensor(tensor: Tensor, *, verbose: bool = False, loc=None, ip=None): [ 0.9159, 0.7577, 0.6918, 0.0754, 0.0591], [ 0.6551, 0.1626, 0.1189, 0.0292, 0.8655]]) """ + if isinstance(tensor, TensorSSA): + tmp = make_fragment(tensor.shape, tensor.dtype) + tmp.store(tensor) + tensor = tmp + if not isinstance(tensor.type, _cute_ir.MemRefType): raise NotImplementedError( f"printing {tensor} is not supported because it doesn't support trivial dereferencing. " @@ -1769,7 +1775,7 @@ def is_static(x: Union[ir.Type, ir.Value, XTuple]) -> bool: return False elif is_dynamic_expression(x): return _cute_ir.is_static(x.type) - elif isinstance(x, int) or x is None: + elif isinstance(x, (bool, int, float)) or x is None: return True elif isinstance(x, ScaledBasis): return x.is_static() @@ -2241,7 +2247,7 @@ def is_weakly_congruent( * X is a non-tuple value, OR * X and Y are both tuples of the same rank AND all corresponding elements are weakly congruent. - Weak congruence allows scalar values to match with tuples, making it useful + Weak congruence allows scalar values to match with tuples, making it useful for determining whether an object has a hierarchical structure "up to" another. :param a: First object to compare @@ -2921,33 +2927,46 @@ def flatten_to_tuple(a: Union[IntTuple, Coord, Shape, Stride]) -> tuple: return tuple(chain.from_iterable(tuple(flatten_to_tuple(x) for x in a))) -def flatten(a: Union[IntTuple, Coord, Shape, Stride, Layout, Tensor]) -> tuple: +@overload +def flatten(a: Union[IntTuple, Coord, Shape, Stride]) -> IntTuple: ... +@overload +def flatten(a: Tensor) -> Tensor: ... +@overload +def flatten(a: Layout) -> Layout: ... + + +def flatten(a): """Flattens a CuTe data structure into a simpler form. For tuples, this function flattens the structure into a single-level tuple. - For non-tuple types, it returns the input unchanged. + For layouts, it returns a new layout with flattened shape and stride. + For tensors, it returns a new tensor with flattened layout. + For other types, it returns the input unchanged. :param a: The structure to flatten :type a: Union[IntTuple, Coord, Shape, Stride, Layout, Tensor] :return: The flattened structure :rtype: Union[tuple, Any] - :raises NotImplementedError: If input is a Layout or Tensor **Examples:** .. code-block:: python - flatten((1, 2, 3)) # Returns (1, 2, 3) - flatten(((1, 2), (3, 4))) # Returns (1, 2, 3, 4) - flatten(5) # Returns 5 - """ - if isinstance(a, (Layout, Tensor)): - raise NotImplementedError("flatten layout and tensor is not supported") + flatten((1, 2, 3)) # Returns (1, 2, 3) + flatten(((1, 2), (3, 4))) # Returns (1, 2, 3, 4) + flatten(5) # Returns 5 + flatten(Layout(shape, stride)) # Returns Layout(flatten(shape), flatten(stride)) + flatten(Tensor(layout)) # Returns Tensor(flatten(layout)) - if not isinstance(a, tuple): - return a - else: + """ + if isinstance(a, Tensor): + return make_tensor(a.iterator, flatten(a.layout)) + elif isinstance(a, Layout): + return make_layout(flatten(a.shape), stride=flatten(a.stride)) + elif isinstance(a, tuple): return flatten_to_tuple(a) + else: + return a def unflatten( @@ -4120,14 +4139,14 @@ def complement( @dsl_user_op def right_inverse(input: Layout, *, loc=None, ip=None) -> Layout: if not isinstance(input, Layout): - raise TypeError(f"expects input of type Layout, but got {type(Layout)}") + raise TypeError(f"expects input of type Layout, but got {type(input)}") return _cute_ir.right_inverse(input=input, loc=loc, ip=ip) @dsl_user_op def left_inverse(input: Layout, *, loc=None, ip=None) -> Layout: if not isinstance(input, Layout): - raise TypeError(f"expects input of type Layout, but got {type(Layout)}") + raise TypeError(f"expects input of type Layout, but got {type(input)}") return _cute_ir.left_inverse(input=input, loc=loc, ip=ip) @@ -5156,7 +5175,6 @@ def _make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): return TiledCopy(atom.op, trait) -@deprecated("Use make_tiled_copy_tv instead") def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): """Create a tiled type given a TV partitioner and tiler. @@ -5434,6 +5452,14 @@ def gemm( For MMA Atoms that require single-threaded execution, the gemm op automatically handles thread election internally. Manual thread selection is not required in such cases. + Following dispatch rules are supported: + + - Dispatch [1]: (V) x (V) => (V) => (V,1,1) x (V,1,1) => (V,1,1) + - Dispatch [2]: (M) x (N) => (M,N) => (1,M,1) x (1,N,1) => (1,M,N) + - Dispatch [3]: (M,K) x (N,K) => (M,N) => (1,M,K) x (1,N,K) => (1,M,N) + - Dispatch [4]: (V,M) x (V,N) => (V,M,N) => (V,M,1) x (V,N,1) => (V,M,N) + - Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) + :param atom: MMA atom :type atom: MmaAtom :param d: Destination tensor @@ -5454,6 +5480,27 @@ def gemm( :rtype: None """ + a_rank = rank(a.shape) + b_rank = rank(b.shape) + c_rank = rank(c.shape) + d_rank = rank(d.shape) + + if a_rank != b_rank: + raise ValueError("`a` and `b` must have the same rank") + + if c_rank != d_rank: + raise ValueError("`c` and `d` must have the same rank") + + if a_rank == 1: + if c_rank > 2: + raise ValueError("`c` must have rank <= 2 when `a` has rank 1") + elif a_rank == 2: + if c_rank not in (2, 3): + raise ValueError("`c` must have rank 2 or 3 when `a` has rank 2") + elif a_rank == 3: + if c_rank != 3: + raise ValueError("`c` must have rank 3 when `a` has rank 3") + value = atom._unpack(loc=loc, ip=ip, **kwargs) return _cute_ir.gemm(value, d.value, a.value, b.value, c.value, loc=loc, ip=ip) @@ -5645,6 +5692,76 @@ def copy( return _cute_ir.copy(value, src.value, dst.value, pred=pred, loc=loc, ip=ip) +@dsl_user_op +def copy_atom_call( + atom: CopyAtom, + src: Tensor, + dst: Tensor, + *, + pred: Optional[Tensor] = None, + loc=None, + ip=None, + **kwargs, +) -> None: + """ + Execute a single copy atom operation. + + The copy_atom_call operation executes a copy atom with the given operands. + Following src/dst layout of atom are valid: + * ((atom_v)) + * (atom_v) + + Note: The format ((atom_v, rest_v)) is NOT valid for copy_atom_call since it would + require multiple atom operations, which contradicts the definition of a single copy atom call. + + Examples: + + .. code-block:: python + + # Call a copy atom operation + cute.copy_atom_call(copy_atom, src_tensor, dst_tensor) + + An additional predication tensor can be provided. If the partitioned tensors have the following + logical profile ``((ATOM_V,ATOM_REST),REST_M,...)``, the predication tensor must have a profile + consistent with ``(ATOM_REST,REST_M,...)``. + """ + if isinstance(src.type, _cute_ir.MemRefType) and isinstance( + dst.type, _cute_ir.MemRefType + ): + if src.element_type.width != dst.element_type.width: + raise TypeError( + "`copy_atom_call` currently only supports equal source and destination " + "element type bit width" + ) + + value = atom._unpack(loc=loc, ip=ip, **kwargs) + if isinstance(pred, Tensor): + pred = pred.value + return _cute_ir.copy_atom_call( + value, src.value, dst.value, pred=pred, loc=loc, ip=ip + ) + + +def prefetch(atom: CopyAtom, src: Tensor, *, loc=None, ip=None) -> None: + """ + The Prefetch algorithm. + + The "prefetch" expects source tensors to be partitioned according to the provided Copy Atom. + Prefetch is used for loading tensors from global memory to L2. + + Prefetch accepts Copy Atom but not all are allowed. Currently, only support for tma load tensor prefetch. + + .. code-block:: python + + cute.prefetch(tma_atom, src) + + For Copy Atoms that require single-threaded execution, the copy op automatically handles thread + election internally. Manual thread selection is not required in such cases. + """ + dummy_tma_bar_ptr = make_ptr(Int64, 0, AddressSpace.smem, loc=loc, ip=ip) + value = atom._unpack(loc=loc, ip=ip, tma_bar_ptr=dummy_tma_bar_ptr) + return _cute_ir.prefetch(value, src.value, loc=loc, ip=ip) + #################################################################################################### # # TensorSSA class (experimental) @@ -5657,6 +5774,11 @@ class ReductionOp(Enum): MUL = auto() MAX = auto() MIN = auto() + INC = auto() + DEC = auto() + AND = auto() + OR = auto() + XOR = auto() def __str__(self): return self.name.lower() @@ -5697,6 +5819,7 @@ class TensorSSA(cutlass_arith.ArithValue): self._shape = shape self._dtype = dtype + self._layout = None @property def dtype(self) -> Type[Numeric]: @@ -5776,13 +5899,26 @@ class TensorSSA(cutlass_arith.ArithValue): ): res_type = Boolean - if lhs.shape != rhs.shape: - raise ValueError( - f"lhs and rhs must have the same shape type, but got {lhs.shape} and {rhs.shape}" - ) + assert isinstance(rhs, TensorSSA), f"rhs must be TensorSSA but got {rhs}" - if not isinstance(rhs, TensorSSA): - raise TypeError(f"rhs must be TensorSSA but got {rhs}") + def _broadcast(s, t): + if s == 1: + return t + elif t == 1: + return s + elif s == t: + return s + else: + raise ValueError(f"cannot broadcast {s} and {t}") + + max_rank = max(rank(lhs.shape), rank(rhs.shape)) + lhs_shape = append(lhs.shape, 1, up_to_rank=max_rank) + rhs_shape = append(rhs.shape, 1, up_to_rank=max_rank) + res_shape = transform_leaf(_broadcast, lhs_shape, rhs_shape) + + # broadcast to the same shape + lhs = lhs.broadcast_to(res_shape) + rhs = rhs.broadcast_to(res_shape) if ( op in (operator.add, operator.sub) @@ -5807,6 +5943,38 @@ class TensorSSA(cutlass_arith.ArithValue): return res + def broadcast_to(self, target_shape: Shape, *, loc=None, ip=None) -> "TensorSSA": + """ + Broadcast the tensor to the target shape. + """ + # pad source shape to the same rank + shape = append(self.shape, 1, up_to_rank=rank(target_shape)) + if shape == target_shape: + return self + + def _check_broadcast(s, t): + if s != t and s != 1: + raise ValueError( + f"src_shape and target_shape must be the same when src_shape is not 1, but got {s} and {t}" + ) + + transform_leaf(_check_broadcast, shape, target_shape) + + # reshape to flatten N-D vector + flat_shp = flatten_to_tuple(shape) + temp_ty = ir.VectorType.get(list(flat_shp), self.dtype.mlir_type) + temp_vect = vector.shape_cast(temp_ty, self, loc=loc, ip=ip) + + # broadcast to result N-D vector + flat_tgt_shp = flatten_to_tuple(target_shape) + temp_tgt_ty = ir.VectorType.get(list(flat_tgt_shp), self.dtype.mlir_type) + temp_tgt_vect = vector.broadcast(temp_tgt_ty, temp_vect, loc=loc, ip=ip) + + res_1d_ty = ir.VectorType.get([size(target_shape)], self.dtype.mlir_type) # type: ignore + res_1d_vect = vector.shape_cast(res_1d_ty, temp_tgt_vect, loc=loc, ip=ip) + + return TensorSSA(res_1d_vect, target_shape, self.dtype) + def __pow__(self, other, *, loc=None, ip=None) -> "TensorSSA": """ Returns the results of tensor^other. @@ -6093,6 +6261,16 @@ class TensorSSA(cutlass_arith.ArithValue): """ return self._apply_op(operator.and_, other, flip=True, loc=loc, ip=ip) + def __neg__(self, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the negation of the tensor. + + :return: The element-wise negation of the tensor + :rtype: TensorSSA + """ + + return self._apply_op(operator.sub, 0, flip=True, loc=loc, ip=ip) + def _flatten_shape_and_coord(self, crd, *, loc=None, ip=None): # Coalesce and flatten source layout at terminal of coordinate # (N_0,(N_1,...), ...) -> (N_0,N_1,N_2,...) @@ -6158,17 +6336,13 @@ class TensorSSA(cutlass_arith.ArithValue): if crd is None: return self - if not has_underscore(crd) or depth(crd) == 0: - idx = crd2idx(crd, make_layout(self._shape)) - if is_static(idx): - res = vector.extract( - self, dynamic_position=[], static_position=[idx], loc=loc, ip=ip - ) - else: - res = vector.extract( - self, dynamic_position=[crd], static_position=[], loc=loc, ip=ip - ) - return self.dtype(res) + if not has_underscore(crd): + if self._layout is None: + self._layout = make_layout(self._shape, loc=loc, ip=ip) + idx = crd2idx(crd, self._layout, loc=loc, ip=ip) + idx_val = as_numeric(idx).ir_value(loc=loc, ip=ip) + res_val = vector.extractelement(self, position=idx_val, loc=loc, ip=ip) + return self.dtype(res_val) if not is_static(crd): raise ValueError("dynamic coordinate is not supported") @@ -6274,7 +6448,7 @@ class TensorSSA(cutlass_arith.ArithValue): :type op: operator :param init_val: The initial value for the reduction :type init_val: numeric - :param reduction_profile: Specifies which dimensions to reduce. Dimensions marked with '_' are kept. + :param reduction_profile: Specifies which dimensions to reduce. Dimensions marked with `None` are kept. :type reduction_profile: Coord :return: The reduced tensor @@ -6289,9 +6463,9 @@ class TensorSSA(cutlass_arith.ArithValue): reduce(f32 o (4, 5)) => f32 - reduce(f32 o (4, (5, 4)), reduction_profile=(_, 1)) + reduce(f32 o (4, (5, 4)), reduction_profile=(None, 1)) => f32 o (4,) - reduce(f32 o (4, (5, 4)), reduction_profile=(_, (_, 1))) + reduce(f32 o (4, (5, 4)), reduction_profile=(None, (None, 1))) => f32 o (4, (5,)) """ # short-cut to no-op @@ -6354,21 +6528,6 @@ class TensorSSA(cutlass_arith.ArithValue): return self._build_result(res_vect, res_shp, loc=loc, ip=ip) -def _get_attr_for_type(ty, value): - if isinstance(ty, ir.IntegerType): - return ir.IntegerAttr.get(ty, value.to(int)) - elif isinstance(ty, ir.FloatType): - return ir.FloatAttr.get(ty, value.to(float)) - else: - raise TypeError(f"unsupported type: {ty}") - - -def _splat(res_ty, fill_value): - elem_attr = _get_attr_for_type(res_ty.element_type, fill_value) - vect_attr = ir.DenseElementsAttr.get_splat(res_ty, elem_attr) - return arith.constant(res_ty, vect_attr) - - @dsl_user_op def full(shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> TensorSSA: """ @@ -6389,9 +6548,14 @@ def full(shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> Tenso if isinstance(fill_value, (ir.Value, int, float, bool)): fill_value = dtype(fill_value) + elif isinstance(fill_value, Numeric): + fill_value = fill_value.to(dtype, loc=loc, ip=ip) + else: + raise ValueError(f"Expected fill_value be numeric type, but got {fill_value}") - res_mlir_type = T.vector(size, dtype.mlir_type) - return TensorSSA(_splat(res_mlir_type, fill_value), shape, dtype) + res_ty = T.vector(size, dtype.mlir_type) + res_val = vector.splat(res_ty, fill_value.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + return TensorSSA(res_val, shape, dtype) def full_like( @@ -6547,7 +6711,7 @@ class struct: **Usage:** - .. code-block:: + .. code-block:: python # Supports base_dsl scalar int/float elements, array and nested struct: @cute.struct @@ -6661,7 +6825,8 @@ class struct: Initializes a new memory range. :param dtype: The data type. - :param size: The size of the memory range in bytes. + :param size: Size of the memory range in bytes. A size of **0** is accepted, but in that + case the range can only be used for its address (e.g. as a partition marker). :param base: The base address of the memory range. """ self._dtype = dtype @@ -6673,9 +6838,9 @@ class struct: Returns start pointer to the data in this memory range. :return: A pointer to the start of the memory range. - :raises AssertionError: If the size of the memory range is not greater than zero. + :raises AssertionError: If the size of the memory range is negative. """ - assert self._size > 0 + assert self._size >= 0 return recast_ptr(self._base, dtype=self._dtype) def get_tensor(self, layout, swizzle=None, dtype=None): @@ -6716,31 +6881,48 @@ class struct: :param v: The object to align. Must be a struct, MemRange, or a scalar type. :param align: The alignment value to set. - :return: A copy of the object with the specified alignment. :raises TypeError: If the object is not a struct, MemRange, or a scalar type. + + :ivar _dtype: The data type to be aligned. + :ivar _align: The alignment of the data type. """ + _dtype = None + _align = None + def __new__(cls, name, bases, dct): return super().__new__(cls, name, bases, dct) def __getitem__(cls, params) -> Any: if len(params) == 2: - obj, align = params + dtype, align = params + assert align > 0 else: raise TypeError("Invalid struct.Align Arguments") - # make a copy of type and mark alignment - if struct._is_scalar_type(obj) or isinstance( - obj, (struct, struct._MemRangeMeta) + if not struct._is_scalar_type(dtype) and not isinstance( + dtype, (struct, struct._MemRangeMeta) ): - new_obj = py_copy.copy(obj) - setattr(new_obj, "_struct_alignment_", align) - return new_obj - else: raise TypeError( - "align only can be applied to sturct/MemRange/base_dsl scalar" + "align only can be applied to struct/MemRange/base_dsl scalar" ) + # Create new class with alignment + new_cls = type( + f"struct.Align[{dtype.__name__}, {align}]", + (struct.Align,), + {"_dtype": dtype, "_align": align}, + ) + return new_cls + + @property + def dtype(cls): + return cls._dtype + + @property + def align(cls): + return cls._align + class Align(metaclass=_AlignMeta): """ Aligns the given type by `Align[T, alignment]`. @@ -6768,6 +6950,7 @@ class struct: :raises TypeError: If the struct is empty. """ self._cls = cls + self.__name__ = f"struct::{cls.__name__}" # Get the class annotations self._annotations = cls.__annotations__ # Create a dictionary to store the offsets @@ -6780,12 +6963,10 @@ class struct: raise TypeError("Empty struct is not supported!") for name, object in self._annotations.items(): # get alignment of object - def alignof(object, default: int = 1): - return getattr(object, "_struct_alignment_", default) - - # alignment for the next offset - def align_offset(offset, align): - return (offset + (align - 1)) & ~(align - 1) + sub_align = 1 + if isinstance(object, struct._AlignMeta): + sub_align = object.align + object = object.dtype # switch addition order to support dynamic size def add_offset(val): @@ -6793,35 +6974,37 @@ class struct: # size of scalar if struct._is_scalar_type(object): - dtype_size = object.width // 8 - sub_align = alignof(object, dtype_size) - offset = align_offset(offset, sub_align) + dtype_size = max(1, object.width // 8) + sub_align = max(dtype_size, sub_align) + offset = self.align_offset(offset, sub_align) self._offsets[name] = offset offset = add_offset(dtype_size) # size of array is size_in_bytes, alignment is elem_size elif isinstance(object, struct._MemRangeMeta): - if object.size == 0: - continue # skip empty array - sub_align = alignof(object, max(1, object.elem_width // 8)) - offset = align_offset(offset, sub_align) + # Allow empty array as a free marker-only struct member. + # Use max(sub_align, ) because we might have in the future some + # object.elem_width less than 8, such as fp4, bit and others, + # and align_offset() does not support an alignment of 0. + sub_align = max(object.elem_width // 8, sub_align) + offset = self.align_offset(offset, sub_align) self._offsets[name] = offset offset = add_offset(object.size_in_bytes) # size of struct elif isinstance(object, struct): - sub_align = max(object.__alignof__(), alignof(object)) - offset = align_offset(offset, sub_align) + sub_align = max(object.__alignof__(), sub_align) + offset = self.align_offset(offset, sub_align) self._offsets[name] = offset offset = add_offset(object.__sizeof__()) else: raise TypeError( - f"Struct element only support sturct/array/base_dsl scalar, " + f"Struct element only support struct/array/base_dsl scalar, " f"but got {object}" ) # Total aligment determined by the strictest requirement alignment = max(alignment, sub_align) # Total size determined by alignment self._align_of = alignment - self._size_of = align_offset(offset, alignment) + self._size_of = self.align_offset(offset, alignment) # create the __init__ method for decorated struct def __call__(self, base: Any) -> None: @@ -6840,6 +7023,8 @@ class struct: setattr(cls, "_base", base) for name, off in self._offsets.items(): obj = self._annotations[name] + if isinstance(obj, struct._AlignMeta): + obj = obj.dtype if struct._is_scalar_type(obj): new_obj = recast_ptr(base + off, dtype=obj) setattr(cls, name, new_obj) @@ -6851,7 +7036,7 @@ class struct: setattr(cls, name, new_obj) else: raise TypeError( - f"Struct element only support sturct/array/base_dsl scalar, " + f"Struct element only support struct/array/base_dsl scalar, " f"but got {obj}" ) return cls @@ -6872,3 +7057,14 @@ class struct: # get alignment def __alignof__(self) -> int: return self._align_of + + # util func for aligning offset + @staticmethod + def align_offset(offset, align): + """ + Return the round-up offset up to the next multiple of align. + """ + assert align > 0 and not ( + align & (align - 1) + ), "align should be a strictly positive power of 2." + return (offset + (align - 1)) & ~(align - 1) diff --git a/python/CuTeDSL/cutlass/cute/math.py b/python/CuTeDSL/cutlass/cute/math.py index 3dda89c2..daaa6082 100644 --- a/python/CuTeDSL/cutlass/cute/math.py +++ b/python/CuTeDSL/cutlass/cute/math.py @@ -10,16 +10,53 @@ # is strictly prohibited. from .core import TensorSSA +from .typing import Numeric from cutlass._mlir.dialects import math, arith +from typing import Callable, Union -def acos(a: TensorSSA) -> TensorSSA: + +def _math_op(func: Callable, fastmath: bool, *args, **kwargs): + """Dispatch the function to either a TensorSSA or a Numeric(Float). + + :param func: The function to dispatch + :param args: The input tensor or scalar + :param kwargs: The input tensor or scalar + """ + arg_type = type(args[0]) + for arg in args: + if not isinstance(arg, TensorSSA) and ( + not isinstance(arg, Numeric) or not type(arg).is_float + ): + raise TypeError( + f"Expected a TensorSSA or Numeric(Float), but got {type(arg)}" + ) + if not isinstance(arg, arg_type): + raise TypeError( + f"Expected all inputs to be of type {arg_type}, but got {type(arg)}" + ) + + fastmath_flag = arith.FastMathFlags.fast if fastmath else arith.FastMathFlags.none + if isinstance(args[0], TensorSSA): + return TensorSSA( + func(*args, fastmath=fastmath_flag), args[0].shape, args[0].dtype + ) + else: + args = [a.ir_value() for a in args] + return func(*args, fastmath=fastmath_flag) + + +def acos( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: """Compute element-wise arc cosine of the input tensor. :param a: Input tensor - :type a: TensorSSA + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional :return: Tensor containing the arc cosine of each element in input tensor - :rtype: TensorSSA + :rtype: Union[TensorSSA, Numeric] Example: @@ -29,16 +66,20 @@ def acos(a: TensorSSA) -> TensorSSA: y = x.load() # Load values z = acos(y) # Compute arc cosine """ - return TensorSSA(math.acos(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype) + return _math_op(math.acos, fastmath, a) -def asin(a: TensorSSA) -> TensorSSA: +def asin( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: """Compute element-wise arc sine of the input tensor. :param a: Input tensor - :type a: TensorSSA + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional :return: Tensor containing the arc sine of each element in input tensor - :rtype: TensorSSA + :rtype: Union[TensorSSA, Numeric] Example: @@ -48,18 +89,20 @@ def asin(a: TensorSSA) -> TensorSSA: y = x.load() # Load values z = asin(y) # Compute arc sine """ - return TensorSSA(math.asin(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype) + return _math_op(math.asin, fastmath, a) -def atan(a: TensorSSA, fastmath: bool = False) -> TensorSSA: +def atan( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: """Compute element-wise arc tangent of the input tensor. :param a: Input tensor - :type a: TensorSSA + :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional :return: Tensor containing the arc tangent of each element in input tensor - :rtype: TensorSSA + :rtype: Union[TensorSSA, Numeric] Example: @@ -70,23 +113,25 @@ def atan(a: TensorSSA, fastmath: bool = False) -> TensorSSA: z = atan(y) # Compute arc tangent """ raise NotImplementedError("atan is not implemented") - return TensorSSA(math.atan(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype) + return _math_op(math.atan, fastmath, a) -def atan2(a: TensorSSA, b: TensorSSA, fastmath: bool = False) -> TensorSSA: +def atan2( + a: Union[TensorSSA, Numeric], b: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: """Compute element-wise arc tangent of two tensors. Computes atan2(a, b) element-wise. The function atan2(a, b) is the angle in radians between the positive x-axis and the point given by the coordinates (b, a). :param a: First input tensor (y-coordinates) - :type a: TensorSSA + :type a: Union[TensorSSA, Numeric] :param b: Second input tensor (x-coordinates) - :type b: TensorSSA + :type b: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional :return: Tensor containing the arc tangent of a/b element-wise - :rtype: TensorSSA + :rtype: Union[TensorSSA, Numeric] Example: @@ -96,20 +141,20 @@ def atan2(a: TensorSSA, b: TensorSSA, fastmath: bool = False) -> TensorSSA: x = cute.make_fragment(ptr2, layout).load() # x coordinates theta = atan2(y, x) # Compute angles """ - return TensorSSA( - math.atan2(a, b, fastmath=arith.FastMathFlags.none), a.shape, a.dtype - ) + return _math_op(math.atan2, fastmath, a, b) -def cos(a: TensorSSA, fastmath: bool = False) -> TensorSSA: +def cos( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: """Compute element-wise cosine of the input tensor. :param a: Input tensor (in radians) - :type a: TensorSSA + :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional :return: Tensor containing the cosine of each element - :rtype: TensorSSA + :rtype: Union[TensorSSA, Numeric] Example: @@ -119,21 +164,23 @@ def cos(a: TensorSSA, fastmath: bool = False) -> TensorSSA: y = x.load() # Load values z = cos(y) # Compute cosine """ - return TensorSSA(math.cos(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype) + return _math_op(math.cos, fastmath, a) -def erf(a: TensorSSA, fastmath: bool = False) -> TensorSSA: +def erf( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: """Compute element-wise error function of the input tensor. The error function is defined as: erf(x) = 2/√π ∫[0 to x] exp(-t²) dt :param a: Input tensor - :type a: TensorSSA + :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional :return: Tensor containing the error function value for each element - :rtype: TensorSSA + :rtype: Union[TensorSSA, Numeric] Example: @@ -143,18 +190,43 @@ def erf(a: TensorSSA, fastmath: bool = False) -> TensorSSA: y = x.load() # Load values z = erf(y) # Compute error function """ - return TensorSSA(math.erf(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype) + return _math_op(math.erf, fastmath, a) -def exp2(a: TensorSSA, fastmath: bool = False) -> TensorSSA: +def exp( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise exponential of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the exponential of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = exp(y) # Compute exponential + """ + return _math_op(math.exp, fastmath, a) + + +def exp2( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: """Compute element-wise base-2 exponential of the input tensor. :param a: Input tensor - :type a: TensorSSA + :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional :return: Tensor containing 2 raised to the power of each element - :rtype: TensorSSA + :rtype: Union[TensorSSA, Numeric] Example: @@ -164,18 +236,20 @@ def exp2(a: TensorSSA, fastmath: bool = False) -> TensorSSA: y = x.load() # Load values z = exp2(y) # Compute 2^x """ - return TensorSSA(math.exp2(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype) + return _math_op(math.exp2, fastmath, a) -def log(a: TensorSSA, fastmath: bool = False) -> TensorSSA: +def log( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: """Compute element-wise natural logarithm of the input tensor. :param a: Input tensor - :type a: TensorSSA + :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional :return: Tensor containing the natural logarithm of each element - :rtype: TensorSSA + :rtype: Union[TensorSSA, Numeric] Example: @@ -185,18 +259,20 @@ def log(a: TensorSSA, fastmath: bool = False) -> TensorSSA: y = x.load() # Load values z = log(y) # Compute natural logarithm """ - return TensorSSA(math.log(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype) + return _math_op(math.log, fastmath, a) -def log2(a: TensorSSA, fastmath: bool = False) -> TensorSSA: +def log2( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: """Compute element-wise base-2 logarithm of the input tensor. :param a: Input tensor - :type a: TensorSSA + :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional :return: Tensor containing the base-2 logarithm of each element - :rtype: TensorSSA + :rtype: Union[TensorSSA, Numeric] Example: @@ -206,18 +282,20 @@ def log2(a: TensorSSA, fastmath: bool = False) -> TensorSSA: y = x.load() # Load values z = log2(y) # Compute log base 2 """ - return TensorSSA(math.log2(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype) + return _math_op(math.log2, fastmath, a) -def log10(a: TensorSSA, fastmath: bool = False) -> TensorSSA: +def log10( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: """Compute element-wise base-10 logarithm of the input tensor. :param a: Input tensor - :type a: TensorSSA + :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional :return: Tensor containing the base-10 logarithm of each element - :rtype: TensorSSA + :rtype: Union[TensorSSA, Numeric] Example: @@ -227,20 +305,22 @@ def log10(a: TensorSSA, fastmath: bool = False) -> TensorSSA: y = x.load() # Load values z = log10(y) # Compute log base 10 """ - return TensorSSA(math.log10(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype) + return _math_op(math.log10, fastmath, a) -def rsqrt(a: TensorSSA, fastmath: bool = False) -> TensorSSA: +def rsqrt( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: """Compute element-wise reciprocal square root of the input tensor. Computes 1/√x element-wise. :param a: Input tensor - :type a: TensorSSA + :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional :return: Tensor containing the reciprocal square root of each element - :rtype: TensorSSA + :rtype: Union[TensorSSA, Numeric] Example: @@ -250,18 +330,20 @@ def rsqrt(a: TensorSSA, fastmath: bool = False) -> TensorSSA: y = x.load() # Load values z = rsqrt(y) # Compute 1/√x """ - return TensorSSA(math.rsqrt(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype) + return _math_op(math.rsqrt, fastmath, a) -def sin(a: TensorSSA, fastmath: bool = False) -> TensorSSA: +def sin( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: """Compute element-wise sine of the input tensor. :param a: Input tensor (in radians) - :type a: TensorSSA + :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional :return: Tensor containing the sine of each element - :rtype: TensorSSA + :rtype: Union[TensorSSA, Numeric] Example: @@ -271,18 +353,20 @@ def sin(a: TensorSSA, fastmath: bool = False) -> TensorSSA: y = x.load() # Load values z = sin(y) # Compute sine """ - return TensorSSA(math.sin(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype) + return _math_op(math.sin, fastmath, a) -def sqrt(a: TensorSSA, fastmath: bool = False) -> TensorSSA: +def sqrt( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: """Compute element-wise square root of the input tensor. :param a: Input tensor - :type a: TensorSSA + :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional :return: Tensor containing the square root of each element - :rtype: TensorSSA + :rtype: Union[TensorSSA, Numeric] Example: @@ -292,16 +376,20 @@ def sqrt(a: TensorSSA, fastmath: bool = False) -> TensorSSA: y = x.load() # Load values z = sqrt(y) # Compute square root """ - return TensorSSA(math.sqrt(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype) + return _math_op(math.sqrt, fastmath, a) -def tan(a: TensorSSA) -> TensorSSA: +def tan( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: """Compute element-wise tangent of the input tensor. :param a: Input tensor (in radians) - :type a: TensorSSA + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional :return: Tensor containing the tangent of each element - :rtype: TensorSSA + :rtype: Union[TensorSSA, Numeric] Example: @@ -311,18 +399,20 @@ def tan(a: TensorSSA) -> TensorSSA: y = x.load() # Load values z = tan(y) # Compute tangent """ - return TensorSSA(math.tan(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype) + return _math_op(math.tan, fastmath, a) -def tanh(a: TensorSSA, fastmath: bool = False) -> TensorSSA: +def tanh( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: """Compute element-wise hyperbolic tangent of the input tensor. :param a: Input tensor - :type a: TensorSSA + :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional :return: Tensor containing the hyperbolic tangent of each element - :rtype: TensorSSA + :rtype: Union[TensorSSA, Numeric] Example: @@ -332,7 +422,7 @@ def tanh(a: TensorSSA, fastmath: bool = False) -> TensorSSA: y = x.load() # Load values z = tanh(y) # Compute hyperbolic tangent """ - return TensorSSA(math.tanh(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype) + return _math_op(math.tanh, fastmath, a) __all__ = [ @@ -342,6 +432,7 @@ __all__ = [ "atan2", "cos", "erf", + "exp", "exp2", "log", "log10", diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/common.py b/python/CuTeDSL/cutlass/cute/nvgpu/common.py index 87a01be9..1b0c4c82 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/common.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/common.py @@ -8,7 +8,7 @@ # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA # is strictly prohibited. - +import enum from dataclasses import dataclass from typing import Type, Optional @@ -101,6 +101,42 @@ class MmaUniversalTrait(core.Trait): #################################################################################################### +class MemoryOrder(enum.Enum): + WEAK = _cute_ir.MemOrderKind.WEAK + RELAXED = _cute_ir.MemOrderKind.RELAXED + ACQUIRE = _cute_ir.MemOrderKind.ACQUIRE + RELEASE = _cute_ir.MemOrderKind.RELEASE + ACQ_REL = _cute_ir.MemOrderKind.ACQ_REL + SC = _cute_ir.MemOrderKind.SC + MMIO = _cute_ir.MemOrderKind.MMIO + CONSTANT = _cute_ir.MemOrderKind.CONSTANT + VOLATILE = _cute_ir.MemOrderKind.VOLATILE + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_ir.MemOrderKind: + return self.value + + +class MemoryScope(enum.Enum): + CTA = _cute_ir.MemScopeKind.CTA + CLUSTER = _cute_ir.MemScopeKind.CLUSTER + GPU = _cute_ir.MemScopeKind.GPU + SYS = _cute_ir.MemScopeKind.SYS + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_ir.MemScopeKind: + return self.value + @dataclass(frozen=True) class CopyUniversalOp(core.CopyOp): """ @@ -133,13 +169,18 @@ class CopyUniversalOp(core.CopyOp): **kwargs, ) -> "CopyUniversalTrait": num_bits_per_copy = kwargs.get("num_bits_per_copy", 0) + memory_order = kwargs.get("memory_order", MemoryOrder.WEAK) + memory_scope = kwargs.get("memory_scope", MemoryScope.CTA) if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0): raise ValueError( "expects a 'num_bits_per_copy' kw argument of type int that is non-negative " f"when creating a copy Atom for {self.__class__.__name__}" ) ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( - copy_internal_type.mlir_type, num_bits_per_copy + copy_internal_type.mlir_type, + num_bits_per_copy, + memory_order._to_ir(), + memory_scope._to_ir(), ) return CopyUniversalTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py index b4b88031..246360c2 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py @@ -23,6 +23,7 @@ __all__ = [ "CopyBulkTensorTileG2SOp", "CopyBulkTensorTileG2SMulticastOp", "CopyBulkTensorTileS2GOp", + "CopyReduceBulkTensorTileS2GOp", # # helpers.py # diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py index 8744a376..a1549560 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py @@ -19,7 +19,7 @@ import cutlass._mlir.dialects.cute as _cute_ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir import ir -from ...core import CopyOp, Trait +from ...core import CopyOp, Trait, ReductionOp from ...typing import Int16, Pointer, Integer, Numeric from ..common import OpError from ..tcgen05.mma import CtaGroup @@ -80,6 +80,12 @@ class CopyG2SOp(CopyOp): **kwargs, ) -> "CopyG2STrait": num_bits_per_copy = kwargs.get("num_bits_per_copy", None) + # Verify that the user provided enum values + if not isinstance(self.cache_mode, LoadCacheMode): + raise OpError( + self, + "expects the 'cache_mode' Op parameter to be a LoadCacheMode instance", + ) if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy <= 0): raise ValueError( "expects a 'num_bits_per_copy' kw argument of type int that is positive " @@ -330,7 +336,7 @@ class CopyBulkTensorTileG2SMulticastNonExecTrait(Trait): @dataclass(frozen=True) class CopyBulkTensorTileS2GOp(CopyOp): """ - Bulk tensor asynchrnous SMEM to GMEM Copy Operation using the TMA unit. + Bulk tensor asynchronous SMEM to GMEM Copy Operation using the TMA unit. See the `PTX documentation `__. This Operation uses TMA in the ``.tile`` mode. @@ -379,3 +385,87 @@ class CopyBulkTensorTileS2GTrait(Trait): exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip ) return exec_value + +@dataclass(frozen=True) +class CopyReduceBulkTensorTileS2GOp(CopyOp): + """ + Bulk tensor asynchronous SMEM to GMEM Reduction Operation using the TMA unit. + + See the `PTX documentation `__. + This Operation uses TMA in the ``.tile`` mode. + """ + + reduction_kind: ReductionOp = ReductionOp.ADD + + admissible_archs = [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ] + + def __post__init__(self): + # Arch verification + arch = CuTeDSL.__get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + def __str__(self) -> str: + return "cp.async SMEM -> GMEM bulk tensor reduction Operation" + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "CopyReduceBulkTensorTileS2GTrait": + raise NotImplementedError( + "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" + ) + + def _to_ir(self) -> _cute_nvgpu_ir.ReductionKind: + if self.reduction_kind == ReductionOp.ADD: + return _cute_nvgpu_ir.ReductionKind.ADD + elif self.reduction_kind == ReductionOp.MIN: + return _cute_nvgpu_ir.ReductionKind.MIN + elif self.reduction_kind == ReductionOp.MAX: + return _cute_nvgpu_ir.ReductionKind.MAX + elif self.reduction_kind == ReductionOp.INC: + return _cute_nvgpu_ir.ReductionKind.INC + elif self.reduction_kind == ReductionOp.DEC: + return _cute_nvgpu_ir.ReductionKind.DEC + elif self.reduction_kind == ReductionOp.AND: + return _cute_nvgpu_ir.ReductionKind.AND + elif self.reduction_kind == ReductionOp.OR: + return _cute_nvgpu_ir.ReductionKind.OR + elif self.reduction_kind == ReductionOp.XOR: + return _cute_nvgpu_ir.ReductionKind.XOR + else: + assert False, "unrecognized self.reduction_kind" + + +class CopyReduceBulkTensorTileS2GTrait(Trait): + def unpack(self, *, loc=None, ip=None, tma_desc_ptr: Optional[Pointer] = None): + """ + Custom implementation of unpack for non-executable TMAs. + """ + exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip) + if isinstance(tma_desc_ptr, Pointer): + attr_str = ( + f"#cute_nvgpu.atom_copy_field_tmareduce<{TMA_DESC_PTR_FIELD_NAME}>" + ) + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip + ) + return exec_value + +__all__ = [ + "LoadCacheMode", + "CopyG2SOp", + "CopyBulkTensorTileG2SOp", + "CopyBulkTensorTileG2SMulticastOp", + "CopyBulkTensorTileS2GOp", + "CopyReduceBulkTensorTileS2GOp", +] diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py index f8374407..f64f07f1 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py @@ -22,9 +22,11 @@ from .copy import ( CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp, + CopyReduceBulkTensorTileS2GOp, CopyBulkTensorTileG2SNonExecTrait, CopyBulkTensorTileG2SMulticastNonExecTrait, CopyBulkTensorTileS2GTrait, + CopyReduceBulkTensorTileS2GTrait, ) @@ -34,6 +36,7 @@ def make_tiled_tma_atom( CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp, + CopyReduceBulkTensorTileS2GOp, ], gmem_tensor: Tensor, smem_layout: Union[Layout, core.ComposedLayout], @@ -67,7 +70,7 @@ def make_tiled_tma_atom( similarly to any other CuTe tensors using the algebra. :param op: The Copy Operation to construct an Atom for - :type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp] + :type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp, CopyReduceBulkTensorTileS2GOp] :param gmem_tensor: The GMEM tensor involved in the Copy :type gmem_tensor: Tensor :param smem_layout: The SMEM layout to construct the Copy Atom for @@ -141,6 +144,17 @@ def make_tiled_tma_atom( ip=ip, ) return core.CopyAtom(op, CopyBulkTensorTileS2GTrait(res[0])), res[1] + elif isinstance(op, CopyReduceBulkTensorTileS2GOp): + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_reduce( + gmem_tensor.value, + smem_layout, + cta_v_map, + op._to_ir(), + internal_type=internal_type, + loc=loc, + ip=ip, + ) + return core.CopyAtom(op, CopyReduceBulkTensorTileS2GTrait(res[0])), res[1] else: raise ValueError(f"expects a bulk tensor (TMA) Copy Op, but got {op}") diff --git a/python/CuTeDSL/cutlass/cute/runtime.py b/python/CuTeDSL/cutlass/cute/runtime.py index c86890d3..9128c67a 100644 --- a/python/CuTeDSL/cutlass/cute/runtime.py +++ b/python/CuTeDSL/cutlass/cute/runtime.py @@ -21,7 +21,7 @@ from cutlass._mlir import ir import cutlass._mlir.dialects.cute as _cute_ir from cutlass.base_dsl.dsl import is_dynamic_expression -from cutlass.cutlass_dsl import TensorFormat, JitArgAdapterRegistry +from cutlass.cutlass_dsl import JitArgAdapterRegistry # Local modules imports from .typing import ( @@ -82,42 +82,36 @@ class _Pointer(Pointer): self._dtype = dtype self._addr_space = mem_space - is_in_device = mem_space == _cute_ir.AddressSpace.gmem if assumed_align is None: - if is_in_device: - self._assumed_align = 32 - else: - self._assumed_align = dtype.width // 8 + self._assumed_align = dtype.width // 8 else: self._assumed_align = assumed_align - class PtrDescriptor(ctypes.Structure): - """A ctype descriptor for CuTe memref ptr""" - - _fields_ = [("ptr", ctypes.c_void_p)] - - def __str__(self): - return f"0x{self.ptr:016x}" - - self._desc = PtrDescriptor(int(self._pointer)) - self._c_pointer = ctypes.cast(ctypes.pointer(self._desc), ctypes.c_void_p) + self._c_pointer = None assert ( - self._desc.ptr % self._assumed_align == 0 + int(self._pointer) % self._assumed_align == 0 ), f"pointer must be {self._assumed_align} bytes aligned" def size_in_bytes(self) -> int: + self._desc = ctypes.c_void_p(int(self._pointer)) return ctypes.sizeof(self._desc) def __get_mlir_types__(self): return [self.mlir_type] def __c_pointers__(self): + if self._c_pointer is None: + self._desc = ctypes.c_void_p(int(self._pointer)) + self._c_pointer = ctypes.addressof(self._desc) return [self._c_pointer] def __new_from_mlir_values__(self, values): assert len(values) == 1 return values[0] + def __extract_mlir_values__(self): + return [self._c_pointer] + # Move mlir Type out of __init__ to decouple with mlir Context @property def mlir_type(self) -> ir.Type: @@ -145,7 +139,7 @@ class _Pointer(Pointer): return False def __str__(self) -> str: - return f"Ptr<0x{self._desc.ptr:016x}@{self._addr_space}>" + return f"Ptr<0x{int(self._pointer):016x}@{self._addr_space}>" def __repr__(self): return self.__str__() diff --git a/python/CuTeDSL/cutlass/pipeline/__init__.py b/python/CuTeDSL/cutlass/pipeline/__init__.py index d2729787..7df24dd6 100644 --- a/python/CuTeDSL/cutlass/pipeline/__init__.py +++ b/python/CuTeDSL/cutlass/pipeline/__init__.py @@ -31,9 +31,12 @@ from .helpers import ( from .sm90 import ( PipelineAsync, + PipelineCpAsync, PipelineTmaAsync, PipelineTmaMultiConsumersAsync, PipelineTmaStore, + PipelineProducer, + PipelineConsumer, ) from .sm100 import ( @@ -53,10 +56,13 @@ __all__ = [ "PipelineUserType", "PipelineState", "PipelineAsync", + "PipelineCpAsync", "PipelineTmaAsync", "PipelineTmaUmma", "PipelineTmaMultiConsumersAsync", "PipelineAsyncUmma", "PipelineUmmaAsync", "PipelineTmaStore", + "PipelineProducer", + "PipelineConsumer", ] diff --git a/python/CuTeDSL/cutlass/pipeline/helpers.py b/python/CuTeDSL/cutlass/pipeline/helpers.py index 68acfdab..b5b94899 100644 --- a/python/CuTeDSL/cutlass/pipeline/helpers.py +++ b/python/CuTeDSL/cutlass/pipeline/helpers.py @@ -89,6 +89,8 @@ class PipelineOp(enum.Enum): TmaStore = enum.auto() # Composite of multiple PipelineOps Composite = enum.auto() + # Async load without TMA + AsyncLoad = enum.auto() def _get_pipeline_op(type_str): @@ -226,6 +228,8 @@ class MbarrierArray(SyncObject): self.arrive_tcgen05mma(index, dst, cta_group) elif self.op_type in [PipelineOp.TmaLoad]: self.arrive_and_expect_tx(index, self.tx_count) + elif self.op_type is PipelineOp.AsyncLoad: + self.arrive_cp_async_mbarrier(index) else: assert ( False @@ -237,6 +241,9 @@ class MbarrierArray(SyncObject): else: cute.arch.mbarrier_arrive(self.get_barrier(index), dst_rank) + def arrive_cp_async_mbarrier(self, index: int): + cute.arch.cp_async_mbarrier_arrive_noinc(self.get_barrier(index)) + def arrive_tcgen05mma( self, index: int, mask: Optional[int], cta_group: cute.nvgpu.tcgen05.CtaGroup ) -> None: diff --git a/python/CuTeDSL/cutlass/pipeline/sm100.py b/python/CuTeDSL/cutlass/pipeline/sm100.py index 591e1d7a..2feed8cc 100644 --- a/python/CuTeDSL/cutlass/pipeline/sm100.py +++ b/python/CuTeDSL/cutlass/pipeline/sm100.py @@ -19,6 +19,7 @@ import cutlass.cute as cute from cutlass.cutlass_dsl import Boolean, if_generate from cutlass.pipeline import ( + Agent, CooperativeGroup, PipelineOp, PipelineState, @@ -106,9 +107,9 @@ class PipelineTmaUmma(PipelineAsync): :type barrier_storage: cute.Pointer :param num_stages: Number of buffer stages for this pipeline :type num_stages: Int32 - :param producer_group: CooperativeGroup for the producer agent + :param producer_group: `CooperativeGroup` for the producer agent :type producer_group: CooperativeGroup - :param consumer_group: CooperativeGroup for the consumer agent + :param consumer_group: `CooperativeGroup` for the consumer agent :type consumer_group: CooperativeGroup :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage :type tx_count: int @@ -258,9 +259,9 @@ class PipelineAsyncUmma(PipelineAsync): :type barrier_storage: cute.Pointer :param num_stages: Number of buffer stages for this pipeline :type num_stages: Int32 - :param producer_group: CooperativeGroup for the producer agent + :param producer_group: `CooperativeGroup` for the producer agent :type producer_group: CooperativeGroup - :param consumer_group: CooperativeGroup for the consumer agent + :param consumer_group: `CooperativeGroup` for the consumer agent :type consumer_group: CooperativeGroup :param cta_layout_vmnk: Layout of the cluster shape :type cta_layout_vmnk: cute.Layout | None @@ -368,9 +369,9 @@ class PipelineUmmaAsync(PipelineAsync): :type barrier_storage: cute.Pointer :param num_stages: Number of buffer stages for this pipeline :type num_stages: Int32 - :param producer_group: CooperativeGroup for the producer agent + :param producer_group: `CooperativeGroup` for the producer agent :type producer_group: CooperativeGroup - :param consumer_group: CooperativeGroup for the consumer agent + :param consumer_group: `CooperativeGroup` for the consumer agent :type consumer_group: CooperativeGroup :param cta_layout_vmnk: Layout of the cluster shape :type cta_layout_vmnk: cute.Layout | None diff --git a/python/CuTeDSL/cutlass/pipeline/sm90.py b/python/CuTeDSL/cutlass/pipeline/sm90.py index 71b99519..5fc19960 100644 --- a/python/CuTeDSL/cutlass/pipeline/sm90.py +++ b/python/CuTeDSL/cutlass/pipeline/sm90.py @@ -10,15 +10,18 @@ # is strictly prohibited. import enum +from typing import Type, Tuple from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Optional, Union import warnings +import cutlass import cutlass.cute as cute from cutlass.cutlass_dsl import Boolean, Int32, if_generate from cutlass.pipeline import ( + Agent, CooperativeGroup, PipelineOp, SyncObject, @@ -91,6 +94,30 @@ class PipelineAsync: - D: Data ready (producer has written data to buffer) - R: Consumer reading (consumer is consuming data from buffer) + **Example:** + + .. code-block:: python + + # Create pipeline with 5 stages + pipeline = PipelineAsync.create( + num_stages=5, # number of pipeline stages + producer_group=producer_warp, + consumer_group=consumer_warp + barrier_storage=smem_ptr, # smem pointer for array of mbarriers in shared memory + ) + + producer, consumer = pipeline.make_participants() + # Producer side + for i in range(num_iterations): + handle = producer.acquire_and_advance() # Wait for buffer to be empty & Move index to next stage + # Write data to pipeline buffer + handle.commit() # Signal buffer is full + + # Consumer side + for i in range(num_iterations): + handle = consumer.wait_and_advance() # Wait for buffer to be full & Move index to next stage + # Read data from pipeline buffer + handle.release() # Signal buffer is empty """ sync_object_full: SyncObject @@ -114,6 +141,7 @@ class PipelineAsync: PipelineOp.TmaLoad, PipelineOp.TCGen05Mma, PipelineOp.Composite, + PipelineOp.AsyncLoad, ]: return MbarrierArray( barrier_storage=barrier_storage, @@ -232,6 +260,74 @@ class PipelineAsync: state.advance() self.producer_acquire(state) + # Util methods to manage produer and consumer + def make_producer(self): + state = make_pipeline_state(PipelineUserType.Producer, self.num_stages) + return PipelineProducer(self, state, self.sync_object_full.cg) + + def make_consumer(self): + state = make_pipeline_state(PipelineUserType.Consumer, self.num_stages) + return PipelineConsumer(self, state, self.sync_object_empty.cg) + + def make_participants(self): + return self.make_producer(), self.make_consumer() + + + +@dataclass(frozen=True) +class PipelineCpAsync(PipelineAsync): + """ + PipelineCpAsync is used for CpAsync producers and AsyncThread consumers (e.g. Hopper non-TMA mainloops). + """ + + @staticmethod + def create( + barrier_storage: cute.Pointer, + num_stages: Int32, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + producer_mask: Int32 = None, + consumer_mask: Int32 = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineAsync. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: CooperativeGroup for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: CooperativeGroup for the consumer agent + :type consumer_group: CooperativeGroup + :param producer_mask: Mask for signaling arrives for the producer agent + :type producer_mask: Int32 | None + :param consumer_mask: Mask for signaling arrives for the consumer agent + :type consumer_mask: Int32 | None + """ + producer_type = PipelineOp.AsyncLoad + consumer_type = PipelineOp.AsyncThread + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_array_full = PipelineCpAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer + ) + sync_object_array_empty = PipelineCpAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + pipeline_init_wait() + + return PipelineCpAsync( + sync_object_array_full, + sync_object_array_empty, + num_stages, + producer_mask, + consumer_mask, + ) + + @dataclass(frozen=True) class PipelineTmaAsync(PipelineAsync): """ @@ -294,9 +390,9 @@ class PipelineTmaAsync(PipelineAsync): :type barrier_storage: cute.Pointer :param num_stages: Number of buffer stages for this pipeline :type num_stages: Int32 - :param producer_group: CooperativeGroup for the producer agent + :param producer_group: `CooperativeGroup` for the producer agent :type producer_group: CooperativeGroup - :param consumer_group: CooperativeGroup for the consumer agent + :param consumer_group: `CooperativeGroup` for the consumer agent :type consumer_group: CooperativeGroup :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage :type tx_count: int @@ -404,11 +500,11 @@ class PipelineTmaMultiConsumersAsync(PipelineAsync): :type barrier_storage: cute.Pointer :param num_stages: Number of buffer stages for this pipeline :type num_stages: Int32 - :param producer_group: CooperativeGroup for the producer agent + :param producer_group: `CooperativeGroup` for the producer agent :type producer_group: CooperativeGroup - :param consumer_group_umma: CooperativeGroup for the UMMA consumer agent + :param consumer_group_umma: `CooperativeGroup` for the UMMA consumer agent :type consumer_group_umma: CooperativeGroup - :param consumer_group_async: CooperativeGroup for the AsyncThread consumer agent + :param consumer_group_async: `CooperativeGroup` for the AsyncThread consumer agent :type consumer_group_async: CooperativeGroup :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage :type tx_count: int @@ -529,9 +625,10 @@ class PipelineTmaStore(PipelineAsync): This helper function computes any necessary attributes and returns an instance of PipelineTmaStore. :param num_stages: Number of buffer stages for this pipeline :type num_stages: Int32 - :param producer_group: CooperativeGroup for the producer agent + :param producer_group: `CooperativeGroup` for the producer agent :type producer_group: CooperativeGroup """ + producer_type = PipelineOp.TmaStore producer = (producer_type, producer_group) @@ -556,3 +653,333 @@ class PipelineTmaStore(PipelineAsync): self.sync_object_full.tail() +################################################################# +# Utilities to help user of pipeline to simplify the workflow +################################################################# + + +class ImmutableResourceHandle: + __origin: PipelineAsync + __immutable_state: PipelineState + + def __init__(self, origin: PipelineAsync, immutable_state: PipelineState): + self.__origin = origin + self.__immutable_state = immutable_state + + @property + def index(self): + """Get the index of the current pipeline stage.""" + return self.__immutable_state.index + + @property + def count(self): + """Get the count of how many handles this producer has committed. + This is useful for tracking the number of blocks that have been loaded from gmem. + """ + return self.__immutable_state.count + + def get_origin(self): + """Get the original pipeline this resource handle belongs to.""" + return self.__origin + + def __extract_mlir_values__(self): + """Extract MLIR values from the current state. + + :return: List of MLIR values representing the current state + :rtype: list + """ + # TODO: need to handle pipeline as well + return self.__immutable_state.__extract_mlir_values__() + + def __new_from_mlir_values__(self, values): + """Create a new Producer instance from MLIR values. + + :param values: MLIR values to initialize the state + :type values: Any + :return: New Producer instance with state initialized from values + :rtype: Producer + """ + return self.__class__( + self.__origin, self.__immutable_state.__new_from_mlir_values__(values) + ) + +class PipelineProducer: + """A class representing a producer in an asynchronous pipeline. + + The Producer class manages the producer side of an asynchronous pipeline, handling + synchronization and state management for producing data. It provides methods for + acquiring, committing, and advancing through pipeline stages. + + :ivar __pipeline: The asynchronous pipeline this producer belongs to + :type __pipeline: PipelineAsync + :ivar __state: The current state of the producer in the pipeline + :type __state: PipelineState + :ivar __group: The cooperative group this producer operates in + :type __group: CooperativeGroup + + **Examples:** + + .. code-block:: python + + pipeline = PipelineAsync.create(...) + producer = pipeline.create_producer(producer_group, stages) + for i in range(iterations): + handle = producer.acquire_and_advance() # Wait for buffer to be empty + # Produce data + producer.commit(handle) # Signal data is ready + # An alternative way to do this is: + # handle.commit() # Signal data is ready + """ + + __pipeline: PipelineAsync + __state: PipelineState + __group: CooperativeGroup + + class ImmutableResourceHandle(ImmutableResourceHandle): + @property + def barrier(self): + """Get the barrier pointer for the current pipeline stage. + + :return: Pointer to the barrier for the current stage + :rtype: cute.Pointer + """ + return self.get_origin().producer_get_barrier( + self._ImmutableResourceHandle__immutable_state + ) + + def commit(self): + """Signal that data production is complete for the current stage. + This allows consumers to start processing the data. + """ + self.get_origin().producer_commit( + self._ImmutableResourceHandle__immutable_state + ) + + def __init__(self, pipeline, state, group: CooperativeGroup): + """Initialize a new Producer instance. + + :param pipeline: The pipeline this producer belongs to + :type pipeline: PipelineAsync + :param state: Initial pipeline state + :type state: PipelineState + :param group: The cooperative group for synchronization + :type group: CooperativeGroup + """ + self.__pipeline = pipeline + self.__state = state + self.__group = group + + def acquire( + self, + try_acquire_token: Optional[Boolean] = None, + ) -> ImmutableResourceHandle: + """Wait for the current buffer to be empty before producing data. + This is a blocking operation. + + :param try_acquire_token: Optional token to try to acquire the buffer + :type try_acquire_token: Optional[Boolean] + :return: A handle to the producer for committing the data + :rtype: ImmutableResourceHandle + """ + self.__pipeline.producer_acquire(self.__state, try_acquire_token) + handle = PipelineProducer.ImmutableResourceHandle( + self.__pipeline, self.__state.clone() + ) + return handle + + def advance(self): + """Move to the next pipeline stage.""" + self.__state.advance() + + def acquire_and_advance( + self, try_acquire_token: Optional[Boolean] = None + ) -> ImmutableResourceHandle: + """Wait for the current buffer to be empty before producing data. + Then advance to the next stage. + This is a blocking operation. + + :param try_acquire_token: Optional token to try to acquire the buffer + :type try_acquire_token: Optional[Boolean] + :return: A handle to the producer for committing the data + :rtype: ImmutableResourceHandle + """ + handle = self.acquire(try_acquire_token) + self.advance() + return handle + + def try_acquire(self) -> Boolean: + """Try to acquire the current buffer without blocking. + + :return: True if acquisition was successful, False otherwise + :rtype: Boolean + """ + return self.__pipeline.producer_try_acquire(self.__state) + + def commit(self, handle: Optional[ImmutableResourceHandle] = None): + """Signal that data production is complete for the current stage. + This allows consumers to start processing the data. + """ + if handle is not None: + assert ( + handle.get_origin() is self + ), "ResourceHandle does not belong to this PipelineProducer instance" + handle.commit() + else: + self.__pipeline.producer_commit(self.__state) + + def tail(self): + """Ensure all used buffers are properly synchronized before producer exit. + This should be called before the producer finishes to avoid dangling signals. + """ + self.__pipeline.producer_tail(self.__state) + + def __extract_mlir_values__(self): + """Extract MLIR values from the current state. + + :return: List of MLIR values representing the current state + :rtype: list + """ + # TODO: need to handle pipeline as well + return self.__state.__extract_mlir_values__() + + def __new_from_mlir_values__(self, values): + """Create a new Producer instance from MLIR values. + + :param values: MLIR values to initialize the state + :type values: Any + :return: New Producer instance with state initialized from values + :rtype: Producer + """ + return PipelineProducer( + self.__pipeline, self.__state.__new_from_mlir_values__(values), self.__group + ) + +class PipelineConsumer: + """A class representing a consumer in an asynchronous pipeline. + + The Consumer class manages the consumer side of an asynchronous pipeline, handling + synchronization and state management for consuming data. It provides methods for + waiting, releasing, and advancing through pipeline stages. + + :ivar __pipeline: The asynchronous pipeline this consumer belongs to + :type __pipeline: PipelineAsync + :ivar __state: The current state of the consumer in the pipeline + :type __state: PipelineState + :ivar __group: The cooperative group this consumer operates in + :type __group: CooperativeGroup + + **Examples:** + .. code-block:: python + + pipeline = PipelineAsync.create(...) + consumer = pipeline.create_consumer(consumer_group, stages) + for i in range(iterations): + handle = consumer.wait_and_advance() # Wait for data to be ready + # Consume data + consumer.release(handle) # Signal buffer is empty + # An alternative way to do this is: + # handle.release() # Signal buffer is empty + """ + + __pipeline: PipelineAsync + __state: PipelineState + __group: CooperativeGroup + + class ImmutableResourceHandle(ImmutableResourceHandle): + def release(self): + """Signal that data production is complete for the current stage. + This allows consumers to start processing the data. + """ + self.get_origin().consumer_release( + self._ImmutableResourceHandle__immutable_state + ) + + def __init__(self, pipeline, state: PipelineState, group: CooperativeGroup): + """Initialize a new Consumer instance. + + :param pipeline: The pipeline this consumer belongs to + :type pipeline: PipelineAsync + :param state: Initial pipeline state + :type state: PipelineState + :param group: The cooperative group for synchronization + :type group: CooperativeGroup + """ + self.__pipeline = pipeline + self.__group = group + self.__state = state + + def wait(self, try_wait_token: Optional[Boolean] = None) -> ImmutableResourceHandle: + """Wait for data to be ready in the current buffer. + This is a blocking operation. + + :param try_wait_token: Optional token to try to wait for the buffer + :type try_wait_token: Optional[Boolean] + :return: A handle to the consumer for releasing the data + :rtype: PipelineConsumerHandle + """ + self.__pipeline.consumer_wait(self.__state, try_wait_token) + handle = PipelineConsumer.ImmutableResourceHandle( + self.__pipeline, self.__state.clone() + ) + return handle + + def advance(self): + """Move to the next pipeline stage.""" + self.__state.advance() + + def wait_and_advance( + self, try_wait_token: Optional[Boolean] = None + ) -> ImmutableResourceHandle: + """Wait for data to be ready in the current buffer. + Then advance to the next stage. + This is a blocking operation. + + :param try_wait_token: Optional token to try to wait for the buffer + :type try_wait_token: Optional[Boolean] + :return: A handle to the consumer for releasing the data + :rtype: PipelineConsumerHandle + """ + handle = self.wait(try_wait_token) + self.advance() + return handle + + def try_wait(self) -> Boolean: + """Try to check if data is ready without blocking. + + :return: True if data is ready, False otherwise + :rtype: Boolean + """ + return self.__pipeline.consumer_try_wait(self.__state) + + def release(self, handle: Optional[ImmutableResourceHandle] = None): + """Signal that data consumption is complete for the current stage. + This allows producers to start producing new data. + """ + if handle is not None: + assert ( + handle.get_origin() is self + ), "ResourceHandle does not belong to this PipelineConsumer instance" + handle.release() + else: + self.__pipeline.consumer_release(self.__state) + + def __extract_mlir_values__(self): + """Extract MLIR values from the current state. + + :return: List of MLIR values representing the current state + :rtype: list + """ + return self.__state.__extract_mlir_values__() + + def __new_from_mlir_values__(self, values): + """Create a new Consumer instance from MLIR values. + + :param values: MLIR values to initialize the state + :type values: Any + :return: New Consumer instance with state initialized from values + :rtype: Consumer + """ + # TODO: need to call pipeline.__new_from_mlir_values__ recursively + return PipelineConsumer( + self.__pipeline, self.__state.__new_from_mlir_values__(values), self.__group + ) diff --git a/python/CuTeDSL/cutlass/torch.py b/python/CuTeDSL/cutlass/torch.py index 066c2816..e5ee5777 100644 --- a/python/CuTeDSL/cutlass/torch.py +++ b/python/CuTeDSL/cutlass/torch.py @@ -9,6 +9,8 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. +import ctypes +from math import prod from dataclasses import dataclass from enum import Enum from typing import Optional, Type, Union @@ -54,6 +56,25 @@ def dtype(ty: Type[Numeric]): return torch_dtype +def as_tensor(pointer, shape, torch_type): + """Convert a pointer to a torch tensor""" + if torch_type.itemsize == 1: + cytype = ctypes.c_uint8 + elif torch_type.itemsize == 2: + cytype = ctypes.c_uint16 + elif torch_type.itemsize == 4: + cytype = ctypes.c_uint32 + elif torch_type.itemsize == 8: + cytype = ctypes.c_uint64 + else: + raise ValueError(f"Unsupported torch dtype: {torch_type}") + cpointer = ctypes.cast(pointer, ctypes.POINTER(cytype)) + arr = (cpointer._type_ * prod(shape)).from_address( + ctypes.addressof(cpointer.contents) + ) + return torch.frombuffer(arr, dtype=torch_type).view(*shape) + + @dataclass class ScalarInitConfig: """Configuration for scalar initialization""" @@ -128,7 +149,7 @@ def create_and_permute_torch_tensor( if not isinstance(init_config, GaussianInitConfig): raise ValueError("init_config must be GaussianInitConfig()") f32_torch_tensor = init_torch_tensor.normal_(init_config.mean, init_config.std) - f32_torch_tensor = f32_torch_tensor * (1 << init_config.scale) + f32_torch_tensor = f32_torch_tensor * init_config.scale else: raise ValueError(f"Invalid init type: {init_type}") diff --git a/python/CuTeDSL/cutlass/utils/__init__.py b/python/CuTeDSL/cutlass/utils/__init__.py index 39add25e..aec0a186 100644 --- a/python/CuTeDSL/cutlass/utils/__init__.py +++ b/python/CuTeDSL/cutlass/utils/__init__.py @@ -64,6 +64,18 @@ from .smem_capacity import ( get_smem_capacity_in_bytes, ) +from .distributed_helpers import ( + spin_lock_wait, + spin_lock_multimem_arrive, + multimem_ld_reduce_8xf16, + multimem_ld_reduce_4xf32, + multimem_ld_reduce_8xbf16, + multimem_ld_reduce_16xe4m3, + multimem_ld_reduce_16xe5m2, + multimem_st_4xb32, + sm_wise_inter_gpu_multimem_barrier, +) + __all__ = [ "get_smem_capacity_in_bytes", "SmemAllocator", diff --git a/python/CuTeDSL/cutlass/utils/distributed_helpers.py b/python/CuTeDSL/cutlass/utils/distributed_helpers.py new file mode 100644 index 00000000..5853c56c --- /dev/null +++ b/python/CuTeDSL/cutlass/utils/distributed_helpers.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from functools import partial +from typing import Tuple + +import cutlass.cute as cute +from cutlass.cutlass_dsl import T, dsl_user_op, while_generate + +from cutlass._mlir import ir +from cutlass._mlir.dialects import arith, llvm, nvvm, scf +from cutlass._mlir.dialects.nvvm import ( + MemOrderKind, + MemScopeKind, + AtomicOpKind, +) +from cutlass.cute.typing import Pointer, Int32, Boolean + + +@dsl_user_op +def atomicAdd(dst_ptr: Pointer, val: Int32, loc=None, ip=None) -> Int32: + return nvvm.atomicrmw( + T.i32(), + AtomicOpKind.ADD, + dst_ptr.llvm_ptr, + val.ir_value(loc=loc, ip=ip), + mem_order=MemOrderKind.RELAXED, + syncscope=MemScopeKind.SYS, + loc=loc, + ip=ip, + ) + + +@cute.jit +def ld_bypass(input_tensor: cute.Tensor): + fragment = cute.make_fragment(input_tensor.layout, input_tensor.element_type) + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + input_tensor.element_type, + memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE, + memory_scope=cute.nvgpu.common.MemoryScope.SYS, + ) + cute.copy(copy_atom_load, input_tensor, fragment) + vals = fragment.load() + return vals + +@cute.jit +def spin_lock_wait(lock_ptr: Pointer, expect_count: Int32, mem_order : str = "relaxed", mem_scope : str = "gpu", loc=None, ip=None) -> None: + """ + wait on a spin lock until the expected count is reached. + """ + res = 0 + while res != expect_count: + res = nvvm.atomicrmw( + T.i32(), + AtomicOpKind.CAS, + lock_ptr.llvm_ptr, + Int32(0).ir_value(loc=loc, ip=ip), + b=Int32(expect_count).ir_value(loc=loc, ip=ip), + mem_order=MemOrderKind.ACQUIRE if mem_order == "acquire" else MemOrderKind.RELAXED, + syncscope=MemScopeKind.GPU if mem_scope == "gpu" else MemScopeKind.SYS + ) + + +@dsl_user_op +def multimem_red_add_sys_release(mc_ptr: Pointer, loc=None, ip=None) -> None: + """ + add 1 to the multimem address + """ + llvm.inline_asm( + None, + [mc_ptr.toint().ir_value()], + "multimem.red.release.sys.global.add.u32 [$0], 1;", + "l", + has_side_effects=True, + asm_dialect=0, + loc=loc, + ip=ip, + ) + +@dsl_user_op +def multimem_red_add_gpu_relaxed(mc_ptr: Pointer, loc=None, ip=None) -> None: + """ + add 1 to the multimem address + """ + llvm.inline_asm( + None, + [mc_ptr.toint().ir_value()], + "multimem.red.relaxed.gpu.global.add.u32 [$0], 1;", + "l", + has_side_effects=True, + asm_dialect=0, + loc=loc, + ip=ip, + ) + + +def spin_lock_multimem_arrive(lock_ptr: Pointer, loc=None, ip=None) -> None: + """ + arrive a spin lock when the lock_ptr is a multimem address. + """ + multimem_red_add_gpu_relaxed(lock_ptr, loc=loc, ip=ip) + + +def sm_wise_inter_gpu_multimem_barrier(barrier : Pointer, barrier_mc : Pointer, num_ranks, loc=None, ip=None) -> None : + """ + barrier for inter-gpu sm-wise + """ + bidx, bidy, bidz = cute.arch.block_idx() + bdimx, bdimy, _ = cute.arch.grid_dim() + pid = bidx + bidy * bdimx + bidz * bdimx * bdimy + multimem_red_add_sys_release(barrier_mc + pid, loc=loc, ip=ip) + cute.arch.fence_proxy(cute.arch.ProxyKind.alias) + spin_lock_wait(barrier + pid, num_ranks, mem_order="acquire", mem_scope="sys", loc=loc, ip=ip) + + +@dsl_user_op +def multimem_ld_reduce_base( + mc_ptr: Pointer, + *, + ptx_string: str = "", + loc=None, + ip=None, +) -> Tuple[Int32, Int32, Int32, Int32]: + # ld reduce 8xf16 elts + mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value() + return_struct = llvm.inline_asm( + ir.Type.parse("!llvm.struct<(i32,i32,i32,i32)>"), + [mc_ptr_int], + ptx_string, + "=r,=r,=r,=r,l", + has_side_effects=True, + asm_dialect=0, + loc=loc, + ip=ip, + ) + return_regs = [llvm.extractvalue(T.i32(), return_struct, [i]) for i in range(4)] + return return_regs[0], return_regs[1], return_regs[2], return_regs[3] + + +multimem_ld_reduce_8xf16 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.f16x2 {$0, $1, $2, $3}, [$4];") +multimem_ld_reduce_4xf32 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.v4.f32 {$0, $1, $2, $3}, [$4];") +multimem_ld_reduce_8xbf16 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.bf16x2 {$0, $1, $2, $3}, [$4];") +multimem_ld_reduce_16xe4m3 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e4m3x4 {$0, $1, $2, $3}, [$4];") +multimem_ld_reduce_16xe5m2 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e5m2x4 {$0, $1, $2, $3}, [$4];") + + +@dsl_user_op +def multimem_st_4xb32( + mc_ptr: Pointer, + x: Int32, + y: Int32, + z: Int32, + w: Int32, + *, + loc=None, + ip=None, +) -> None: + # st 4x32 bits of data + mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + T.i32(), + [mc_ptr_int, x, y, z, w], + "multimem.st.sys.relaxed.global.v4.f32 [$1], {$2, $3, $4, $5};", + "=r,l,r,r,r,r", + has_side_effects=True, + asm_dialect=0, + loc=loc, + ip=ip, + ) + diff --git a/python/CuTeDSL/cutlass/utils/layout.py b/python/CuTeDSL/cutlass/utils/layout.py index a1261d4d..4560c266 100644 --- a/python/CuTeDSL/cutlass/utils/layout.py +++ b/python/CuTeDSL/cutlass/utils/layout.py @@ -34,18 +34,6 @@ class LayoutEnum(Enum): else warpgroup.OperandMajorMode.MN ) - def is_k_major_a(self): - return self == LayoutEnum.ROW_MAJOR - - def is_m_major_a(self): - return self == LayoutEnum.COL_MAJOR - - def is_k_major_b(self): - return self == LayoutEnum.COL_MAJOR - - def is_n_major_b(self): - return self == LayoutEnum.ROW_MAJOR - def is_n_major_c(self): return self == LayoutEnum.ROW_MAJOR diff --git a/python/CuTeDSL/cutlass/utils/smem_allocator.py b/python/CuTeDSL/cutlass/utils/smem_allocator.py index 9490fb58..2500c06e 100644 --- a/python/CuTeDSL/cutlass/utils/smem_allocator.py +++ b/python/CuTeDSL/cutlass/utils/smem_allocator.py @@ -11,7 +11,7 @@ from typing import Type, Union, overload -from cutlass.cutlass_dsl import Int8, Numeric, NumericMeta +from cutlass.cutlass_dsl import Int8, Numeric, NumericMeta, CutlassBaseDSL import cutlass.cute as cute from cutlass.cute.arch import get_dyn_smem, get_dyn_smem_size @@ -40,14 +40,17 @@ class SmemAllocator: """ self._base = get_dyn_smem(Int8, alignment=1024) self._allocated_bytes = 0 + CutlassBaseDSL.track_smem_allocator(self, lambda cls: cls._allocated_bytes) @overload - def allocate(self, size_or_type: int, byte_alignment: int): ... + def allocate(self, size_or_type: int, byte_alignment: int) -> cute.Pointer: ... @overload - def allocate(self, size_or_type: cute.struct, byte_alignment: int): ... + def allocate( + self, size_or_type: cute.struct, byte_alignment: int + ) -> cute.Pointer: ... - def allocate(self, size_or_type, byte_alignment: int = 1) -> int: + def allocate(self, size_or_type, byte_alignment: int = 1) -> cute.Pointer: """Allocate a block of memory with specified size and alignment. This method adjusts the base pointer to ensure proper alignment and updates diff --git a/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py b/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py index 1a4d13de..2873244d 100644 --- a/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py +++ b/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py @@ -382,3 +382,5 @@ class StaticPersistentTileScheduler: @property def num_tiles_executed(self) -> Int32: return self._num_tiles_executed + + diff --git a/python/CuTeDSL/cutlass_dsl/__init__.py b/python/CuTeDSL/cutlass_dsl/__init__.py index 4d4b4ad2..5492fb51 100644 --- a/python/CuTeDSL/cutlass_dsl/__init__.py +++ b/python/CuTeDSL/cutlass_dsl/__init__.py @@ -17,6 +17,7 @@ from ..base_dsl.ast_helpers import ( if_executor, while_selector, while_executor, + range, range_constexpr, range_dynamic, const_expr, @@ -28,6 +29,8 @@ from ..base_dsl.ast_helpers import ( all_executor, range_value_check, range_perf_warning, + cf_symbol_check, + redirect_builtin_function, ) from ..base_dsl import * @@ -38,5 +41,4 @@ from ..base_dsl._mlir_helpers.op import dsl_user_op from ..base_dsl.runtime import * from ..base_dsl.runtime import cuda as cuda_helpers from ..base_dsl.compiler import compile -from ..base_dsl.runtime.dlpack_runtime import * from ..base_dsl.runtime.jit_arg_adapters import * diff --git a/python/CuTeDSL/cutlass_dsl/cutlass.py b/python/CuTeDSL/cutlass_dsl/cutlass.py index e2461d50..1630c873 100644 --- a/python/CuTeDSL/cutlass_dsl/cutlass.py +++ b/python/CuTeDSL/cutlass_dsl/cutlass.py @@ -15,12 +15,14 @@ regarding to that dialect. """ # Local module imports -from typing import Callable, Union, Type, List, Union, Sequence, ForwardRef -from inspect import isclass +from itertools import chain +from types import GenericAlias, SimpleNamespace, UnionType +from typing import Callable, Union, Type, List, Union, Sequence, ForwardRef, Any import functools import pkgutil -from dataclasses import is_dataclass +from dataclasses import is_dataclass, fields from collections.abc import Sequence +import builtins from ..base_dsl import * from ..base_dsl import compiler @@ -51,20 +53,15 @@ from ..base_dsl.ast_helpers import ( while_selector, while_executor, assert_executor, + const_expr, + dynamic_expr, bool_cast, compare_executor, any_executor, all_executor, range_value_check, range_perf_warning, -) -from ..base_dsl.runtime.dlpack_runtime import ( - get_cute_tensor_c_pointer, - get_tensor_desc_shape_all, - get_tensor_desc_stride_all, - get_tensor_desc_element_type, - get_tensor_desc_is_in_device, - get_tensor_desc_assumed_align, + cf_symbol_check, ) from .cutlass_ast_decorators import ( @@ -73,6 +70,16 @@ from .cutlass_ast_decorators import ( _while_execute_dynamic, ) +from .tree_utils import ( + is_constexpr_field, + tree_flatten, + tree_unflatten, + PyTreeDef, + is_frozen_dataclass, + DSLTreeFlattenError, +) +from ..base_dsl.runtime.jit_arg_adapters import JitArgAdapterRegistry + # ============================================================================= # Cutlass DSL Base Abstract Class @@ -125,6 +132,46 @@ def is_cute_algebra_type(arg_spec): return False +def _get_c_pointers_cutlass(obj): + """ + This is an extended version of `get_c_pointers` that supports dataclasses, SimpleNamespace, and dict. + """ + if hasattr(obj, "__c_pointers__"): + return obj.__c_pointers__() + elif isinstance(obj, (tuple, list)): + return list(chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj)) + elif isinstance(obj, SimpleNamespace): + return list( + chain.from_iterable( + _get_c_pointers_cutlass(x) for x in obj.__dict__.values() + ) + ) + elif isinstance(obj, dict): + return list( + chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj.values()) + ) + elif is_dataclass(obj): + return list( + chain.from_iterable( + _get_c_pointers_cutlass(getattr(obj, f.name)) + for f in fields(obj) + if not is_constexpr_field(f) + ) + ) + elif isinstance(obj, set): + raise DSLRuntimeError( + "Sets are not supported in get_c_pointers to ensure order preservation", + context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", + suggestion="Consider using a list or tuple instead", + ) + else: + # Try get adapter + adapter = JitArgAdapterRegistry.get_registered_adapter(type(obj)) + if adapter is not None: + return _get_c_pointers_cutlass(adapter(obj)) + return [] + + class CutlassBaseDSL(BaseDSL): """This abstract class provides a DSL for Cutlass.""" @@ -137,16 +184,25 @@ class CutlassBaseDSL(BaseDSL): preprocess: bool = False, ): super().__init__( - name, - compiler_provider, - pass_sm_arch_name, - device_compilation_only, - preprocess, + name=name, + dsl_package_name=["cutlass"], + compiler_provider=compiler_provider, + pass_sm_arch_name=pass_sm_arch_name, + device_compilation_only=device_compilation_only, + preprocess=preprocess, ) + self._smem_usage_tracker: tuple = None + # this method is not useful for cutlass_dsl, so we only provide a dummy implementation. def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool: return False + # this method is not useful for cutlass_dsl, so we only provide a dummy implementation. + def _handle_tensor_descriptor( + self, maybe_tensor, arg_name: str, need_gpu_memory: bool + ) -> Any: + return False + def _build_gpu_module(self, attrs): self.gpu_module = gpu.GPUModuleOp(ir.StringAttr.get("kernels")) with ir.InsertionPoint(self.gpu_module.bodyRegion.blocks.append(*[])): @@ -229,8 +285,43 @@ class CutlassBaseDSL(BaseDSL): return version_hash + @staticmethod + def track_smem_allocator(allocator, callback): + """ + Tracks shared memory usage for kernel functions. + Find and set allocator to its parent dsl object. + """ + frame = inspect.currentframe().f_back + while frame: + obj = frame.f_locals.get("self", None) + if obj and isinstance(obj, CutlassBaseDSL): + obj._set_smem_tracking(allocator, callback) + return + frame = frame.f_back + warnings.warn("Cannot find parent dsl for allocator!", UserWarning) + + def _set_smem_tracking(self, allocator, callback): + # Registers an allocator and callback for current dsl + self._smem_usage_tracker = (allocator, callback) + + def _reset_smem_tracking(self): + # Clear an allocator and callback for current dsl + self._smem_usage_tracker = None + + def _get_smem_usage(self) -> int: + # Treat final allocated bytes of allocator as smem usage + if not self._smem_usage_tracker: + return 0 + allocator, callback = self._smem_usage_tracker + return callback(allocator) + def _kernel_helper(self, funcBody, *args, **kwargs): class _CutlassIrKernelGenHelper(BaseDSL._KernelGenHelper): + def __init__(self, dsl: CutlassBaseDSL): + super().__init__() + self.dsl = dsl + self.dsl._reset_smem_tracking() + def generate_func_op(self, arg_types, arg_attrs, kernel_name, loc=None): super().generate_func_op(arg_types, arg_attrs, kernel_name) self.func_op = func.FuncOp( @@ -272,6 +363,17 @@ class CutlassBaseDSL(BaseDSL): if cfg.has_cluster: cfg.cluster = [to_index(size) for size in cfg.cluster] + smem_usage = self.dsl._get_smem_usage() + if any(not isinstance(x, int) for x in [cfg.smem, smem_usage]): + pass # cannot compare dynamic value inside kernel to launch op in py + elif cfg.auto_smem: + cfg.smem = smem_usage + elif smem_usage > cfg.smem: + warnings.warn( + f"Potential error: specified kernel launch smem bytes " + f"({cfg.smem}) is smaller than kernel usage ({smem_usage})!", + UserWarning, + ) cfg.smem = const(cfg.smem) if not isinstance(cfg.async_deps, (list, tuple)): @@ -295,12 +397,13 @@ class CutlassBaseDSL(BaseDSL): return token if is_async else None return KernelLauncher( - self, _CutlassIrKernelGenHelper, funcBody, *args, **kwargs + self, + lambda: _CutlassIrKernelGenHelper(self), + funcBody, + *args, + **kwargs, ) - def _get_module_globals(self): - return globals() - def _preprocess_launch_config_args(self, args, kwargs): """Helper to preprocess args and kwargs for LaunchConfig""" if "stream" in kwargs: @@ -316,7 +419,10 @@ class CutlassBaseDSL(BaseDSL): Validates if the arg is really of the annotated type. """ - if is_arg_spec_constexpr(arg_annotation, arg_name, arg_index, None): + if ( + is_arg_spec_constexpr(arg_annotation, arg_name, arg_index, None) + or arg_annotation is Any + ): pass else: origin = get_origin(arg_annotation) @@ -329,11 +435,12 @@ class CutlassBaseDSL(BaseDSL): f"expects argument #{arg_index+1} ({arg_name}) to be Type[{expected_base}], but got {arg}" ) # Handle Union types and generic types - elif origin is Union: + elif origin is Union or isinstance(arg_annotation, UnionType): # For Union types, check if arg matches any of the allowed types allowed_types = get_args(arg_annotation) if not any( - (isinstance(ty, type) and isinstance(arg, ty)) + (ty is Any) + or (isinstance(ty, type) and isinstance(arg, ty)) or (get_origin(ty) is tuple and isinstance(arg, tuple)) for ty in allowed_types ): @@ -381,6 +488,26 @@ class CutlassBaseDSL(BaseDSL): jit_exec_arg.extend(get_c_pointers(arg) if is_host else dyn_vals) else: jit_exec_arg = jit_arg_type = jit_arg_attr = None + elif not hasattr(arg, "__extract_mlir_values__") and not hasattr( + arg, "__new_from_mlir_values__" + ): + # Try tree_flatten + try: + dyn_vals, _ = tree_flatten(arg) + except DSLTreeFlattenError: + # If fails, just return the original arg + return jit_exec_arg, jit_arg_type, jit_arg_attr + + if dyn_vals: + jit_arg_type.extend([v.type for v in dyn_vals]) + jit_arg_attr.extend([default_attr] * len(dyn_vals)) + jit_exec_arg.extend( + _get_c_pointers_cutlass(arg) if is_host else dyn_vals + ) + else: + # If tree flatten yields empty list, treat it as a constexpr thing + # Like a dataclass with all fields are constexpr, or an empty tuple or list + jit_exec_arg = jit_arg_type = jit_arg_attr = None return jit_exec_arg, jit_arg_type, jit_arg_attr def _generate_execution_arguments_for_known_types( @@ -396,6 +523,17 @@ class CutlassBaseDSL(BaseDSL): blk_args = fop_args[iv_block_args : iv_block_args + n_args] ir_arg.append(new_from_mlir_values(arg, blk_args)) iv_block_args += n_args + elif not hasattr(arg, "__extract_mlir_values__") and not hasattr( + arg, "__new_from_mlir_values__" + ): + # Try tree_unflatten + try: + dyn_vals, tree_def = tree_flatten(arg) + block_args = fop_args[iv_block_args : iv_block_args + len(dyn_vals)] + ir_arg.append(tree_unflatten(tree_def, block_args)) + iv_block_args += len(dyn_vals) + except DSLTreeFlattenError: + return ir_arg, iv_block_args return ir_arg, iv_block_args @@ -458,10 +596,7 @@ class KernelLauncher: def _check_func_args(self, funcBody, *func_args, **func_kwargs): # Get function signature - if isinstance(funcBody, DSLCallable): - sig = funcBody.get_signature() - else: - sig = inspect.signature(funcBody) + sig = inspect.signature(funcBody) # func_args and func_kwargs should match funcBody's signature, # no extra or missing arguments. @@ -473,6 +608,12 @@ class KernelLauncher: cause=e, ) + def smem_usage(self) -> int: + """ + Check smem usage for this kernel, only available after `launch` + """ + return self.dsl._get_smem_usage() + def launch(self, *args, **kwargs): self.dsl.frame = inspect.currentframe().f_back self.dsl._preprocess_launch_config_args(args, kwargs) @@ -497,134 +638,151 @@ class KernelLauncher: # ============================================================================= # Utils # ============================================================================= - - -def is_frozen_dataclass(obj_or_cls) -> bool: +def _filter_readonly_frozen_dataclass( + iter_args: List[Any], items_to_filter: List[Any], full_write_args_count: int +) -> List[Any]: """ - Return True if obj_or_cls is a dataclass (class or instance) declared with frozen=True, - otherwise False. - """ - if not isinstance(obj_or_cls, type): - # If it's an instance, get its class - obj_or_cls = obj_or_cls.__class__ + Filter items based on whether corresponding iter_args are frozen dataclasses. - # Must be a dataclass, and __dataclass_params__.frozen must be True - return ( - is_dataclass(obj_or_cls) - and getattr(obj_or_cls, "__dataclass_params__", None) is not None - and obj_or_cls.__dataclass_params__.frozen + This function filters items (which can be values or names) based on the same + logic: keep items if they correspond to full-write arguments (index < full_write_args_count) + or if the corresponding iter_arg is not a frozen dataclass. + + Args: + iter_args: List of arguments to check for frozen dataclass status + items_to_filter: List of items to filter (values or names) + full_write_args_count: Number of arguments that are always written (not read-only) + + Returns: + Filtered list of items + + Examples: + # Filter values (original remove_read_only_frozen_dataclass behavior) + filtered_values = _filter_readonly_frozen_dataclass(iter_args, iter_args, full_write_args_count) + + # Filter names (original filter_readonly_frozen_dataclass_names behavior) + filtered_names = _filter_readonly_frozen_dataclass(iter_args, iter_args_names, full_write_args_count) + """ + return [ + item + for i, item in enumerate(items_to_filter) + if i < full_write_args_count or not is_frozen_dataclass(iter_args[i]) + ] + + +def remove_read_only_frozen_dataclass( + iter_args: List[Any], full_write_args_count: int +) -> List[Any]: + """Filter out frozen dataclass arguments that are not full-write arguments.""" + return _filter_readonly_frozen_dataclass( + iter_args, iter_args, full_write_args_count ) +def filter_readonly_frozen_dataclass_names( + iter_args: List[Any], iter_args_names: List[str], full_write_args_count: int +) -> List[str]: + """Filter names based on whether corresponding iter_args are frozen dataclasses.""" + return _filter_readonly_frozen_dataclass( + iter_args, iter_args_names, full_write_args_count + ) + + +def insert_read_only_frozen_dataclass( + iter_args: List[Any], original_iter_args: List[Any], full_write_args_count: int +) -> List[Any]: + """ + Insert read-only frozen dataclass arguments back into the iteration arguments. + + This function takes the new iteration arguments and the original arguments, + and preserves frozen dataclass instances from the original arguments while + using the new arguments for non-frozen dataclass instances. + + Args: + iter_args: New iteration arguments to use for non-frozen dataclass instances + original_iter_args: Original iteration arguments to preserve frozen dataclass instances from + full_write_args_count: Number of arguments that are always written (not read-only) + + Returns: + List of arguments with frozen dataclass instances preserved from original + """ + # Take full-write arguments from new iter_args + full_write_args = ( + iter_args[:full_write_args_count] if full_write_args_count > 0 else [] + ) + + # Process remaining arguments: preserve frozen dataclass from original, use new for others + remaining_original = original_iter_args[full_write_args_count:] + remaining_new = iter_args[full_write_args_count:] + + def process_remaining_arg(original_arg, new_arg_iter): + """Process a single remaining argument, preserving frozen dataclass if present""" + return original_arg if is_frozen_dataclass(original_arg) else next(new_arg_iter) + + # Use zip to pair original args with new args, then map the processing function + new_arg_iter = iter(remaining_new) + processed_remaining = [ + process_remaining_arg(orig_arg, new_arg_iter) for orig_arg in remaining_original + ] + + return full_write_args + processed_remaining + + +def unpack_to_irvalue( + mixed_values: List[Any], body_name: str, full_write_args_count: int +) -> Tuple[List[ir.Value], PyTreeDef]: + log().debug("===--- Values UNPack") + for idx, packed in enumerate(mixed_values): + log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed) + + try: + unpacked_values, treedef = tree_flatten( + remove_read_only_frozen_dataclass(mixed_values, full_write_args_count) + ) + except DSLTreeFlattenError as e: + raise DSLRuntimeError( + f"The '{body_name}' statement encountered a user-defined Python object, which cannot be automatically converted into an dynamic expression.", + context={ + e.message: ( + f"All expressions within '{body_name}' must be dynamic expressions, " + "mixing Python objects and dynamic expressions is not supported. " + "The DSL failed to convert the Python object into dynamic expressions." + ) + }, + suggestion=( + f"Please ensure '{e.type_str}' implements the '{DynamicExpression.__name__}' or mark with `dataclass`, " + f"so it can be treated as a valid dynamic expression or mark '{body_name}' as a constant expression if conditions are Python objects." + ), + ) + + log().debug("------------------ ") + for idx, unpacked in enumerate(unpacked_values): + log().debug("[%d]: unpacked values: %s", idx, unpacked) + log().debug("treedef: %s", treedef) + log().debug("------------------ ") + + return unpacked_values, treedef + + def pack_from_irvalue( ir_values: List["ir.Value"], - indices: Dict[int, Tuple[int, int]], - class_types: List[Any], + pytree_def: PyTreeDef, + mixed_values: List[Any], + full_write_args_count: int, ) -> List[Any]: """ Packs MLIR values into a list of mixed values. """ log().debug("===--- Values Pack (%d)", len(ir_values)) - for idx, packed in enumerate(ir_values): - log().debug("[%d]: will-packed: %s", idx, ir_values) - for idx, unpacked in indices.items(): - log().debug("[%d]: indices: %s", idx, unpacked) - for idx, c in enumerate(class_types): - log().debug("[%d]: obj-types: %s", idx, type(c)) - - mixed_values = [None] * len(indices) - for idx, (start, length) in sorted(indices.items()): - chunk = ir_values[start : start + length] - obj = class_types[idx] - if is_frozen_dataclass(obj): - mixed_values[idx] = obj - elif not isinstance(obj, type) and hasattr(obj, "__new_from_mlir_values__"): - mixed_values[idx] = obj.__new_from_mlir_values__(chunk) - elif isinstance(chunk, list) and chunk[0] is None: - mixed_values[idx] = class_types[idx] - else: - if len(chunk) == 1: - try: - mixed_values[idx] = t.as_numeric(chunk[0]) - except ValueError: - # Suppress the conversion error and try new_from_mlir_values below - pass - - if mixed_values[idx] is None: - mixed_values[idx] = new_from_mlir_values(obj, chunk) - - log().debug("------------------ ") - for idx, packed in enumerate(mixed_values): - log().debug("[%d]: packed: %s", idx, packed) - log().debug("------------------ ") - return mixed_values - - -def unpack_to_irvalue( - mixed_values: List[Any], body_name: str -) -> Tuple[List[ir.Value], List[Any], Dict[int, Tuple[int, int]], List[Any]]: - """ - Unpacks mixed values into ir.Value values. - """ - unpacked_values = [] - ir_values = [] - indices = {} - class_types = [] - current_offset = 0 - - log().debug("===--- Values UNPack (%d)", len(mixed_values)) - for idx, packed in enumerate(mixed_values): - log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed) - for idx, item in enumerate(mixed_values): - class_types.append(item) - try: - if is_frozen_dataclass(item): - extracted_vals = [None] - else: - extracted_vals = extract_mlir_values(item) - # it's consexpr (python value), so we create mlir value for it - if extracted_vals == []: - if item is None: - extracted_vals = [None] - else: - dyn_expr = t.as_numeric(item) - extracted_vals = extract_mlir_values(dyn_expr) - ir_values.extend(extracted_vals) - else: - ir_values.extend(extracted_vals) - - unpacked_values.extend(extracted_vals) - length = len(extracted_vals) - indices[idx] = (current_offset, length) - current_offset += length - except Exception as e: - raise DSLRuntimeError( - f"The '{body_name}' statement encountered a user-defined Python object, which cannot be automatically converted into an dynamic expression (aka MLIR value).", - context={ - item: ( - f"All expressions within '{body_name}' must be dynamic expressions, " - "mixing Python objects and dynamic expressions (aka MLIR values) is not supported. " - "The DSL failed to convert the Python object into MLIR values." - ) - }, - suggestion=( - f"Please ensure '{item}' implements the '{DynamicExpression.__name__}', " - f"so it can be treated as a valid dynamic expression or mark '{body_name}' as a constant expression if conditions are Python objects." - ), - ) from e - - log().debug("------------------ ") - for idx, unpacked in enumerate(unpacked_values): - log().debug("[%d]: unpacked values: %s", idx, unpacked) - for idx, unpacked in enumerate(ir_values): - log().debug("[%d]: unpacked ir_values: %s", idx, unpacked) - for idx, unpacked in indices.items(): - log().debug("[%d]: indices: %s", idx, unpacked) - for idx, unpacked in enumerate(class_types): - log().debug("[%d]: initial-class-types: %s", idx, unpacked) + for idx, value in enumerate(ir_values): + log().debug("[%d]: will-packed: %s", idx, value) + log().debug("treedef: %s", pytree_def) log().debug("------------------ ") - return ir_values, unpacked_values, indices, class_types + unflattened = tree_unflatten(pytree_def, ir_values) + return insert_read_only_frozen_dataclass( + unflattened, mixed_values, full_write_args_count + ) def to_index(value): @@ -1015,8 +1173,8 @@ def any_(iterable): def select_(cond, if_value, else_value): def _as_scalar(value): - if const_expr(isinstance(value, list)): - if const_expr(len(value) == 1): + if isinstance(value, list): + if len(value) == 1: return value[0] else: raise DSLRuntimeError( @@ -1024,16 +1182,16 @@ def select_(cond, if_value, else_value): ) return value - if const_expr(not is_dynamic_expression(cond)): + if not is_dynamic_expression(cond): raise DSLRuntimeError("Conditional expression must be dynamic") # Extract MLIR values cond = extract_mlir_values(cond) - if const_expr(is_dynamic_expression(if_value)): + if is_dynamic_expression(if_value): if_value = extract_mlir_values(if_value) else: if_value = const(if_value) - if const_expr(is_dynamic_expression(else_value)): + if is_dynamic_expression(else_value): else_value = extract_mlir_values(else_value) else: else_value = const(else_value) @@ -1089,7 +1247,7 @@ def for_generate( iter_args: Optional[Sequence[ir.Value]] = None, *, unroll: LoopUnroll = None, - pipelining=None, + prefetch_stages=None, loc=None, ip=None, ): @@ -1127,8 +1285,8 @@ def for_generate( if unroll is not None: for_op.attributes["loop_annotation"] = unroll - if pipelining is not None: - for_op.attributes["cutlass.pipelining"] = _createI32Attr(pipelining) + if prefetch_stages is not None: + for_op.attributes["cutlass.pipelining"] = _createI32Attr(prefetch_stages) iv = for_op.induction_variable new_results = new_from_mlir_values(iter_args, for_op.results) @@ -1155,11 +1313,11 @@ def not_(lhs: Union[ir.Value, bool], *, loc=None, ip=None): """ res = None # Handle Python bool first to prevent infinite recursion - if const_expr(type(lhs) == bool): + if type(lhs) == bool: res = lhs ^ True - elif const_expr(hasattr(lhs, "__dsl_not__")): + elif hasattr(lhs, "__dsl_not__"): res = lhs.__dsl_not__(loc=loc, ip=ip) - elif const_expr(is_dynamic_expression(lhs)): + elif is_dynamic_expression(lhs): # If lhs is MLIR value, compute not using xor res = arith.XOrIOp(lhs, const(1, lhs.type)).result else: @@ -1338,29 +1496,59 @@ def equal(lhs, rhs): return lhs == rhs -def in_(lhs, rhs, op): +def not_equal(lhs, rhs): + if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): + return lhs != rhs + + # Both sequence + if isinstance(lhs, Sequence) and isinstance(rhs, Sequence): + # Short-circuit for unequal length + if len(lhs) != len(rhs): + return True + return any_(not_equal(l, r) for l, r in zip(lhs, rhs)) + + if hasattr(lhs, "__ne__"): + return lhs != rhs + elif hasattr(rhs, "__ne__"): + return rhs != lhs + else: + return not_(equal(lhs, rhs)) + + +def in_(lhs, rhs): if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): return lhs in rhs if not isinstance(rhs, Sequence): raise DSLRuntimeError( - f"'{op}' not supported between instances of {type(lhs)} and {type(rhs)}" + f"'in' not supported between instances of {type(lhs)} and {type(rhs)}" ) return any_(equal(lhs, r) for r in rhs) -def _lt_gt(lhs, rhs, op): - def native_lt_gt(lhs, rhs, op): - if op == "<": - return lhs < rhs - elif op == ">": - return lhs > rhs - else: - raise DSLRuntimeError(f"Unsupported comparison operator: {op}") +def _lte_gte(lhs, rhs, op): + def native_lte_gte(lhs, rhs, op): + match op: + case "<": + return lhs < rhs + case "<=": + if hasattr(lhs, "__le__"): + return lhs <= rhs + else: + return not_(lhs > rhs) + case ">": + return lhs > rhs + case ">=": + if hasattr(lhs, "__ge__"): + return lhs >= rhs + else: + return not_(lhs < rhs) + case _: + raise DSLRuntimeError(f"Unsupported comparison operator: {op}") if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): - return native_lt_gt(lhs, rhs, op) + return native_lte_gte(lhs, rhs, op) # Both sequence, comparisons other than == and != do not allow mixing different types of sequences if ( @@ -1375,7 +1563,7 @@ def _lt_gt(lhs, rhs, op): is_equal = equal(l, r) mask.append(not_(or_(is_equal, unequal_found))) unequal_found = not_(is_equal) - comp_results.append(_lt_gt(l, r, op)) + comp_results.append(_lte_gte(l, r, op)) result = any_(and_(r, m) for r, m in zip(comp_results, mask)) @@ -1383,62 +1571,126 @@ def _lt_gt(lhs, rhs, op): # Ref https://docs.python.org/3/tutorial/datastructures.html#comparing-sequences-and-other-types # If one sequence is an initial sub-sequence of the other, the shorter sequence is the smaller (lesser) one has_valid_mask = any_(mask) - if op == "<": - length_result = len(lhs) < len(rhs) - elif op == ">": - length_result = len(lhs) > len(rhs) + match op: + case "<": + length_result = len(lhs) < len(rhs) + case ">": + length_result = len(lhs) > len(rhs) + case "<=": + length_result = len(lhs) <= len(rhs) + case ">=": + length_result = len(lhs) >= len(rhs) if type(has_valid_mask) == bool: return result if has_valid_mask else length_result else: return select_(has_valid_mask, result, length_result) else: - return result + if op in {"<=", ">="}: + # If no unequal, return True + return select_(unequal_found, result, True) + else: + return result else: - return native_lt_gt(lhs, rhs, op) + return native_lte_gte(lhs, rhs, op) def greater_than(lhs, rhs): - return _lt_gt(lhs, rhs, ">") + return _lte_gte(lhs, rhs, ">") + + +def greater_equal(lhs, rhs): + return _lte_gte(lhs, rhs, ">=") def less_than(lhs, rhs): - return _lt_gt(lhs, rhs, "<") + return _lte_gte(lhs, rhs, "<") + + +def less_equal(lhs, rhs): + return _lte_gte(lhs, rhs, "<=") + + +def _compare_dispatch(lhs, rhs, op): + """ + Dispatches the comparison operation between lhs and rhs based on the given operator. + + :param lhs: The left-hand side operand for the comparison. + :param rhs: The right-hand side operand for the comparison. + :param op: The comparison operator as a string. Supported operators are: + - "is", "is not": Python identity comparisons. + - "in", "not in": Membership tests. + - "==", "!=": Equality and inequality. + - "<", ">", "<=", ">=": Relational comparisons. + :return: The result of the comparison, which may be a boolean or a DSL-specific type. + :raises DSLRuntimeError: If the operator is not supported. + """ + match op: + # 'is' and 'is not' are pure python operators + case "is": + return lhs is rhs + case "is not": + return lhs is not rhs + case "in": + return in_(lhs, rhs) + case "not in": + return not_(in_(lhs, rhs)) + case "==": + return equal(lhs, rhs) + case "!=": + return not_equal(lhs, rhs) + case "<": + return less_than(lhs, rhs) + case ">": + return greater_than(lhs, rhs) + case ">=": + return greater_equal(lhs, rhs) + case "<=": + return less_equal(lhs, rhs) + case _: + raise DSLRuntimeError(f"Unsupported comparison operator: {op}") def _compare_executor(left, comparators, ops): - result = left + # Fast path for single comparison + if len(comparators) == 1: + return _compare_dispatch(left, comparators[0], ops[0]) + + # Chain comparison, dispatch in a loop + result = True + current = left for comparator, op in zip(comparators, ops): - # 'is' and 'is not' are pure python operators - if op == "is": - result = result is comparator - elif op == "is not": - result = result is not comparator - elif op in ["in", "not in"]: - result = in_(left, comparator, op) - elif op in ["==", "!="]: - result = equal(left, comparator) - elif op in ["<", ">="]: - result = less_than(left, comparator) - elif op in [">", "<="]: - result = greater_than(left, comparator) - else: - raise DSLRuntimeError(f"Unsupported comparison operator: {op}") - # Invert the result for NotIn, NotEq, GtE, LtE - if op in ["not in", "!=", ">=", "<="]: - result = not_(result) + cmp_result = _compare_dispatch(current, comparator, op) + result = and_(result, cmp_result) + current = comparator + return result + +def _builtin_redirector(fcn): + if fcn == builtins.max: + return max + elif fcn == builtins.min: + return min + elif fcn == builtins.any: + return any_ + elif fcn == builtins.all: + return all_ + else: + raise DSLRuntimeError(f"Unsupported built-in function: {fcn}") + + # ============================================================================= # Set the AST decorator # ============================================================================= # Set the DSL specific functions executor.set_functions( - is_dynamic_expression, - _loop_execute_range_dynamic, - _if_execute_dynamic, - _while_execute_dynamic, - _compare_executor, - any_, - all_, + is_dynamic_expression=is_dynamic_expression, + loop_execute_range_dynamic=_loop_execute_range_dynamic, + if_dynamic=_if_execute_dynamic, + while_dynamic=_while_execute_dynamic, + compare_executor=_compare_executor, + any_executor=any_, + all_executor=all_, + builtin_redirector=_builtin_redirector, ) diff --git a/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py b/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py index 370a0c9f..b5b4d895 100644 --- a/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py +++ b/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py @@ -14,13 +14,22 @@ from types import NoneType from cutlass._mlir import ir from cutlass._mlir.dialects import scf, arith from cutlass._mlir.extras import types as T +from collections.abc import Sequence -from ..base_dsl.dsl import extract_mlir_values, new_from_mlir_values +from ..base_dsl.dsl import is_dynamic_expression from ..base_dsl.ast_helpers import * from ..base_dsl.utils.logger import log from ..base_dsl import typing as t -from ..base_dsl.typing import Int32, Float32, Boolean, Numeric, get_mlir_types +from ..base_dsl.typing import ( + Int32, + Float32, + Boolean, + Numeric, + get_mlir_types, + as_numeric, +) from . import cutlass as cutlass_dsl +from .tree_utils import PyTreeDef, check_tree_equal # ============================================================================= # AST Helpers @@ -57,14 +66,6 @@ class ScfGenerator: def __init__(self): pass - @staticmethod - def fill_none(ir_values, unpacked_values): - i = 0 - for idx, item in enumerate(unpacked_values): - if item is not None: - unpacked_values[idx] = ir_values[i] - i += 1 - @staticmethod def _normalize_region_result_to_list(region_result: Any) -> List[Any]: """ @@ -82,34 +83,109 @@ class ScfGenerator: return region_result_list @staticmethod - def check_region_result(region_values, ir_values): - for i, (expected_value, actual_value) in enumerate( - zip(ir_values, region_values) + def _check_region_result(original_value, region_value, arg_name, op_type_name): + """ + Validate that a region result maintains the same type as the original value. + + This method checks for type consistency between the original value passed to a dynamic + SCF operation (like for, if, while) and the value returned from the operation's region. + + Args: + original_value: The value before entering the SCF operation region + region_value: The value returned from the SCF operation region + arg_name: Name of the argument being checked (for error reporting) + op_type_name: Type of SCF operation (e.g., 'for', 'if', 'while') for error reporting + + Raises: + DSLRuntimeError: If the region value has a different type than the original value. + The error includes suggestions for using compile-time control flow instead. + + Note: + This method performs relaxed type checking that allows inheritance relationships. + For example, a child class can be returned where a parent class was expected. + However, fundamental type changes (like None to non-None, different sequence types, + or different numeric types) are not allowed in dynamic SCF operations. + """ + + def get_type_name(value): + if isinstance(value, NoneType): + return "None" + elif isinstance(value, Sequence): + return f"{type(value).__name__}<{len(value)}>" + else: + return type(value).__name__ + + # Check for type mismatches + type_mismatch = False + old_type_name = None + new_type_name = None + + # Handle None type changes + if isinstance(original_value, NoneType) != isinstance(region_value, NoneType): + type_mismatch = True + old_type_name = get_type_name(original_value) + new_type_name = get_type_name(region_value) + # Handle sequence type/length changes + elif isinstance(original_value, Sequence) and isinstance( + region_value, Sequence ): - expected_value_type = get_mlir_types(expected_value) - actual_value_type = get_mlir_types(actual_value) - if expected_value_type != actual_value_type: - return False, i, expected_value_type, actual_value_type - return True, -1, None, None + if type(original_value) != type(region_value) or len(original_value) != len( + region_value + ): + type_mismatch = True + old_type_name = get_type_name(original_value) + new_type_name = get_type_name(region_value) + # Handle numeric type changes + elif isinstance( + original_value, (Numeric, ArithValue, ir.Value, int, float, bool) + ) or isinstance( + region_value, (Numeric, ArithValue, ir.Value, int, float, bool) + ): + try: + original_numeric = as_numeric(original_value) + region_numeric = as_numeric(region_value) + if original_numeric.dtype != region_numeric.dtype: + type_mismatch = True + old_type_name = original_numeric.dtype.__name__ + new_type_name = region_numeric.dtype.__name__ + except Exception: + pass + # Handle general type changes (relaxed for inheritance) + elif type(original_value) != type(region_value): + old_type = type(original_value) + new_type = type(region_value) + if not (issubclass(old_type, new_type) or issubclass(new_type, old_type)): + type_mismatch = True + old_type_name = old_type.__name__ + new_type_name = new_type.__name__ + + if type_mismatch: + raise DSLRuntimeError( + f"`{arg_name}` is {old_type_name} prior to this `{op_type_name}`, " + f"and update to {new_type_name} inside of this `{op_type_name}` is not supported.", + suggestion=( + f"Please avoid changing type inside a dynamic `{op_type_name}`, " + f"or change to compile-time control flow by marking this `{op_type_name}` with " + f"`{'range_constexpr' if op_type_name == 'for' else 'const_expr'}`." + ), + ) def scf_execute_dynamic( self, op_type_name: str, - used_args: List[Any], mix_iter_args: List[Any], + full_write_args_count: int, mix_iter_arg_names: List[str], - create_op_func: Callable[ - [List[ir.Value], Dict[int, Tuple[int, int]], List[Any]], ir.Operation - ], + create_op_func: Callable[[List[ir.Value]], ir.Operation], region_builders: List[ Callable[ [ "ir.Operation", List["ir.Value"], # block_args - List[Any], # used_args List["ir.Value"], # dyn_yield_ops - Dict[int, Tuple[int, int]], + PyTreeDef, List[Any], + int, ], Any, ] @@ -119,11 +195,11 @@ class ScfGenerator: block_term_op_builder: Dict[Callable, Callable] = {}, ) -> Any: # 1) Unpack - ir_values, dyn_unpacked_values, dyn_indices, dyn_class_types = ( - cutlass_dsl.unpack_to_irvalue(mix_iter_args, op_type_name) + ir_values, pytree_def = cutlass_dsl.unpack_to_irvalue( + mix_iter_args, op_type_name, full_write_args_count ) # 2) Create the SCF op - op = create_op_func(ir_values, dyn_indices, dyn_class_types) + op = create_op_func(ir_values) log().debug("Generated scf.%s \n[%s]", op_type_name, op) # 3) Build the regions @@ -135,76 +211,61 @@ class ScfGenerator: region_result = builder( op, block_args, - used_args, - dyn_unpacked_values, - dyn_indices, - dyn_class_types, + ir_values, + pytree_def, + mix_iter_args, + full_write_args_count, ) # Use custom terminator if provided for this builder, otherwise use default YieldOp if builder in block_term_op_builder: # Use the provided terminator generator - block_term_op_builder[builder](region_result) + block_term_op_builder[builder](region_result, full_write_args_count) else: - # For standard yield op, check result - for arg, result, name in zip( - mix_iter_args, - ( - region_result - if isinstance(region_result, list) - else [region_result] - ), - mix_iter_arg_names, - ): - if isinstance(arg, NoneType) and not isinstance( - result, NoneType - ): - raise DSLRuntimeError( - ( - f"`{name}` is None prior to this `{op_type_name}`, " - f"and update to non-None value inside of this `{op_type_name}` is not supported." - ), - suggestion=( - f"Please make sure `{name}` is not None prior to this `{op_type_name}`, " - f"or mark this `{op_type_name}` with " - f"`{'range' if op_type_name == 'for' else 'const_expr'}`." - ), - ) # Normalize region_result region_result_list = ScfGenerator._normalize_region_result_to_list( region_result ) - # Default behavior - generate YieldOp - region_values, unpacked_values, _, _ = ( - cutlass_dsl.unpack_to_irvalue(region_result_list, op_type_name) - ) - - is_match, mismatch_idx, expected_type, actual_type = ( - ScfGenerator.check_region_result(region_values, ir_values) - ) - - if not is_match: - # From unpacked index, we need to find the original index - original_idx = -1 - for unpacked_idx, (original_idx, length) in dyn_indices.items(): - if ( - mismatch_idx >= original_idx - and mismatch_idx < original_idx + length - ): - original_idx = unpacked_idx - break - raise DSLRuntimeError( - f"`{op_type_name}` expects {expected_type} type for varible `{mix_iter_arg_names[original_idx]}`, but got {actual_type}.", - suggestion=f"Please make sure `{mix_iter_arg_names[original_idx]}` type is not changed inside of `{op_type_name}`.", + # For standard yield op, check result + for arg, result, name in zip( + mix_iter_args, + region_result_list, + mix_iter_arg_names, + ): + ScfGenerator._check_region_result( + arg, result, name, op_type_name ) + + # Default behavior - generate YieldOp + region_values, yield_pytree_def = cutlass_dsl.unpack_to_irvalue( + region_result_list, op_type_name, full_write_args_count + ) + + mismatch = check_tree_equal(pytree_def, yield_pytree_def) + if mismatch != -1: + # Get arg name + filterd_arg_names = ( + cutlass_dsl.filter_readonly_frozen_dataclass_names( + mix_iter_args, mix_iter_arg_names, full_write_args_count + ) + ) + + raise DSLRuntimeError( + f"`{filterd_arg_names[mismatch]}` is structured different after this `{op_type_name}`.", + suggestion=( + f"Please avoid changing type structure inside a dynamic `{op_type_name}`, " + f"or change to compile-time control flow by marking this `{op_type_name}` with " + f"`{'range_constexpr' if op_type_name == 'for' else 'const_expr'}`." + ), + ) + scf.YieldOp(region_values) log().debug("Completed scf.%s \n[%s]", op_type_name, op) - ScfGenerator.fill_none(op.results, unpacked_values) # 4) Pack final results final_results = cutlass_dsl.pack_from_irvalue( - unpacked_values, dyn_indices, dyn_class_types + op.results, pytree_def, mix_iter_args, full_write_args_count ) # 5) Return in a nice pattern @@ -215,28 +276,32 @@ class ScfGenerator: return final_results +def _attr_const_check(attr, expected_type, attr_name): + # Use strict type equality to prevent `bool` being accepted where `int` is required. + if is_dynamic_expression(attr) or type(attr) is not expected_type: + raise DSLRuntimeError( + f"loop attribute `{attr_name}` must be a Python value of type `{expected_type.__name__}`, got `{type(attr).__name__}`." + ) + + def _loop_execute_range_dynamic( func: Callable, start: Any, stop: Any, step: Any, - used_args: List[Any] = [], mix_iter_args: List[Any] = [], + full_write_args_count: int = 0, mix_iter_arg_names: List[str] = [], unroll: int = -1, unroll_full: bool = False, - pipelining: int = None, + prefetch_stages: int = None, ): """ Example: build an scf.for with optional unroll, using our universal helper. """ scf_gen = ScfGenerator() - def create_for_op( - dyn_yield_ops: List[ir.Value], - dyn_indices: Dict[int, Tuple[int, int]], - dyn_class_types: List[Any], - ): + def create_for_op(dyn_yield_ops: List[ir.Value]): for d in dyn_yield_ops: if not isinstance(d, ir.Value): raise DSLRuntimeError( @@ -254,6 +319,10 @@ def _loop_execute_range_dynamic( stop_ = stop_.ir_value() step_ = step_.ir_value() + # Attributes must be pure Python value, add a check + _attr_const_check(unroll, int, "unroll") + _attr_const_check(unroll_full, bool, "unroll_full") + # Possibly attach unroll attributes unroll_attr = None if unroll_full: @@ -262,17 +331,18 @@ def _loop_execute_range_dynamic( unroll_attr = LoopUnroll(count=unroll) log().debug("Unroll attribute: %s", unroll_attr) - pipelining_attr = None - if pipelining is not None: - if pipelining >= 0: - pipelining_attr = ir.IntegerAttr.get( - ir.IntegerType.get_signless(32), pipelining + prefetch_stages_attr = None + if prefetch_stages is not None: + _attr_const_check(prefetch_stages, int, "prefetch_stages") + if prefetch_stages >= 0: + prefetch_stages_attr = ir.IntegerAttr.get( + ir.IntegerType.get_signless(32), prefetch_stages ) else: raise DSLRuntimeError( - f"Pipelining must be non-negative, got {pipelining}" + f"loop attribute `prefetch_stages` must be non-negative, got `{prefetch_stages}`." ) - log().debug("Pipelining attribute: %s", pipelining_attr) + log().debug("prefetch_stages attribute: %s", prefetch_stages_attr) log().debug( "Creating scf.ForOp \n\t\tstart=%s: type : %s\n\t\tstop=%s: type : %s\n\t\tstep=%s: type : %s", @@ -303,47 +373,48 @@ def _loop_execute_range_dynamic( if unroll_attr is not None: for_op.attributes["loop_annotation"] = unroll_attr - if pipelining_attr is not None: - for_op.attributes["cutlass.pipelining"] = pipelining_attr + if prefetch_stages_attr is not None: + for_op.attributes["cutlass.pipelining"] = prefetch_stages_attr return for_op def for_body_builder( - op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types + op, + block_args, + _, + pytree_def, + mix_iter_args, + full_write_args_count, ): - # Insert induction variable at the beginning - dyn_yield_ops.insert(0, block_args[0]) - ScfGenerator.fill_none(block_args, dyn_yield_ops) - block_args = dyn_yield_ops # scf.ForOp block_args are typically [induction_var, iter_args...] # But MLIR also gives you op.induction_variable iv = t.as_numeric(op.induction_variable) log().debug( - "For body builder: %s block_args: %s used_args: %s", + "For body builder: %s block_args: %s full_write_args_count: %s", iv, block_args, - used_args, + full_write_args_count, ) - if len(block_args) <= 1: + # block_args[1:] are iteration variables + func_args = [] + func_args.extend( + cutlass_dsl.pack_from_irvalue( + block_args[1:], pytree_def, mix_iter_args, full_write_args_count + ) + ) + if not func_args: # No iteration arguments, or only the induction var - func(iv, *used_args) + func(iv) return [] # yield nothing else: - # block_args[1:] are iteration variables - func_args = [*used_args] - func_args.extend( - cutlass_dsl.pack_from_irvalue( - block_args[1:], dyn_indices, dyn_class_types - ) - ) updated_func_args = func(iv, *func_args) return updated_func_args # Now call the universal SCF executor with a single region builder return scf_gen.scf_execute_dynamic( op_type_name="for", - used_args=used_args, mix_iter_args=mix_iter_args, + full_write_args_count=full_write_args_count, mix_iter_arg_names=mix_iter_arg_names, create_op_func=create_for_op, region_builders=[for_body_builder], @@ -354,8 +425,8 @@ def _if_execute_dynamic( pred: "ir.Value", then_block: Callable, else_block: Callable = None, - used_args: List[Any] = [], mix_yield_args: List[Any] = [], + full_write_args_count: int = 0, mix_yield_arg_names: List[str] = [], if_constexpr=None, # ignoring for brevity ): @@ -364,11 +435,7 @@ def _if_execute_dynamic( """ scf_gen = ScfGenerator() - def create_if_op( - dyn_yield_ops: List[ir.Value], - dyn_indices: Dict[int, Tuple[int, int]], - dyn_class_types: List[Any], - ): + def create_if_op(dyn_yield_ops: List[ir.Value]): # Assume final result types match the dynamic yields result_types = [arg.type for arg in dyn_yield_ops] @@ -387,11 +454,18 @@ def _if_execute_dynamic( return if_op def then_builder( - if_op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types + if_op, + _, + dyn_yield_ops, + pytree_def, + mix_iter_args, + full_write_args_count, ): - flat_args = [*used_args] + flat_args = [] flat_args.extend( - cutlass_dsl.pack_from_irvalue(dyn_yield_ops, dyn_indices, dyn_class_types) + cutlass_dsl.pack_from_irvalue( + dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count + ) ) return then_block(*flat_args) @@ -400,12 +474,17 @@ def _if_execute_dynamic( if else_block is not None: def else_builder( - if_op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types + if_op, + _, + dyn_yield_ops, + pytree_def, + mix_iter_args, + full_write_args_count, ): - flat_args = [*used_args] + flat_args = [] flat_args.extend( cutlass_dsl.pack_from_irvalue( - dyn_yield_ops, dyn_indices, dyn_class_types + dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count ) ) return else_block(*flat_args) @@ -414,8 +493,8 @@ def _if_execute_dynamic( return scf_gen.scf_execute_dynamic( op_type_name="if", - used_args=used_args, mix_iter_args=mix_yield_args, + full_write_args_count=full_write_args_count, mix_iter_arg_names=mix_yield_arg_names, create_op_func=create_if_op, region_builders=region_builders, @@ -425,9 +504,9 @@ def _if_execute_dynamic( def _while_execute_dynamic( while_before_block: Callable, while_after_block: Callable = None, - used_args=[], - yield_args=[], - yield_arg_names=[], + write_args=[], + full_write_args_count=0, + write_args_names=[], ): """ Create and return an SCF WhileOp for dynamic loops. @@ -436,8 +515,7 @@ def _while_execute_dynamic( Args: while_before_block: Function that returns (condition, updated_values) while_after_block: Function that returns updated values - used_args: Additional arguments used in the loop body - yield_args: Values that are updated in the loop + write_args: Values that are updated in the loop See create_while_function in ast_preprocessor.py for details on the input structure. """ @@ -445,11 +523,7 @@ def _while_execute_dynamic( while_op_type_name = "while" scf_gen = ScfGenerator() - def create_while_op( - dyn_yield_ops: List[ir.Value], - dyn_indices: Dict[int, Tuple[int, int]], - dyn_class_types: List[Any], - ): + def create_while_op(dyn_yield_ops: List[ir.Value]): # Create the while operation with the types from yield_args result_types = [arg.type for arg in dyn_yield_ops] try: @@ -468,14 +542,19 @@ def _while_execute_dynamic( ) from e def before_block_builder( - op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types + op, + block_args, + _, + pytree_def, + mix_iter_args, + full_write_args_count, ): # Build the before (condition) block - ScfGenerator.fill_none(block_args, dyn_yield_ops) - block_args = dyn_yield_ops - flat_args = [*used_args] + flat_args = [] flat_args.extend( - cutlass_dsl.pack_from_irvalue(block_args, dyn_indices, dyn_class_types) + cutlass_dsl.pack_from_irvalue( + block_args, pytree_def, mix_iter_args, full_write_args_count + ) ) log().debug("before block args: %s", flat_args) @@ -493,18 +572,15 @@ def _while_execute_dynamic( return cond, before_results - def before_block_terminator(cond_and_results): + def before_block_terminator(cond_and_results, full_write_args_count): # Generate a condition op instead of yield op cond = cond_and_results[0] before_result_list = ScfGenerator._normalize_region_result_to_list( cond_and_results[1] ) - ir_cond_list, _, _, _ = cutlass_dsl.unpack_to_irvalue( - [cond], while_op_type_name - ) - ir_cond = ir_cond_list[0] - ir_results_list, _, _, _ = cutlass_dsl.unpack_to_irvalue( - before_result_list, while_op_type_name + ir_cond = as_numeric(cond).ir_value() + ir_results_list, pytree_def = cutlass_dsl.unpack_to_irvalue( + before_result_list, while_op_type_name, full_write_args_count ) log().debug( "creating scf.ConditionOp with [%s], [%s]", @@ -514,14 +590,19 @@ def _while_execute_dynamic( scf.ConditionOp(ir_cond, ir_results_list) def after_block_builder( - op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types + op, + block_args, + _, + pytree_def, + mix_iter_args, + full_write_args_count, ): # Build the after (body) block - ScfGenerator.fill_none(block_args, dyn_yield_ops) - block_args = dyn_yield_ops - flat_args = [*used_args] + flat_args = [] flat_args.extend( - cutlass_dsl.pack_from_irvalue(block_args, dyn_indices, dyn_class_types) + cutlass_dsl.pack_from_irvalue( + block_args, pytree_def, mix_iter_args, full_write_args_count + ) ) log().debug("after block args: %s", flat_args) @@ -541,9 +622,9 @@ def _while_execute_dynamic( # Call the universal SCF executor with two region builders return scf_gen.scf_execute_dynamic( op_type_name=while_op_type_name, - used_args=used_args, - mix_iter_args=yield_args, - mix_iter_arg_names=yield_arg_names, + mix_iter_args=write_args, + full_write_args_count=full_write_args_count, + mix_iter_arg_names=write_args_names, create_op_func=create_while_op, region_builders=[before_block_builder, after_block_builder], block_term_op_builder={ diff --git a/python/CuTeDSL/cutlass_dsl/tree_utils.py b/python/CuTeDSL/cutlass_dsl/tree_utils.py new file mode 100644 index 00000000..599b72ea --- /dev/null +++ b/python/CuTeDSL/cutlass_dsl/tree_utils.py @@ -0,0 +1,763 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Callable, Any, Iterable, Iterator, NamedTuple, Union, get_origin +import dataclasses +import itertools as it +from types import SimpleNamespace + +from ..base_dsl.typing import as_numeric, Numeric, Constexpr +from ..base_dsl._mlir_helpers.arith import ArithValue +from ..base_dsl.common import DSLBaseError +from .._mlir import ir + +# ============================================================================= +# Tree Utils +# ============================================================================= + + +class DSLTreeFlattenError(DSLBaseError): + """Exception raised when tree flattening fails due to unsupported types.""" + + def __init__(self, msg: str, type_str: str): + super().__init__(msg) + self.type_str = type_str + + +def unzip2(pairs: Iterable[tuple[Any, Any]]) -> tuple[list[Any], list[Any]]: + """Unzip a sequence of pairs into two lists.""" + lst1, lst2 = [], [] + for x1, x2 in pairs: + lst1.append(x1) + lst2.append(x2) + return lst1, lst2 + + +def get_fully_qualified_class_name(x: Any) -> str: + """ + Get the fully qualified class name of an object. + + Args: + x: Any object + + Returns: + str: Fully qualified class name in format 'module.class_name' + + Example: + >>> get_fully_qualified_class_name([1, 2, 3]) + 'builtins.list' + """ + return f"{x.__class__.__module__}.{x.__class__.__qualname__}" + + +def is_frozen_dataclass(obj_or_cls: Any) -> bool: + """ + Check if an object or class is a frozen dataclass. + + Args: + obj_or_cls: Either a dataclass instance or class + + Returns: + bool: True if the object/class is a dataclass declared with frozen=True, + False otherwise + + Example: + >>> from dataclasses import dataclass + >>> @dataclass(frozen=True) + ... class Point: + ... x: int + ... y: int + >>> is_frozen_dataclass(Point) + True + >>> is_frozen_dataclass(Point(1, 2)) + True + """ + cls = obj_or_cls if isinstance(obj_or_cls, type) else obj_or_cls.__class__ + + return ( + dataclasses.is_dataclass(cls) + and getattr(cls, "__dataclass_params__", None) is not None + and cls.__dataclass_params__.frozen + ) + + +def is_dynamic_expression(x: Any) -> bool: + """ + Check if an object implements the DynamicExpression protocol. + + Objects implementing this protocol must have both `__extract_mlir_values__` + and `__new_from_mlir_values__` methods. + + Args: + x: Any object to check + + Returns: + bool: True if the object implements the DynamicExpression protocol, + False otherwise + """ + return all( + hasattr(x, attr) + for attr in ("__extract_mlir_values__", "__new_from_mlir_values__") + ) + + +def is_constexpr_field(field: dataclasses.Field) -> bool: + """ + Check if a field is a constexpr field. + """ + if field.type is Constexpr: + return True + elif get_origin(field.type) is Constexpr: + return True + return False + + +# ============================================================================= +# PyTreeDef +# ============================================================================= + +class NodeType(NamedTuple): + """ + Represents a node in a pytree structure. + + Attributes: + name: String representation of the node type + to_iterable: Function to convert node to iterable form + from_iterable: Function to reconstruct node from iterable form + """ + name: str + to_iterable: Callable + from_iterable: Callable + + +class PyTreeDef(NamedTuple): + """ + Represents the structure definition of a pytree. + + Attributes: + node_type: The type of this node + node_metadata: SimpleNamespace metadata associated with this node + child_treedefs: Tuple of child tree definitions + """ + node_type: NodeType + node_metadata: SimpleNamespace + child_treedefs: tuple["PyTreeDef", ...] + + +@dataclasses.dataclass(frozen=True) +class Leaf: + """ + Represents a leaf node in a pytree structure. + + Attributes: + is_numeric: Whether this leaf contains a `Numeric` value + is_none: Whether this leaf represents None + node_metadata: SimpleNamespace metadata associated with this leaf + ir_type_str: String representation of the IR type + """ + is_numeric: bool = False + is_none: bool = False + node_metadata: SimpleNamespace = None + ir_type_str: str = None + + +# ============================================================================= +# Default to_iterable and from_iterable +# ============================================================================= + + +def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any]]: + """ + Extract non-method, non-function attributes from a dataclass instance. + + Args: + x: A dataclass instance + + Returns: + tuple: (field_names, field_values) lists + """ + fields = [field.name for field in dataclasses.fields(x)] + + # If the dataclass has extra fields, raise an error + for k in x.__dict__.keys(): + if k not in fields: + raise DSLTreeFlattenError( + f"`{x}` has extra field `{k}`", + type_str=get_fully_qualified_class_name(x), + ) + + if not fields: + return [], [] + + # record constexpr fields + members = [] + constexpr_fields = [] + for field in dataclasses.fields(x): + if is_constexpr_field(field): + constexpr_fields.append(field.name) + fields.remove(field.name) + v = getattr(x, field.name) + if is_dynamic_expression(v): + raise DSLTreeFlattenError( + f"`{x}` has dynamic expression field `{field.name}` with a Constexpr type annotation `{field.type}`", + type_str=get_fully_qualified_class_name(x), + ) + else: + members.append(getattr(x, field.name)) + + return fields, members, constexpr_fields + + +def default_dataclass_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]: + """ + Convert a dataclass instance to iterable form for tree flattening. + + Extracts all non-method, non-function attributes that don't start with '__' + and returns them along with metadata about the dataclass. + + Args: + x: A dataclass instance + + Returns: + tuple: (metadata, members) where metadata contains type info and field names, + and members is the list of attribute values + """ + fields, members, constexpr_fields = extract_dataclass_members(x) + + metadata = SimpleNamespace( + type_str=get_fully_qualified_class_name(x), + fields=fields, + constexpr_fields=constexpr_fields, + original_obj=x, + ) + return metadata, members + + +def set_dataclass_attributes( + instance: Any, + fields: list[str], + values: Iterable[Any], + constexpr_fields: list[str], +) -> Any: + """ + Set attributes on a dataclass instance. + + Args: + instance: The dataclass instance + fields: List of field names + values: Iterable of field values + is_frozen: Whether the dataclass is frozen + + Returns: + The instance with attributes set + """ + if not fields: + return instance + + kwargs = dict(zip(fields, values)) + for field in constexpr_fields: + kwargs[field] = getattr(instance, field) + return dataclasses.replace(instance, **kwargs) + +def default_dataclass_from_iterable( + metadata: SimpleNamespace, children: Iterable[Any] +) -> Any: + """ + Reconstruct a dataclass instance from iterable form. + + Handles both regular and frozen dataclasses appropriately. + + Args: + metadata: Metadata containing type information and field names + children: Iterable of attribute values to reconstruct the instance + + Returns: + The reconstructed dataclass instance + """ + instance = metadata.original_obj + + new_instance = set_dataclass_attributes( + instance, metadata.fields, children, metadata.constexpr_fields + ) + metadata.original_obj = new_instance + return new_instance + + +def dynamic_expression_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]: + """ + Convert a dynamic expression to iterable form. + + Uses the object's `__extract_mlir_values__` method to extract MLIR values. + + Args: + x: A dynamic expression object + + Returns: + tuple: (metadata, mlir_values) where metadata marks this as a dynamic expression + and mlir_values are the extracted MLIR values + """ + return ( + SimpleNamespace(is_dynamic_expression=1, original_obj=x), + x.__extract_mlir_values__(), + ) + + +def dynamic_expression_from_iterable( + metadata: SimpleNamespace, children: Iterable[Any] +) -> Any: + """ + Reconstruct a dynamic expression from iterable form. + + Uses the object's `__new_from_mlir_values__` method to reconstruct from MLIR values. + + Args: + metadata: Metadata containing the original object + children: Iterable of MLIR values to reconstruct from + + Returns: + The reconstructed dynamic expression object + """ + return metadata.original_obj.__new_from_mlir_values__(list(children)) + + +def default_dict_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]: + """ + Convert a dict to iterable form. + """ + if isinstance(x, SimpleNamespace): + keys = list(x.__dict__.keys()) + values = list(x.__dict__.values()) + else: + keys = list(x.keys()) + values = list(x.values()) + + return ( + SimpleNamespace( + type_str=get_fully_qualified_class_name(x), original_obj=x, fields=keys + ), + values, + ) + + +def default_dict_from_iterable( + metadata: SimpleNamespace, children: Iterable[Any] +) -> Any: + """ + Reconstruct a dict from iterable form. + """ + instance = metadata.original_obj + fields = metadata.fields + is_simple_namespace = isinstance(instance, SimpleNamespace) + + for k, v in zip(fields, children): + if is_simple_namespace: + setattr(instance, k, v) + else: + instance[k] = v + + return instance + + +# ============================================================================= +# Register pytree nodes +# ============================================================================= + +_node_types: dict[type, NodeType] = {} + + +def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable) -> NodeType: + """ + Register a new node type for pytree operations. + + Args: + ty: The type to register + to_iter: Function to convert instances of this type to iterable form + from_iter: Function to reconstruct instances of this type from iterable form + + Returns: + NodeType: The created NodeType instance + """ + nt = NodeType(str(ty), to_iter, from_iter) + _node_types[ty] = nt + return nt + + +def register_default_node_types() -> None: + """Register default node types for pytree operations.""" + default_registrations = [ + ( + tuple, + lambda t: (SimpleNamespace(length=len(t)), list(t)), + lambda _, xs: tuple(xs), + ), + ( + list, + lambda l: (SimpleNamespace(length=len(l)), list(l)), + lambda _, xs: list(xs), + ), + ( + dict, + default_dict_to_iterable, + default_dict_from_iterable, + ), + ( + SimpleNamespace, + default_dict_to_iterable, + default_dict_from_iterable, + ), + ] + + for ty, to_iter, from_iter in default_registrations: + register_pytree_node(ty, to_iter, from_iter) + + +# Initialize default registrations +register_default_node_types() + + +# ============================================================================= +# tree_flatten and tree_unflatten +# ============================================================================= + +""" +Behavior of tree_flatten and tree_unflatten, for example: + +```python + a = (1, 2, 3) + b = MyClass(a=1, b =[1,2,3]) +``` + +yields the following tree: + +```python + tree_a = PyTreeDef(type = 'tuple', + metadata = {length = 3}, + children = [ + Leaf(type = int), + Leaf(type = int), + Leaf(type = int), + ], + ) + flattened_a = [1, 2, 3] + tree_b = PyTreeDef(type = 'MyClass', + metadata = {fields = ['a','b']}, + children = [ + PyTreeDef(type = `list`, + metadata = {length = 3}, + children = [ + Leaf(type=`int`), + Leaf(type=`int`), + Leaf(type=`int`), + ], + ), + Leaf(type=int), + ], + ) + flattened_b = [1, 1, 2, 3] +``` + +Passing the flattened values and PyTreeDef to tree_unflatten to reconstruct the original structure. + +``` python + unflattened_a = tree_unflatten(tree_a, flattened_a) + unflattened_b = tree_unflatten(tree_b, flattened_b) +``` + +yields the following structure: + +``` python + unflattened_a = (1, 2, 3) + unflattened_b = MyClass(a=1, b =[1,2,3]) +``` + +unflattened_a should be structurally identical to a, and unflattened_b should be structurally identical to b. + +""" + + +def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]: + """ + Flatten a nested structure into a flat list of values and a tree definition. + + This function recursively traverses nested data structures (trees) and + flattens them into a linear list of leaf values, while preserving the + structure information in a PyTreeDef. + + Args: + x: The nested structure to flatten + + Returns: + tuple: (flat_values, treedef) where flat_values is a list of leaf values + and treedef is the tree structure definition + + Raises: + DSLTreeFlattenError: If the structure contains unsupported types + + Example: + >>> tree_flatten([1, [2, 3], 4]) + ([1, 2, 3, 4], PyTreeDef(...)) + """ + children_iter, treedef = _tree_flatten(x) + return list(children_iter), treedef + + +def get_registered_node_types_or_insert(x: Any) -> NodeType | None: + """ + Get the registered node type for an object, registering it if necessary. + + This function checks if a type is already registered for pytree operations. + If not, it automatically registers the type based on its characteristics: + - Dynamic expressions get registered with dynamic expression handlers + - Dataclasses get registered with default dataclass handlers + + Args: + x: The object to get or register a node type for + + Returns: + NodeType or None: The registered node type, or None if the type + cannot be registered + """ + node_type = _node_types.get(type(x)) + if node_type: + return node_type + elif is_dynamic_expression(x): + # If a class implements DynamicExpression protocol, register it before default dataclass one + return register_pytree_node( + type(x), dynamic_expression_to_iterable, dynamic_expression_from_iterable + ) + elif dataclasses.is_dataclass(x): + return register_pytree_node( + type(x), default_dataclass_to_iterable, default_dataclass_from_iterable + ) + else: + return None + + +def create_leaf_for_value( + x: Any, + is_numeric: bool = False, + is_none: bool = False, + node_metadata: SimpleNamespace = None, + ir_type_str: str = None, +) -> Leaf: + """ + Create a Leaf node for a given value. + + Args: + x: The value to create a leaf for + is_numeric: Whether this is a numeric value + is_none: Whether this represents None + node_metadata: Optional metadata + ir_type_str: Optional IR type string + + Returns: + Leaf: The created leaf node + """ + return Leaf( + is_numeric=is_numeric, + is_none=is_none, + node_metadata=node_metadata, + ir_type_str=ir_type_str or (str(x.type) if hasattr(x, "type") else None), + ) + + +def _tree_flatten(x: Any) -> tuple[Iterable[Any], PyTreeDef | Leaf]: + """ + Internal function to flatten a tree structure. + + This is the core implementation of tree flattening that handles different + types of objects including None, ArithValue, ir.Value, Numeric types, + and registered pytree node types. + + Args: + x: The object to flatten + + Returns: + tuple: (flattened_values, treedef) where flattened_values is an iterable + of leaf values and treedef is the tree structure + + Raises: + DSLTreeFlattenError: If the object type is not supported + """ + match x: + case None: + return [], create_leaf_for_value(x, is_none=True) + + case ArithValue() if is_dynamic_expression(x): + v = x.__extract_mlir_values__() + return v, create_leaf_for_value( + x, + node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x), + ir_type_str=str(v[0].type), + ) + + case ArithValue(): + return [x], create_leaf_for_value(x, is_numeric=True) + + case ir.Value(): + return [x], create_leaf_for_value(x) + + case Numeric(): + v = x.__extract_mlir_values__() + return v, create_leaf_for_value( + x, + node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x), + ir_type_str=str(v[0].type), + ) + + case _: + node_type = get_registered_node_types_or_insert(x) + if node_type: + node_metadata, children = node_type.to_iterable(x) + children_flat, child_trees = unzip2(map(_tree_flatten, children)) + flattened = it.chain.from_iterable(children_flat) + return flattened, PyTreeDef( + node_type, node_metadata, tuple(child_trees) + ) + + # Try to convert to numeric + try: + nval = as_numeric(x).ir_value() + return [nval], create_leaf_for_value(nval, is_numeric=True) + except Exception: + raise DSLTreeFlattenError( + "Flatten Error", get_fully_qualified_class_name(x) + ) + + +def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any: + """ + Reconstruct a nested structure from a flat list of values and tree definition. + + This is the inverse operation of tree_flatten. It takes the flattened + values and the tree structure definition to reconstruct the original + nested structure. + + Args: + treedef: The tree structure definition from tree_flatten + xs: List of flat values to reconstruct from + + Returns: + The reconstructed nested structure + + Example: + >>> flat_values, treedef = tree_flatten([1, [2, 3], 4]) + >>> tree_unflatten(treedef, flat_values) + [1, [2, 3], 4] + """ + return _tree_unflatten(treedef, iter(xs)) + + +def _tree_unflatten(treedef: PyTreeDef | Leaf, xs: Iterator[Any]) -> Any: + """ + Internal function to reconstruct a tree structure. + + This is the core implementation of tree unflattening that handles + different types of tree definitions including Leaf nodes and PyTreeDef nodes. + + Args: + treedef: The tree structure definition + xs: Iterator of flat values to reconstruct from + + Returns: + The reconstructed object + """ + match treedef: + case Leaf(is_none=True): + return None + + case Leaf( + node_metadata=metadata + ) if metadata and metadata.is_dynamic_expression: + return metadata.original_obj.__new_from_mlir_values__([next(xs)]) + + case Leaf(is_numeric=True): + return as_numeric(next(xs)) + + case Leaf(): + return next(xs) + + case PyTreeDef(): + children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs) + return treedef.node_type.from_iterable(treedef.node_metadata, children) + + +def _check_tree_equal(lhs: Union[PyTreeDef, Leaf], rhs: Union[PyTreeDef, Leaf]) -> bool: + """ + Check if two tree definitions are structurally equal. + + This is a helper function for check_tree_equal that recursively compares + tree structures. + + Args: + lhs: Left tree definition (PyTreeDef or Leaf) + rhs: Right tree definition (PyTreeDef or Leaf) + + Returns: + bool: True if the trees are structurally equal, False otherwise + """ + match (lhs, rhs): + case (Leaf(), Leaf()): + return lhs.is_none == rhs.is_none and lhs.ir_type_str == rhs.ir_type_str + + case (PyTreeDef(), PyTreeDef()): + lhs_metadata = lhs.node_metadata + rhs_metadata = rhs.node_metadata + + lhs_fields = getattr(lhs_metadata, "fields", []) + rhs_fields = getattr(rhs_metadata, "fields", []) + lhs_constexpr_fields = getattr(lhs_metadata, "constexpr_fields", []) + rhs_constexpr_fields = getattr(rhs_metadata, "constexpr_fields", []) + + return ( + lhs.node_type == rhs.node_type + and lhs_fields == rhs_fields + and lhs_constexpr_fields == rhs_constexpr_fields + and len(lhs.child_treedefs) == len(rhs.child_treedefs) + and all(map(_check_tree_equal, lhs.child_treedefs, rhs.child_treedefs)) + ) + + case _: + return False + + +def check_tree_equal(lhs: PyTreeDef, rhs: PyTreeDef) -> int: + """ + Check if two tree definitions are equal and return the index of first difference. + + This function compares two tree definitions and returns the index of the + first child that differs, or -1 if they are completely equal. + + Args: + lhs: Left tree definition + rhs: Right tree definition + + Returns: + int: Index of the first differing child, or -1 if trees are equal + + Example: + >>> treedef1 = tree_flatten([1, [2, 3]])[1] + >>> treedef2 = tree_flatten([1, [2, 4]])[1] + >>> check_tree_equal(treedef1, treedef2) + 1 # The second child differs + """ + assert len(lhs.child_treedefs) == len(rhs.child_treedefs) + + def find_first_difference( + index_and_pair: tuple[int, tuple[PyTreeDef, PyTreeDef]] + ) -> int: + index, (l, r) = index_and_pair + return index if not _check_tree_equal(l, r) else -1 + + differences = map( + find_first_difference, enumerate(zip(lhs.child_treedefs, rhs.child_treedefs)) + ) + return next((diff for diff in differences if diff != -1), -1) diff --git a/python/CuTeDSL/requirements.txt b/python/CuTeDSL/requirements.txt index 093ed1d6..3d1d5b00 100644 --- a/python/CuTeDSL/requirements.txt +++ b/python/CuTeDSL/requirements.txt @@ -1,3 +1,3 @@ # Use `pip install -r requirements.txt` with the present file to install a # wheel consistent with the present state of the github repository -nvidia-cutlass-dsl==4.1.0 +nvidia-cutlass-dsl==4.2.0 diff --git a/python/cutlass/__init__.py b/python/cutlass/__init__.py index dd0d7c62..35507d2a 100644 --- a/python/cutlass/__init__.py +++ b/python/cutlass/__init__.py @@ -133,7 +133,7 @@ def get_option_registry(): this._option_registry = OptionRegistry(device_cc()) return this._option_registry -this.__version__ = '4.1.0' +this.__version__ = '4.2.0' from cutlass_cppgen.backend import create_memory_pool from cutlass_cppgen.emit.pytorch import pytorch diff --git a/python/cutlass/backend/c_types.py b/python/cutlass/backend/c_types.py index 91fbc23e..3f515aa3 100644 --- a/python/cutlass/backend/c_types.py +++ b/python/cutlass/backend/c_types.py @@ -245,6 +245,9 @@ def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args, _fields_ = [ ("device_id", ctypes.c_int), ("sm_count", ctypes.c_int), + ("max_active_clusters", ctypes.c_int), + ("cluster_shape", dim3_), + ("cluster_shape_fallback", dim3_), ] class _GemmArguments(ctypes.Structure): diff --git a/python/cutlass/backend/compiler.py b/python/cutlass/backend/compiler.py index 1b78b513..0b66ce8a 100644 --- a/python/cutlass/backend/compiler.py +++ b/python/cutlass/backend/compiler.py @@ -93,7 +93,7 @@ class CompilationOptions: opts.append(f"--include-path={incl}") arch_flag = f"-arch=sm_{self.arch}" - if self.arch == 90 and int(cutlass_cppgen.nvcc_version().split('.')[0]) >= 12: + if self.arch in [90, 100, 101, 103, 120, 121] and int(cutlass_cppgen.nvcc_version().split('.')[0]) >= 12: arch_flag += "a" opts.append(arch_flag) @@ -109,7 +109,7 @@ class CompilationOptions: options.append(bytes(str.encode(f" --include-path={incl}"))) arch_flag = f" -arch=sm_{self.arch}" - if self.arch == 90: + if self.arch in [90, 100, 101, 103, 120, 121]: arch_flag += "a" options.append(bytes(str.encode(arch_flag))) diff --git a/python/cutlass/backend/evt/backend/__init__.py b/python/cutlass/backend/evt/backend/__init__.py index a1654548..945dcf80 100644 --- a/python/cutlass/backend/evt/backend/__init__.py +++ b/python/cutlass/backend/evt/backend/__init__.py @@ -34,3 +34,5 @@ from cutlass_cppgen.backend.evt.backend.sm80_emitter import Sm80Emitter import cutlass_cppgen.backend.evt.backend.sm80_nodes as sm80_nodes from cutlass_cppgen.backend.evt.backend.sm90_emitter import Sm90Emitter import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes +from cutlass_cppgen.backend.evt.backend.sm100_emitter import Sm100Emitter +import cutlass_cppgen.backend.evt.backend.sm100_nodes as sm100_nodes diff --git a/python/cutlass/backend/evt/backend/emitter_base.py b/python/cutlass/backend/evt/backend/emitter_base.py index 39723844..72a7d8c0 100644 --- a/python/cutlass/backend/evt/backend/emitter_base.py +++ b/python/cutlass/backend/evt/backend/emitter_base.py @@ -52,6 +52,7 @@ class FusionCallbacks: self.dag_ir = dag_ir self.emit_CD = emit_CD self.cc = cc + self.evt_cc = 90 if cc >= 90 else cc if self.cc < 90: self.namespace = "threadblock" else: @@ -103,7 +104,7 @@ class FusionCallbacks: return "" evt_tmp = f""" -using EVT{node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.cc}EVT< +using EVT{node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}EVT< {node.name_camel}, """ sorted_children = self.dag_ir.get_all_inputs(node.name) @@ -140,7 +141,7 @@ using EVT{node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.cc}EVT dag_nodes = ",\n".join(dag_node_strs) return f""" -using {node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.cc}TopologicalVisitor< +using {node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}TopologicalVisitor< {DataTypeTag[node.subgraph.element_compute]}, {edge_tuples}, {dag_nodes} diff --git a/python/cutlass/backend/evt/backend/sm100_emitter.py b/python/cutlass/backend/evt/backend/sm100_emitter.py new file mode 100644 index 00000000..db521e52 --- /dev/null +++ b/python/cutlass/backend/evt/backend/sm100_emitter.py @@ -0,0 +1,116 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Emitter for Sm100 Epilogue Visitor +""" + +from cutlass_library import DataType, DataTypeTag, EpilogueScheduleTag, OpcodeClassTag +from cutlass_cppgen.backend.library import to_blackwell_threadblock_shape +from cutlass_cppgen.backend import GemmOperationUniversal +from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks +from cutlass_cppgen.backend.evt.ir.node import TupleEmitter + + +class Sm100CollectiveEpilogue: + def __init__(self, tile_description, + kernel_schedule, + epilogue_schedule, + element_accumulator, + element_d, + fusion_callbacks) -> None: + + self.cta_tile_mnk, _ = to_blackwell_threadblock_shape(tile_description, tile_description.cluster_shape, kernel_schedule) + self.element_accumulator = element_accumulator + if fusion_callbacks.dag_ir.has_node("C"): + self.element_c = fusion_callbacks.dag_ir.get_node_meta("C").element + else: + self.element_c = DataType.void + self.element_d = element_d + self.schedule = epilogue_schedule + self.fusion_callbacks = fusion_callbacks + self.opclass = tile_description.math_instruction.opcode_class + + @property + def CtaTileMNK(self) -> str: + """ + The threadblock shape + """ + return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>" + + @property + def EpilogueTileType(self) -> str: + """ + The epilogue tile type + """ + return "cutlass::epilogue::collective::EpilogueTileAuto" + + @property + def Schedule(self) -> str: + return EpilogueScheduleTag[self.schedule] + + def emit(self): + tuple_emitter = TupleEmitter("int64_t") + stride_D_str = self.fusion_callbacks.dag_ir.get_node_meta("D").underlying_impl.stride_mnl + stride_C_str = stride_D_str + if self.fusion_callbacks.dag_ir.has_node("C"): + stride_C_str = self.fusion_callbacks.dag_ir.get_node_meta("C").underlying_impl.stride_mnl + + callback_decl, callback_name = self.fusion_callbacks.emit() + return callback_name, f""" +using EpilogueDescriptor = cutlass::epilogue::collective::detail::Sm100EpilogueDescriptor< + {OpcodeClassTag[self.opclass]}, + {self.CtaTileMNK}, {self.EpilogueTileType}, + {DataTypeTag[self.element_accumulator]}, {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]}, + {self.Schedule}, {stride_C_str}, {stride_D_str}, + false /* IsPerColScaleSupported */, + false /* IsBlockScaleSupported */ +>; +{callback_decl} +""" + + +class Sm100Emitter: + def __init__(self, operation: GemmOperationUniversal, graph) -> None: + fusion_callbacks = FusionCallbacks(graph, cc=100, emit_CD=False) + + self.collective_epilogue = Sm100CollectiveEpilogue( + tile_description=operation.tile_description, + kernel_schedule=operation.tile_description.kernel_schedule, + epilogue_schedule=operation.tile_description.epilogue_schedule, + element_accumulator=operation.tile_description.math_instruction.element_accumulator, + element_d=fusion_callbacks.dag_ir.get_node_meta("D").element, + fusion_callbacks=fusion_callbacks + ) + + def emit(self): + return self.collective_epilogue.emit() diff --git a/python/cutlass/backend/evt/backend/sm100_nodes.py b/python/cutlass/backend/evt/backend/sm100_nodes.py new file mode 100644 index 00000000..33e77b4c --- /dev/null +++ b/python/cutlass/backend/evt/backend/sm100_nodes.py @@ -0,0 +1,134 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from pycute import product + +from cutlass_library import DataTypeSize, DataTypeTag + +from cutlass_cppgen.backend.evt.ir import AuxLoadImpl, AuxStoreImpl +import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes + +from cutlass_cppgen.backend.library import FloatRoundStyleTag + + +Sm100AccumulatorImpl = sm90_nodes.Sm90AccumulatorImpl +Sm100LoadSrcImpl = sm90_nodes.Sm90LoadSrcImpl +Sm100ScalarBroadcastImpl = sm90_nodes.Sm90ScalarBroadcastImpl +Sm100RowBroadcastImpl = sm90_nodes.Sm90RowBroadcastImpl +Sm100ColumnBroadcastImpl = sm90_nodes.Sm90ColumnBroadcastImpl +Sm100ComputeImpl = sm90_nodes.Sm90ComputeImpl +Sm100StoreDImpl = sm90_nodes.Sm90StoreDImpl +Sm100ColumnReductionImpl = sm90_nodes.Sm90ColumnReductionImpl +Sm100RowReductionImpl = sm90_nodes.Sm90RowReductionImpl +Sm100ScalarReductionImpl = sm90_nodes.Sm90ScalarReductionImpl + + +class Sm100AuxLoadImpl(AuxLoadImpl): + + @property + def descriptor(self) -> str: + """ + Descriptor for Aux Load + """ + return f"{self.name_camel}Descriptor" + + def decl_descriptor(self) -> str: + """ + Declare the descriptor type + """ + return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxLoadDescriptor;\n" + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = self.decl_descriptor() + self._type_decl += f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad< + {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]}, + {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R +>; +""" + return self._type_decl + + def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): + """ + Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d + """ + return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128) + + +class Sm100AuxStoreImpl(AuxStoreImpl): + + @property + def descriptor(self) -> str: + """ + Descriptor for Aux Load + """ + return f"{self.name_camel}Descriptor" + + def decl_descriptor(self) -> str: + """ + Declare the descriptor type + """ + return f""" +using {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxStoreDescriptor< + EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]} +>; +""" + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = self.decl_descriptor() + self._type_decl += f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore< + {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]}, + {FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, + typename {self.descriptor}::CopyOpR2S +>; +""" + return self._type_decl + + def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): + """ + Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d + """ + return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128) diff --git a/python/cutlass/backend/evt/epilogue.py b/python/cutlass/backend/evt/epilogue.py index 92f71a10..da446e76 100644 --- a/python/cutlass/backend/evt/epilogue.py +++ b/python/cutlass/backend/evt/epilogue.py @@ -70,7 +70,7 @@ class EpilogueFunctorVisitor(EpilogueFunctorBase): # Epilogue Thread Type epilogue_thread_type = self.visitor.epilogue_thread_type - if cc == 90: + if cc_map[cc] in [90, 100]: self.arg_c_type = self.visitor.arg_c_type self.arg_d_type = self.visitor.arg_d_type output_names = self.visitor.return_names @@ -114,7 +114,7 @@ class EpilogueFunctorVisitor(EpilogueFunctorBase): Helper function for extracting device pointer """ # Skip the special tensors - if cc == 90: + if cc in [90, 100]: if tensor_name in ["C", "D"]: return 0 if tensor_name not in kwargs.keys(): diff --git a/python/cutlass/backend/evt/frontend/frontend_base.py b/python/cutlass/backend/evt/frontend/frontend_base.py index c150bf20..213aafdb 100644 --- a/python/cutlass/backend/evt/frontend/frontend_base.py +++ b/python/cutlass/backend/evt/frontend/frontend_base.py @@ -56,6 +56,7 @@ from cutlass_cppgen.backend.evt.passes import ( PassPreprocessRed, PassShapeTypePropagation, ) +from cutlass_cppgen.backend.evt.passes.util import cc_map from cutlass_cppgen.backend.utils import device_cc from cutlass_cppgen.epilogue.evt_ops import permute, reshape from cutlass_cppgen.utils.datatypes import library_type @@ -119,7 +120,7 @@ class EVTFrontendBase: self.pass_manager() # Set the epilogue type self.epilogue_thread_type = self.dag_ir.epilogue_thread_type - if self.cc == 90: + if cc_map[self.cc] in [90, 100]: self.arg_c_type = self.dag_ir.arg_c_type self.arg_d_type = self.dag_ir.arg_d_type self.reduction_names = self.dag_ir.reduction_names diff --git a/python/cutlass/backend/evt/ir/node.py b/python/cutlass/backend/evt/ir/node.py index e2b3a34a..606591b8 100644 --- a/python/cutlass/backend/evt/ir/node.py +++ b/python/cutlass/backend/evt/ir/node.py @@ -43,6 +43,28 @@ from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reve from cutlass_cppgen.backend.evt.ir.tensor import Tensor +class TupleEmitter: + """ + Emit the cute tuple to C++ code + """ + def __init__(self, stride_dtype): + self.stride_dtype = stride_dtype + + def emit(self, py_tuple): + if isinstance(py_tuple, int): + if py_tuple in [0, 1]: + return f"cute::Int<{py_tuple}>" + else: + return f"{self.stride_dtype}" + elif isinstance(py_tuple, tuple): + decl = "cute::Stride<" + for item in py_tuple: + decl += self.emit(item) + ", " + return decl[:-2] + ">" + else: + raise ValueError(f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}") + + class ImplBase: """ Base class for Node Implementation @@ -52,7 +74,15 @@ class ImplBase: self.name = node.name self.tensor = node.tensor self._type_decl = None - self.stride_dtype = "int64_t" + self.tuple_emitter = TupleEmitter("int64_t") + + @property + def stride_dtype(self): + return self.tuple_emitter.stride_dtype + + @stride_dtype.setter + def stride_dtype(self, stride_dtype): + self.tuple_emitter.stride_dtype = stride_dtype @staticmethod def match(node, problem_size: tuple): @@ -81,30 +111,13 @@ class ImplBase: """ return sub(r"(_|-)+", " ", self.name).title().replace(" ", "") - def _emit_cute_tuple(self, py_tuple): - """ - Emit the cute tuple to C++ code - """ - if isinstance(py_tuple, int): - if py_tuple in [0, 1]: - return f"cute::Int<{py_tuple}>" - else: - return f"{self.stride_dtype}" - elif isinstance(py_tuple, tuple): - decl = "cute::Stride<" - for item in py_tuple: - decl += self._emit_cute_tuple(item) + ", " - return decl[:-2] + ">" - else: - raise ValueError(f"_emit_cute_tuple only accepts tuple or int, got {type(py_tuple).__name__}") - @property def stride_mnl(self): """ Typename StrideMNL """ stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2])))) - return self._emit_cute_tuple(stride) + return self.tuple_emitter.emit(stride) def get_non_constant_stride(self, py_tuple): if isinstance(py_tuple, int): diff --git a/python/cutlass/backend/evt/passes/pass_argument_type.py b/python/cutlass/backend/evt/passes/pass_argument_type.py index c458f799..b0c3cdbd 100644 --- a/python/cutlass/backend/evt/passes/pass_argument_type.py +++ b/python/cutlass/backend/evt/passes/pass_argument_type.py @@ -40,6 +40,7 @@ from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation +from cutlass_cppgen.backend.evt.passes.util import cc_map class PassGetArgumentType(EVTPassBase): @@ -54,9 +55,9 @@ class PassGetArgumentType(EVTPassBase): def requires(self) -> None: # Check "D" is in the node list - if self.cc == 90 and (not self.dag_ir.has_node("D")): + if cc_map[self.cc] in [90, 100] and (not self.dag_ir.has_node("D")): raise SyntaxError( - "Sm90 EVT requires the epilogue to have a returned tensor D, " + "Sm90+ EVT requires the epilogue to have a returned tensor D, " "but the variable 'D' is not found in the return values.") def call(self): @@ -66,7 +67,7 @@ class PassGetArgumentType(EVTPassBase): meta = self.dag_ir.get_node_meta(node) if not meta.disabled: self.argument_types[node] = meta.underlying_impl.argument_type - if node == "D" and self.cc == 90: + if node == "D" and cc_map[self.cc] in [90, 100]: continue if isinstance(meta, TopoVisitorNode): self.get_dag_argument_type(node) @@ -111,6 +112,9 @@ class PassGetArgumentType(EVTPassBase): else: self.dag_ir.arg_c_type = self.dag_ir.arg_d_type + def sm100_set_argument_type(self): + self.sm90_set_argument_type() + def sm80_set_argument_type(self): nodes = self.dag_ir.nodes_topological_order() self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]] diff --git a/python/cutlass/backend/evt/passes/pass_dag_2_tree.py b/python/cutlass/backend/evt/passes/pass_dag_2_tree.py index 5eae2f92..46976966 100644 --- a/python/cutlass/backend/evt/passes/pass_dag_2_tree.py +++ b/python/cutlass/backend/evt/passes/pass_dag_2_tree.py @@ -105,7 +105,7 @@ class PassDAG2Tree(EVTPassBase): output_node = None if (self.dag_ir.cc >= 90): - # For SM90, the lca should be the input node of D + # For SM90+, the lca should be the input node of D if (not self.dag_ir.has_node("D")): raise RuntimeError(f"D is not a node in the DAG IR.") output_node = "D" diff --git a/python/cutlass/backend/evt/passes/smem_size_calculator.py b/python/cutlass/backend/evt/passes/smem_size_calculator.py index 4896840e..8168c597 100644 --- a/python/cutlass/backend/evt/passes/smem_size_calculator.py +++ b/python/cutlass/backend/evt/passes/smem_size_calculator.py @@ -34,12 +34,14 @@ Compute the shared memory size in bytes """ +from math import gcd + import cutlass_library -from pycute import shape_div, product +from pycute import flatten, shape_div, product import cutlass_cppgen from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR -from cutlass_cppgen.backend.library import DataTypeSize +from cutlass_cppgen.backend.library import DataType, DataTypeSize class GetSmemSize: @@ -58,12 +60,15 @@ class GetSmemSize: # Get the epilogue tile size schedule = tile_description.epilogue_schedule if schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecialized: - epilogue_tile_mn = (64, 32) + element_d = self.dag_ir.get_node_meta("D").element + nperf = 64 if (DataTypeSize[element_d] == 8 and tile_description.threadblock_shape[1] % 64 == 0) else 32 + epi_tile_m = min(64, tile_description.threadblock_shape[0]) + epi_tile_n = gcd(min(nperf, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1]) + epilogue_tile_mn = (epi_tile_m, epi_tile_n) elif schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecializedCooperative: - if tile_description.threadblock_shape[0] >= 128: - epilogue_tile_mn = (128, 32) - else: - epilogue_tile_mn = (64, 32) + epi_tile_m = min(128, tile_description.threadblock_shape[0]) + epi_tile_n = gcd(min(32, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1]) + epilogue_tile_mn = (epi_tile_m, epi_tile_n) else: raise NotImplementedError(f"Unsupported schedule: {schedule}") @@ -93,11 +98,7 @@ class GetSmemSize: self.element_d = element_d self.is_source_supported = element_c is not None - def sm90_epilogue_smem_size(self, tile_description): - """ - Compute the shared memory size of sm90 collective epilogue - """ - self.sm90_epilogue_tile(tile_description) + def sm90_or_sm100_epilogue_smem_size(self, tile_description): # Get the Fusion Storage nodes = self.dag_ir.nodes_topological_order() self.smem_types = {} @@ -139,6 +140,120 @@ class GetSmemSize: smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size]) return smem_size[0] + def sm90_epilogue_smem_size(self, tile_description): + """ + Compute the shared memory size of sm90 collective epilogue + """ + self.sm90_epilogue_tile(tile_description) + return self.sm90_or_sm100_epilogue_smem_size(tile_description) + + # + # Sm100 epilogue specific + # + + def sm100_epilogue_tile(self, tile_description): + cta_tile = (tile_description.blackwell_threadblock_shape[0], tile_description.blackwell_threadblock_shape[1]) + mma_tile = cta_tile + + if tile_description.is_2sm: + cta_tile = (cta_tile[0] // 2, cta_tile[1]) + + if tile_description.is_2sm and mma_tile[0] == 128: + tmem_warps = (2, 2) + else: + tmem_warps = (4, 1) + + if self.dag_ir.has_node("C"): + element_c = self.dag_ir.get_node_meta("C").element + element_c_size = DataTypeSize[element_c] + else: + element_c = None + element_c_size = 0 + + element_d = self.dag_ir.get_node_meta("D").element + + DisableSource = element_c is None or not self.dag_ir.has_node("C") or self.dag_ir.get_node_meta("C").element == DataType.void + + CtaM = cta_tile[0] + CtaN = cta_tile[1] + WarpM = tmem_warps[0] + WarpN = tmem_warps[1] + MaxBits = max(element_c_size, DataTypeSize[element_d]) + DpFull = 32 + M = min(CtaM, DpFull * WarpM) + + if DisableSource: + # Epilogues w/o residual load are less sensitive to smem allocation + # Target a fixed amount of compute per epilogue iteration + if MaxBits == 4: + # Make epilogue tile larger to reduce the epilogue iterations. + # 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same. + ComputeElts = 8192 + Nperf = ComputeElts // M + else: + ComputeElts = 4096 + Nperf = ComputeElts // M + else: + # Epilogues w/ residual load are more sensitive to smem allocation + # Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize + if MaxBits == 32: + Nperf = 16 if CtaM > 64 and CtaN <= 128 else 32 + elif MaxBits == 16: + Nperf = 32 if CtaN <= 128 else 64 + else: + Nperf = 64 + + def is_m_major(layout): + return flatten(layout.stride[0]) == 1 + + if DisableSource or is_m_major(self.dag_ir.get_node_meta("C").tensor.layout): + N_min_C = 8 * WarpN + elif element_c_size == 6: + N_min_C = 128 * WarpN + else: + N_min_C = (128 // element_c_size) * WarpN + + if is_m_major(self.dag_ir.get_node_meta("D").tensor.layout): + N_min_D = 8 * WarpN + elif DataTypeSize[element_d] == 6: + N_min_D = 128 * WarpN + else: + N_min_D = (128 // DataTypeSize[element_d]) * WarpN + + N = min(CtaN, max(Nperf, N_min_C, N_min_D)) + + tile_m = M + tile_n_size = N // WarpN * WarpN + + epilogue_tile_mn = (tile_m, tile_n_size) + epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn)) + + stages_d = min(epi_tiles, 2) + reuse_smem_c = (element_c_size > 8) + + if reuse_smem_c: + stages_c = max(min(epi_tiles, 4), stages_d + 1) + else: + stages_c = min(epi_tiles, 4) + + # Record the epilogue tile + self.cta_tile_mnk = tuple(tile_description.threadblock_shape) + self.epilogue_tile_mn = epilogue_tile_mn + self.epi_tiles = epi_tiles + self.stages_c = stages_c + self.stages_d = stages_d + self.reuse_smem_c = reuse_smem_c + self.element_c = element_c + self.element_d = element_d + self.is_source_supported = not DisableSource + + def sm100_epilogue_smem_size(self, tile_description): + """ + Compute the shared memory size of sm100 collective epilogue + """ + self.sm100_epilogue_tile(tile_description) + return self.sm90_or_sm100_epilogue_smem_size(tile_description) + def __call__(self, tile_description): return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description) diff --git a/python/cutlass/backend/evt/passes/util.py b/python/cutlass/backend/evt/passes/util.py index ad014bf5..4b72e330 100644 --- a/python/cutlass/backend/evt/passes/util.py +++ b/python/cutlass/backend/evt/passes/util.py @@ -36,8 +36,11 @@ Utilities for passes # Map from the CC of the kernel to the EVT implementation that the CC targets cc_map = { - 80: 80, - 86: 80, - 89: 80, - 90: 90, + 80: 80, + 86: 80, + 89: 80, + 90: 90, + 100: 100, + 101: 100, + 103: 100, } diff --git a/python/cutlass/backend/gemm_operation.py b/python/cutlass/backend/gemm_operation.py index cf6bcc18..5e2a3a30 100644 --- a/python/cutlass/backend/gemm_operation.py +++ b/python/cutlass/backend/gemm_operation.py @@ -186,7 +186,7 @@ class GemmArguments2x(ArgumentBase): if operation.C.layout in [LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32]: raise Exception("Interleaved layout not currently supported") - if hasattr(self.operation.epilogue_functor, "visitor") and operation.arch != 90: + if hasattr(self.operation.epilogue_functor, "visitor") and operation.arch not in [90, 100, 101, 103]: super().__init__(A, B, None, None, **kwargs) else: super().__init__(A, B, C, D, **kwargs) @@ -569,7 +569,9 @@ class GemmArguments3x(GemmArguments2x): # Set hardware info hw_info_ = hw_info( - 0, device_sm_count(), + 0, device_sm_count(), 0, + dim3_(0,0,0), + dim3_(0,0,0), ) self.arguments = argument_type( @@ -1324,6 +1326,8 @@ using DeviceKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_nam if operation.tile_description.tile_scheduler is not None: tschedule = operation.tile_description.tile_scheduler + emit_tile_m, emit_tile_n, emit_tile_k = operation.tile_description.blackwell_threadblock_shape + values = { "operation_name": operation.procedural_name(), "operation_suffix": self.operation_suffix, @@ -1339,9 +1343,9 @@ using DeviceKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_nam "element_epilogue": DataTypeTag[operation.epilogue_functor.element_epilogue], "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], "arch": "cutlass::arch::Sm%d" % operation.arch, - "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), - "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), - "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "threadblock_shape_m": str(emit_tile_m), + "threadblock_shape_n": str(emit_tile_n), + "threadblock_shape_k": str(emit_tile_k), "cluster_m": str(operation.tile_description.cluster_shape[0]), "cluster_n": str(operation.tile_description.cluster_shape[1]), "cluster_k": str(operation.tile_description.cluster_shape[2]), diff --git a/python/cutlass/backend/library.py b/python/cutlass/backend/library.py index a8b113b4..a77b302d 100644 --- a/python/cutlass/backend/library.py +++ b/python/cutlass/backend/library.py @@ -41,6 +41,7 @@ from cutlass_library import ( DataType, DataTypeSize, EpilogueScheduleType, + KernelScheduleSuffixes, KernelScheduleType, MathOperation, OpcodeClass, @@ -238,6 +239,22 @@ class MathInstruction: self.math_operation = math_operation +def to_blackwell_threadblock_shape(tile_description, cluster_shape, kernel_schedule): + blackwell_threadblock_shape = tile_description.threadblock_shape + is_2sm = False if kernel_schedule is None else ("2sm" in KernelScheduleSuffixes[kernel_schedule]) + if cluster_shape[0] > 0: + blackwell_threadblock_shape = [ + tile_description.threadblock_shape[0] // cluster_shape[0], + tile_description.threadblock_shape[1] // cluster_shape[1], + tile_description.threadblock_shape[2] // cluster_shape[2] + ] + if is_2sm: + blackwell_threadblock_shape[0] *= 2 + else: + blackwell_threadblock_shape = tile_description.math_instruction.instruction_shape + return blackwell_threadblock_shape, is_2sm + + class TileDescription: """ Description of a tile of computation to be performed in the kernel, encompassing threadblock, cluster, and warp shapes, @@ -290,6 +307,8 @@ class TileDescription: # Number of warps along x, y, z directions self.warp_count = warp_count + self.blackwell_threadblock_shape, self.is_2sm = to_blackwell_threadblock_shape(self, self.cluster_shape, self.kernel_schedule) + def clone_and_update(self, td: dict): attrs = { "cluster_shape": None, @@ -473,7 +492,7 @@ def api_version(arch, opclass, dtype): :return: API version to be used in code emission :rtype: ApiVersion """ - if (arch >= 90 and + if (arch in [90, 100, 101, 103] and opclass == OpcodeClass.TensorOp and (dtype != DataType.f64)): return ApiVersion.v3x diff --git a/python/cutlass/backend/operation.py b/python/cutlass/backend/operation.py index 1f4b26ad..10ee67bc 100644 --- a/python/cutlass/backend/operation.py +++ b/python/cutlass/backend/operation.py @@ -45,7 +45,7 @@ def supports_cluster_launch(): global _supports_cluster_launch if _supports_cluster_launch is None: major, minor = _version_splits[0], _version_splits[1] - _supports_cluster_launch = device_cc() >= 90 and (major > 11 or (major == 11 and minor >= 8)) + _supports_cluster_launch = device_cc() in [90, 100, 101, 103] and (major > 11 or (major == 11 and minor >= 8)) return _supports_cluster_launch diff --git a/python/cutlass/emit/pytorch.py b/python/cutlass/emit/pytorch.py index 86374b8b..fe96f3ed 100644 --- a/python/cutlass/emit/pytorch.py +++ b/python/cutlass/emit/pytorch.py @@ -689,10 +689,10 @@ def _jit(name: str, cc: int, cpp_file: str, cuda_file: str): from torch.utils.cpp_extension import load extra_cuda_cflags = ["-std=c++17"] - if cc == 90: + if cc in [90, 100, 101, 103]: # PyTorch does not currently add the sm_90a target when compute capability # 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target. - extra_cuda_cflags.append("-gencode=arch=compute_90a,code=sm_90a") + extra_cuda_cflags.append(f"-gencode=arch=compute_{cc}a,code=sm_{cc}a") with _ArchListSetter(cc): jitmodule = load( @@ -768,8 +768,8 @@ def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = "" outfile.write(cpp_source) extra_compile_args = "" - if cc == 90: - extra_compile_args = "'--generate-code=arch=compute_90a,code=[sm_90a]'" + if cc in [90, 100, 101, 103]: + extra_compile_args = f"'--generate-code=arch=compute_{cc}a,code=[sm_{cc}a]'" _generate_setup(name, sourcedir, extra_compile_args) if jit: diff --git a/python/cutlass/epilogue/epilogue.py b/python/cutlass/epilogue/epilogue.py index 16d1fec8..a3a17506 100644 --- a/python/cutlass/epilogue/epilogue.py +++ b/python/cutlass/epilogue/epilogue.py @@ -118,7 +118,7 @@ def trace(fn, example_tensors, **kwargs): """ Trace `fn(**example_tensors)` and generates epilogue visitor - :param fn: Python callables + :param fn or str: Python callable or string of the epilogue function :param example_tensors: example inputs for fn :type example_tensors: dict @@ -153,6 +153,22 @@ def trace(fn, example_tensors, **kwargs): pass setattr(EpilogueFunctor, "__call__", staticmethod(fn)) + epilogue_functor = EpilogueFunctor(**kwargs) + epilogue_functor.trace(example_tensors) + return epilogue_functor + elif isinstance(fn, str): + class EpilogueFunctor(PythonASTFrontend): + def __init__(self, cc=None, **kwargs): + self.source = textwrap.dedent(fn) + if not cc: + cc = device_cc() + super().__init__(cc, **kwargs) + + def parse(self, example_inputs) -> None: + self.example_inputs = example_inputs + self.ast = ast.parse(self.source) + self.visit(self.ast) + epilogue_functor = EpilogueFunctor(**kwargs) epilogue_functor.trace(example_tensors) return epilogue_functor diff --git a/python/cutlass/library_defaults.py b/python/cutlass/library_defaults.py index da321577..f5ea0441 100644 --- a/python/cutlass/library_defaults.py +++ b/python/cutlass/library_defaults.py @@ -45,7 +45,7 @@ from cutlass_cppgen.utils.check import valid_stage_count from cutlass_cppgen.utils.datatypes import td_from_profiler_td, td_from_profiler_op -_generator_ccs = [50, 60, 61, 70, 75, 80, 90] +_generator_ccs = [50, 60, 61, 70, 75, 80, 90, 100] class KernelsForDataType: @@ -258,6 +258,9 @@ class ArchOptions: self.op_class = None self.allowed_math_operations = allowed_math_operations + if target_cc == 100 and kernel_cc == 90 or target_cc == 90 and kernel_cc == 100: + return + # Identify the method within CUTLASS generator script that generates kernel # descriptions for the target CC generate_function_name = "GenerateSM" + str(kernel_cc) @@ -292,6 +295,7 @@ class ArchOptions: # find available opclasses and data types for name, op_list in manifest.operations[operation_kind][kernel_cc].items(): for op in op_list: + if operation_kind == cutlass_library.OperationKind.Gemm: if op.gemm_kind not in gemm_kinds: continue @@ -316,7 +320,7 @@ class ArchOptions: # TF32 kernels only supported on SM80 and beyond if self.cc < 80: continue - elif self.cc == 90: + elif self.cc == 90 or self.cc == 100: if (op.A.element != cutlass_library.DataType.f32 or op.B.element != cutlass_library.DataType.f32 or op.C.element != cutlass_library.DataType.f32): @@ -550,8 +554,8 @@ class OptionRegistry: def __init__(self, target_cc: int): self.registry = {} - if target_cc > 90: - raise Exception(f"Unsupported compute capability {target_cc}. The CUTLASS Python interface only supports compute capabilities up to 90.") + if target_cc > 100 and (target_cc not in [101, 103, 120, 121]): + raise Exception(f"Unsupported compute capability {target_cc}. The CUTLASS Python interface only supports compute capabilities up to the Blackwell architecture.") gemm_kinds = [cutlass_library.GemmKind.Universal, cutlass_library.GemmKind.Universal3x] operation_kinds = [cutlass_library.OperationKind.Gemm, cutlass_library.OperationKind.Conv2d] diff --git a/python/cutlass/op/conv.py b/python/cutlass/op/conv.py index 4f21d854..711b27da 100644 --- a/python/cutlass/op/conv.py +++ b/python/cutlass/op/conv.py @@ -212,7 +212,7 @@ class Conv2d(OperationBase): ): super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=OperationKind.Conv2d) # Verify the kernel cc - if self.current_cc == 90: + if self.current_cc in [90, 100, 101, 103]: # The Conv2d kernel on Hopper (SM90) is currently unsupported # Revert to use SM80-tagged kernels cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") diff --git a/python/cutlass/op/gemm.py b/python/cutlass/op/gemm.py index fddd0c09..a6f9b1ab 100644 --- a/python/cutlass/op/gemm.py +++ b/python/cutlass/op/gemm.py @@ -123,6 +123,7 @@ from cutlass_library import ( DataType, DataTypeSize, GemmUniversalMode, + KernelScheduleSuffixes, ) import cutlass_cppgen @@ -323,8 +324,8 @@ class Gemm(OperationBase): if self.op_class == cutlass_cppgen.OpcodeClass.Simt: raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp') - if self.current_cc == 90: - raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90') + if self.current_cc in [90, 100, 101, 103]: + raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90+') self._swizzling_functor = swizzling_functor # @@ -394,6 +395,11 @@ class Gemm(OperationBase): return (valid, msg) valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler) + + if self.cc in [100, 101, 103] and td.kernel_schedule is not None and td.is_2sm and td.cluster_shape[0] % 2 != 0: + valid = False + msg = "Cluster shape must be divisible by 2 for 2SM kernels on SM100, SM101, and SM103" + return valid, msg def tile_descriptions(self) -> list: diff --git a/python/cutlass/op/gemm_grouped.py b/python/cutlass/op/gemm_grouped.py index 594106f2..59f90535 100644 --- a/python/cutlass/op/gemm_grouped.py +++ b/python/cutlass/op/gemm_grouped.py @@ -133,7 +133,7 @@ class GroupedGemm(Gemm): ) # Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80 - if self.current_cc == 90: + if self.current_cc in [90, 100, 101, 103]: self._reset_options(80) self._reset_operations(reset_epilogue=False) diff --git a/python/cutlass/op/op.py b/python/cutlass/op/op.py index 88ccd26e..bebf07a7 100644 --- a/python/cutlass/op/op.py +++ b/python/cutlass/op/op.py @@ -47,6 +47,7 @@ from cutlass_library import ( import cutlass_cppgen from cutlass_cppgen import get_option_registry from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor +from cutlass_cppgen.backend.evt.passes.util import cc_map from cutlass_cppgen.backend.utils.device import device_cc from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs @@ -251,13 +252,13 @@ class OperationBase: mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo) if not self.specified_kernel_cc: - if self.current_cc == 90: + if self.current_cc in [90, 100, 101, 103]: # CUTLASS 3.0 kernels do not use different math operations. If one is specified, we # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") self._reset_options(80) self._reset_operations(reset_epilogue=False) - elif self.current_cc == 90: + elif self.current_cc in [90, 100, 101, 103]: raise Exception("CUTLASS 3.0 kernels do not use different math operations. " "To use 2.x kernels with a specific math operation, do not set the `kernel_cc`" "parameter when constructing the plan.") @@ -283,7 +284,7 @@ class OperationBase: elements_per_access = self.epilogue_functor.epilogue_vector_length if not self.specified_kernel_cc: - if self.current_cc == 90 and activation != identity: + if self.current_cc in [90, 100, 101, 103] and activation != identity: # CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation, # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") @@ -291,13 +292,13 @@ class OperationBase: raise Exception("CUTLASS 2.x kernels require element C to be the same as element D") self._reset_options(80) self._reset_operations(reset_epilogue=False) - elif (self.cc == 90 and self.current_cc != 90 and activation == identity and self._math_operation is None): + elif (self.cc in [90, 100, 101, 103] and self.current_cc not in [90, 100, 101, 103] and activation == identity and self._math_operation is None): # SM80 fallback kernels are currently used. Since an identity activation is requested, # we can switch back to using SM90 kernels. - self._reset_options(90) + self._reset_options(self.cc) self._reset_operations(reset_epilogue=False) else: - if self.current_cc == 90 and activation != identity: + if self.current_cc in [90, 100, 101, 103] and activation != identity: raise Exception("Epilogues with elementwise fusion are not currently supported " "in the Python interface for 3.x kernels. To use 2.x kernels " "with fused elementwise epilogues, do not set the `kernel_cc` " @@ -385,12 +386,12 @@ class OperationBase: """ Create the epilogue visitor """ - self.epilogue_functor = EpilogueFunctorVisitor(self.cc, visitor) + self.epilogue_functor = EpilogueFunctorVisitor(cc_map[self.cc], visitor) # The epilogue_functor may consume too much shared memory # Reset the possible operations - if self.cc != 90: - # The shared memory is only a concern for sm90 epilogue + if self.cc not in [90, 100, 101, 103]: + # The shared memory is only a concern for sm90+ epilogue # In sm80, the epilogue and mainloop share the shared memory return @@ -400,7 +401,7 @@ class OperationBase: for operation in self.possible_operations.all_operations: td = datatypes.td_from_profiler_op(operation) # Filter invalid epilogue schedules - if td.epilogue_schedule not in [ + if cc_map[self.cc] == 90 and td.epilogue_schedule not in [ cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized, cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]: continue diff --git a/python/cutlass/utils/check.py b/python/cutlass/utils/check.py index ff76a42b..108f268b 100644 --- a/python/cutlass/utils/check.py +++ b/python/cutlass/utils/check.py @@ -36,7 +36,7 @@ Utility functions for checking constraints on kernels and calculating kernel att import ctypes -from cutlass_library import DataTypeSize, OperationKind, SharedMemPerCC +from cutlass_library import DataTypeSize, KernelScheduleSuffixes, OperationKind, SharedMemPerCC import cutlass_cppgen from cutlass_cppgen.backend.library import TileDescription @@ -54,7 +54,9 @@ def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: Operatio :return: number of bytes of shared memory consumed by a single stage :rtype: int """ - m, n, k = td.threadblock_shape + m, n, k = td.blackwell_threadblock_shape + if td.is_2sm: + m //= 2 if operation_kind == OperationKind.Gemm: stage_barrier_bytes = 32 @@ -106,7 +108,7 @@ def valid_stage_count( valid for the provided device and the second element being an error message :rtype: tuple """ - if kernel_cc == 90: + if kernel_cc in [90, 100, 101, 103]: if (td.stages is None or td.stages == 0): # Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically # determines the stage count to use. Thus, all settings are valid in these scenarios. @@ -157,10 +159,10 @@ def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple: :rtype: tuple """ - if cc < 90: + if cc < 90 or cc in [120, 121]: if cluster_shape != [1, 1, 1]: return (False, - f"Cluster shape for pre-SM90 architectures must be [1, 1, 1]. Received cluster shape of " + f"Cluster shape for pre-SM90 architectures and SM 120 and 121 must be [1, 1, 1]. Received cluster shape of " f"{cluster_shape} for SM{cc}.") else: return (True, "") @@ -174,15 +176,6 @@ def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple: "CUTLASS kernels currently require the third dimension of cluster shape to be 1. " f"Received cluster shape of {cluster_shape}.") - # The CUDA programming guide currently defines a maximum of 8 thread blocks per cluster - # as being portably supported (https://docs.nvidia.com/cuda/cuda-c-programming-guide/#thread-block-clusters). - # Current CUTLASS kernels only have non-unit cluster dimensions within the first two dimensions, - # so we check that the first two dimensions of the cluster shape do not exceed 8 thread blocks in total. - blocks_in_2d = cluster_shape[0] * cluster_shape[1] - if blocks_in_2d > 8: - return (False, - f"Thread block clusters with more than 8 thread blocks are currently unsupported on SM{cc}. " - f"Received cluster shape {cluster_shape}, which has {blocks_in_2d} thread blocks.") return (True, "") @@ -211,16 +204,16 @@ def valid_schedule( kernel_auto = (kernel_schedule == cutlass_cppgen.KernelScheduleType.ScheduleAuto) epilogue_auto = (epilogue_schedule == cutlass_cppgen.EpilogueScheduleType.ScheduleAuto) tile_scheduler_default = (tile_scheduler == cutlass_cppgen.TileSchedulerType.Default) - if cc < 90 and not (kernel_auto and epilogue_auto and tile_scheduler_default): - return (False, "Non-default schedules are only supported on SM90 and beyond") + if (cc < 90 or cc in [120, 121]) and not (kernel_auto and epilogue_auto and tile_scheduler_default): + return (False, "Non-default schedules are only supported on SM90 and beyond (excluding SM120 and SM121)") - if (kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto): + if cc == 90 and ((kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto)): return (False, "Kernel and epilogue schedules must either both be auto or neither be auto") if not tile_scheduler_default: cooperative_kernels = [cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, cutlass_cppgen.KernelScheduleType.CpAsyncWarpSpecializedCooperative] - if (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels): + if cc == 90 and (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels): return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule") return (True, "") diff --git a/python/cutlass_library/emit_kernel_listing.py b/python/cutlass_library/emit_kernel_listing.py index 3121b7b0..fbe52eb5 100755 --- a/python/cutlass_library/emit_kernel_listing.py +++ b/python/cutlass_library/emit_kernel_listing.py @@ -348,11 +348,15 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode 'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2', ] + block_scaled_tile_k = ['x128_', 'x256_'] + sm103_block_scaled_data_type = [ 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2', 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1', ] + sm103_block_scaled_tile_k = ['x768_'] + block_scaled_cluster_size = [ '4x4x1', '2x1x1', '0x0x1' # dynamic cluster @@ -360,11 +364,12 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode block_scaled_layouts = ['tnt'] # regex list must be in kernel procedural name order - block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" - block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" + block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" - sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" - sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + sm103_block_scaled_prefetch_policy = ['tmapf'] + sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*" + sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*" if arch in ["100a", "100f"]: kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index 1d247625..0d2449e7 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -985,33 +985,38 @@ ${compile_guard_end} epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] if opcode_class_main == OpcodeClass.BlockScaledTensorOp: - is_no_smem_epilogue = operation.epilogue_schedule in [EpilogueScheduleType.NoSmemWarpSpecialized1Sm, EpilogueScheduleType.NoSmemWarpSpecialized2Sm] grouped = is_grouped(operation.gemm_kind) if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped): epi_tile_mn = "cute::Shape" - if not is_no_smem_epilogue: + if is_tma_epilogue(operation.epilogue_schedule): epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)] if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped): epi_tile_mn = "cute::Shape" - if not is_no_smem_epilogue: + if is_tma_epilogue(operation.epilogue_schedule): epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)] - if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: + # SM103 FP4 Ultra + is_sm103_fp4_ultra_1sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, grouped) + ] + is_sm103_fp4_ultra_2sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, grouped) + ] + if cta_n == 256 and is_sm103_fp4_ultra_1sm_kernel_schedule: epi_tile_mn = "cute::Shape" - if not is_no_smem_epilogue: - epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm] - if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: + if is_tma_epilogue(operation.epilogue_schedule): + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)] + if cta_n == 256 and is_sm103_fp4_ultra_2sm_kernel_schedule: epi_tile_mn = "cute::Shape" - if not is_no_smem_epilogue: - epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm] - - if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: - epi_tile_mn = "cute::Shape" - if not is_no_smem_epilogue: - epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm] - if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: - epi_tile_mn = "cute::Shape" - if not is_no_smem_epilogue: - epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm] + if is_tma_epilogue(operation.epilogue_schedule): + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)] element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>' element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>' diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 19cef8c7..c10fe315 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -5239,7 +5239,7 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmK if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments. @@ -5305,7 +5305,7 @@ def GenerateSM90_TensorOp_16b_WGMMA_alignx_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=100, default_level=101, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=101, exhaustive_level=9992) is_aligned = False # layouts for ABC and their alignments. @@ -5366,7 +5366,7 @@ def GenerateSM90_SparseTensorOp_16b_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments. @@ -5431,7 +5431,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=120, default_level=121, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=120, default_level=121, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments @@ -5489,7 +5489,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_alignx_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=100, default_level=101, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=101, exhaustive_level=9992) is_aligned = False # layouts for ABC and their alignments. @@ -5546,7 +5546,7 @@ def GenerateSM90_SparseTensorOp_tf32_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=120, default_level=121, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=120, default_level=121, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments @@ -5601,7 +5601,7 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments @@ -5653,7 +5653,7 @@ def GenerateSM90_TensorOp_int8_WGMMA_alignx_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) is_aligned = False # layouts for ABC and their alignments @@ -5705,7 +5705,7 @@ def GenerateSM90_SparseTensorOp_int8_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments @@ -5760,7 +5760,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmK if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments @@ -5826,7 +5826,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version, if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments @@ -5908,7 +5908,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=0, default_level=101, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=0, default_level=101, exhaustive_level=9992) is_aligned = False # layouts for ABC and their alignments @@ -5965,7 +5965,7 @@ def GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 1): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9999) + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9999) is_aligned = True # layouts for ABC, their alignments will be fixed later based on the data type @@ -6056,7 +6056,7 @@ def GenerateSM90_SparseTensorOp_fp8_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments @@ -6721,6 +6721,31 @@ def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version): # Blackwell SM 100 generators +try: + import cutlass_library.sm100_utils + from cutlass_library.sm100_utils import ( + generate_tf32_math_instructions_sm100, + generate_16b_math_instructions_sm100, + generate_f8f6f4_math_instructions_sm100, + generate_mxf8f6f4_math_instructions_sm100, + generate_mxf4nvf4_math_instructions_sm100, + generate_fp8_math_instructions_sm100, + generate_cluster_shapes_sm100, + get_pruning_level_from_global_level + ) +except ImportError: + import sm100_utils + from sm100_utils import ( + generate_tf32_math_instructions_sm100, + generate_16b_math_instructions_sm100, + generate_f8f6f4_math_instructions_sm100, + generate_mxf8f6f4_math_instructions_sm100, + generate_mxf4nvf4_math_instructions_sm100, + generate_fp8_math_instructions_sm100, + generate_cluster_shapes_sm100, + get_pruning_level_from_global_level + ) + ################################################################################################### def get_tma_alignment_elt(data_type : DataType, is_f8f6f4 : bool = True ) -> int: @@ -6743,6 +6768,8 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return + instantiation_level = manifest.get_instantiation_level(pruned_level=490, default_level=490, exhaustive_level=9999) + # layouts for ABC and their alignments. layouts = [ [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], @@ -6779,36 +6806,18 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): min_cc = 100 max_cc = thor_sm - math_instructions_1sm = [ - # tf32 -> f32 - MathInstruction( - [64, 128, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - ] + math_instructions_1sm, math_instructions_2sm = generate_tf32_math_instructions_sm100(instantiation_level) - cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline: - cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1] - , DynamicClusterShape - ] + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) tile_schedulers = [ - TileSchedulerType.Default + TileSchedulerType.Default, TileSchedulerType.StreamK ] # 1xSM MMA kernels @@ -6828,38 +6837,6 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): tile_schedulers=tile_schedulers) # 2xSM MMA kernels - math_instructions_2sm = [ - MathInstruction( - [128, 128, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 128, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - ] - - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1] - , DynamicClusterShape - ] - for math_inst in math_instructions_2sm: tile_descriptions = [] for cluster_shape in cluster_shapes_2sm: @@ -6883,6 +6860,8 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return + instantiation_level = manifest.get_instantiation_level(pruned_level=490, default_level=490, exhaustive_level=9999) + # layouts for ABC and their alignments. C alignment will be set later based on output type layouts = [ [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], @@ -6897,76 +6876,22 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK thor_sm = ThorSMRenumbering(cuda_version) + math_instructions_1sm, math_instructions_2sm = generate_16b_math_instructions_sm100(instantiation_level) + min_cc = 100 max_cc = thor_sm grouped = is_grouped(gemm_kind) - math_instructions_1sm = [ - # f16 -> f16 - #MathInstruction( - # [64, 64, 16], - # DataType.f16, DataType.f16, DataType.f16, - # OpcodeClass.TensorOp, - # MathOperation.multiply_add), - MathInstruction( - [64, 128, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # f16 -> f32 - MathInstruction( - [64, 128, 16], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 16], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 16], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # bf16 -> f32 - MathInstruction( - [64, 128, 16], - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 16], - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 16], - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add)] - - cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1],[4,4,1] - , DynamicClusterShape - ] + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1] - , DynamicClusterShape - ] + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) tile_schedulers = [ - TileSchedulerType.Default + TileSchedulerType.Default, TileSchedulerType.StreamK ] # 1xSM MMA kernels @@ -7039,90 +6964,6 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) # 2xSM MMA kernels - math_instructions_2sm = [ - # 128x64x16 - #MathInstruction( - # [128, 64, 16], - # DataType.f16, DataType.f16, DataType.f16, - # OpcodeClass.TensorOp, - # MathOperation.multiply_add), - # 128x128x16 - MathInstruction( - [128, 128, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 16], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 16], - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - - # 128x256x16 - MathInstruction( - [128, 256, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 16], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 16], - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - - # 256x128x16 - MathInstruction( - [256, 128, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 128, 16], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 128, 16], - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - - # 256x256x16 - MathInstruction( - [256, 256, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 16], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 16], - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add)] - - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1] - , DynamicClusterShape - ] - for math_inst in math_instructions_2sm: tile_descriptions = [] for cluster_shape in cluster_shapes_2sm: @@ -7199,6 +7040,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return + instantiation_level = manifest.get_instantiation_level(pruned_level=591 , default_level=591 , exhaustive_level=9999) + # layouts for ABC and their alignments. layouts = [ [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], @@ -7219,77 +7062,18 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK epi_type = DataType.f32 grouped = is_grouped(gemm_kind) - math_instructions_1sm = [ - # inst 64x128 - MathInstruction( - [64, 128, 32], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [64, 128, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [64, 128, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [64, 128, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 128x128 - MathInstruction( - [128, 128, 32], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 128x256 - MathInstruction( - [128, 256, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add)] + math_instructions_1sm, math_instructions_2sm = generate_fp8_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) - cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1] - , DynamicClusterShape - ] + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1] - , DynamicClusterShape - ] + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) tile_schedulers = [ - TileSchedulerType.Default, + TileSchedulerType.Default, TileSchedulerType.StreamK ] # 1xSM MMA kernels @@ -7418,86 +7202,6 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) # 2xSM MMA kernels - math_instructions_2sm = [ - # inst 128x128 - MathInstruction( - [128, 128, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 128x256 - MathInstruction( - [128, 256, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 256x128 - MathInstruction( - [256, 128, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 128, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 128, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 256x256 - MathInstruction( - [256, 256, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ] - - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1] - , DynamicClusterShape - ] for math_inst in math_instructions_2sm: tile_descriptions = [] @@ -7633,6 +7337,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return + instantiation_level = manifest.get_instantiation_level(pruned_level=593, default_level=593, exhaustive_level=9999) + grouped = is_grouped(gemm_kind) # layouts for ABC and their alignments. @@ -7651,111 +7357,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, max_cc = 100 epi_type = DataType.f32 - math_instructions_1sm = [ - # inst 64x128 - MathInstruction( - [64, 128, 32], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [64, 128, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [64, 128, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [64, 128, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 128x32 - MathInstruction( - [128, 32, 32], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 32, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 32, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 32, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 128x64 - MathInstruction( - [128, 64, 32], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 64, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 64, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 64, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 128x128 - MathInstruction( - [128, 128, 32], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 128x256 - MathInstruction( - [128, 256, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add)] + pruning_level = get_pruning_level_from_global_level(instantiation_level) - cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1] - , DynamicClusterShape - ] + math_instructions_1sm, math_instructions_2sm = generate_fp8_math_instructions_sm100(instantiation_level, enable_compile_time_dtype=grouped or pruning_level >= 1, enable_runtime_dtype=not grouped) + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) tile_schedulers = [ TileSchedulerType.Default, @@ -7865,36 +7471,31 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) -def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): +def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + # SM100 MMA with mixed F4/F6/F8 inputs + without block scale if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return + instantiation_level = manifest.get_instantiation_level(pruned_level=590, default_level=590, exhaustive_level=9999) + + grouped = is_grouped(gemm_kind) + # layouts for ABC and their alignments. layouts = [ [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], ] - instruction_sizes_1sm = [ - # [64, 128, 32], - [128, 128, 32], - # [64, 256, 32], - [128, 256, 32], - ] + math_instructions_1sm, math_instructions_2sm = generate_f8f6f4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) - instruction_sizes_2sm = [ - # [128, 128, 32], - # [128, 256, 32], - [256, 128, 32], - [256, 256, 32], - ] + def change_priority_func(shapes_1sm, shapes_2sm): + shapes_1sm[(1,2,1)] = 6 + shapes_1sm[(1,4,1)] = 6 + shapes_2sm[(2,2,1)] = 6 + shapes_2sm[(2,4,1)] = 6 + shapes_2sm[(4,2,1)] = 6 - ab_types = [ - DataType.f4, DataType.f6, DataType.f8, - DataType.e2m1, DataType.e3m2, DataType.e4m3, - ] - - acc_types = [ DataType.f32 ] + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func) tile_schedulers = [ TileSchedulerType.Default, TileSchedulerType.StreamK @@ -7907,61 +7508,13 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): epi_type = DataType.f32 - math_instructions_1sm = [] - is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) - # Usage: - - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types): - is_runtime_datatype_a = is_runtime_datatype(a_type) - is_runtime_datatype_b = is_runtime_datatype(b_type) - - # A/B datatypes should be both static or dynamic - if (is_runtime_datatype_a != is_runtime_datatype_b): - continue - - math_instructions_1sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - - math_instructions_2sm = [] - - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types): - is_runtime_datatype_a = is_runtime_datatype(a_type) - is_runtime_datatype_b = is_runtime_datatype(b_type) - - # A/B datatypes should be both static or dynamic - if (is_runtime_datatype_a != is_runtime_datatype_b): - continue - - math_instructions_2sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - - cluster_shapes_1sm = [ - # [1,2,1], - [2,1,1], - [1,1,1], - # [1,4,1], - [4,4,1] - , DynamicClusterShape - ] - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [ - [2,1,1], - [1,1,1] - , DynamicClusterShape - ] + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) # 1xSM MMA kernels for math_inst in math_instructions_1sm: @@ -8022,22 +7575,6 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], tile_schedulers=tile_schedulers) - cluster_shapes_2sm = [ - [2,1,1], - # [2,2,1], - # [2,4,1], - # [4,1,1], - # [4,2,1], - [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [ - [2,1,1] - , DynamicClusterShape - ] - for math_inst in math_instructions_2sm: tile_descriptions = [] for cluster_shape in cluster_shapes_2sm: @@ -8101,25 +7638,31 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): [[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]], tile_schedulers=tile_schedulers) def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): + # SM100 MMA with mixed F4/F6/F8 inputs + block scale if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return + instantiation_level = manifest.get_instantiation_level(pruned_level=590, default_level=590, exhaustive_level=9999) + grouped = is_grouped(gemm_kind) layouts = [ [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.ColumnMajor, 0]], [[LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 128], [LayoutType.RowMajor, 0]], ] - instruction_sizes_1sm = [ - [128, 128, 32], [128, 256, 32], # Block scaled kernels only support M=128 for 1SM cases - ] + math_instructions_1sm, math_instructions_2sm = generate_mxf8f6f4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) - instruction_sizes_2sm = [ - [256, 128, 32], - [256, 256, 32], - ] + def change_priority_func(shapes_1sm, shapes_2sm): + shapes_1sm[(1,2,1)] = 6 + shapes_1sm[(1,4,1)] = 6 + shapes_2sm[(2,2,1)] = 6 + shapes_2sm[(2,4,1)] = 6 + shapes_2sm[(4,2,1)] = 6 + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func) ab_types = [ DataType.f4, DataType.f6, @@ -8147,64 +7690,17 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud epi_type = DataType.f32 - math_instructions_1sm = [] - is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types): - is_runtime_datatype_a = is_runtime_datatype(a_type) - is_runtime_datatype_b = is_runtime_datatype(b_type) - - # A/B datatypes should be both static or dynamic - if (is_runtime_datatype_a != is_runtime_datatype_b): - continue - - math_instructions_1sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) - ) - - math_instructions_2sm = [] - - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types): - is_runtime_datatype_a = is_runtime_datatype(a_type) - is_runtime_datatype_b = is_runtime_datatype(b_type) - - # A/B datatypes should be both static or dynamic - if (is_runtime_datatype_a != is_runtime_datatype_b): - continue - - math_instructions_2sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) - ) - - cluster_shapes_1sm = [ - [1,1,1], - # [1,2,1], - [2,1,1], - # [1,4,1], - [4,4,1] - , DynamicClusterShape - ] - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [ - [1,1,1], - [2,1,1] - , DynamicClusterShape - ] + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) # 1xSM MMA kernels for math_inst in math_instructions_1sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp tile_descriptions = [] for cluster_shape in cluster_shapes_1sm: multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape @@ -8276,24 +7772,8 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud [[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]] , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) - cluster_shapes_2sm = [ - [2,1,1], - # [2,2,1], - # [2,4,1], - [4,1,1], - # [4,2,1], - [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [ - [2,1,1], - [4,1,1] - , DynamicClusterShape - ] - for math_inst in math_instructions_2sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp tile_descriptions = [] for cluster_shape in cluster_shapes_2sm: multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) @@ -8403,6 +7883,8 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return + instantiation_level = manifest.get_instantiation_level(pruned_level=591, default_level=591, exhaustive_level=9999) + grouped = is_grouped(gemm_kind) # layouts for ABC and their alignments. @@ -8411,21 +7893,16 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.ColumnMajor, 0]], ] - instruction_sizes_1sm = [ - [128, 64, 64], - [128, 128, 64], - ] + math_instructions_1sm, math_instructions_2sm = generate_mxf4nvf4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) - instruction_sizes_2sm = [ - [256, 64, 64], - [256, 128, 64], - [256, 192, 64], [256, 256, 64] - ] + def change_priority_func(shapes_1sm, shapes_2sm): + shapes_1sm[(1,2,1)] = 6 + shapes_1sm[(1,4,1)] = 6 + shapes_2sm[(2,2,1)] = 6 + shapes_2sm[(2,4,1)] = 6 + shapes_2sm[(4,2,1)] = 6 - ab_types = [ - DataType.f4, - DataType.e2m1, - ] + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func=change_priority_func) acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions @@ -8444,80 +7921,17 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio epi_type = DataType.f32 - math_instructions_1sm = [] - is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types): - is_runtime_datatype_a = is_runtime_datatype(a_type) - is_runtime_datatype_b = is_runtime_datatype(b_type) - - # A/B datatypes should be both static or dynamic - if (is_runtime_datatype_a != is_runtime_datatype_b): - continue - - math_instructions_1sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) # UE8M0 scale factor - ) - math_instructions_1sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue4m3) # UE4M3 scale factor - ) - - math_instructions_2sm = [] - - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types): - is_runtime_datatype_a = is_runtime_datatype(a_type) - is_runtime_datatype_b = is_runtime_datatype(b_type) - - # A/B datatypes should be both static or dynamic - if (is_runtime_datatype_a != is_runtime_datatype_b): - continue - - math_instructions_2sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) # UE8M0 scale factor - ) - math_instructions_2sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue4m3) # UE4M3 scale factor - ) - - cluster_shapes_1sm = [ - [1,1,1], - # [1,2,1], - [2,1,1], - # [1,4,1], - [4,4,1] - , DynamicClusterShape - ] - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [ - [1,1,1], - [2,1,1] - , DynamicClusterShape - ] + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) # 1xSM MMA kernels for math_inst in math_instructions_1sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp tile_descriptions = [] for cluster_shape in cluster_shapes_1sm: multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape @@ -8527,6 +7941,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio math_inst.instruction_shape[1] * multiplier_1sm[1], math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + assert math_inst.instruction_shape[2] * 4 == 256 data_types = [ { @@ -8626,37 +8041,22 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio # E2M1 x E2M1, vector size 16, UE4M3 isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) + epi_nosmem_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized1Sm, grouped) nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped) fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, grouped) - nvfp4_schedule = [nvfp4_kernel_schedule, epi_schedule] - fp4_schedule = [fp4_kernel_schedule, epi_schedule] - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [nvfp4_schedule] + nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]] + fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind ) if isFp4: - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, fp4_schedules , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind ) - cluster_shapes_2sm = [ - [2,1,1], - # [2,2,1], - # [2,4,1], - [4,1,1], - # [4,2,1], - [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [ - [2,1,1], - [4,1,1] - , DynamicClusterShape - ] - for math_inst in math_instructions_2sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp tile_descriptions = [] for cluster_shape in cluster_shapes_2sm: multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) @@ -8765,27 +8165,29 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 epi_schedule = EpilogueScheduleType.ScheduleAuto if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm + epi_nosmem_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized2Sm, grouped) nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped) fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, grouped) - nvfp4_schedule = [nvfp4_kernel_schedule, epi_schedule] - fp4_schedule = [fp4_kernel_schedule, epi_schedule] - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [nvfp4_schedule] + nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]] + fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) if isFp4: - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, fp4_schedules , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) - - def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): # SM100 MMA with F4 + block scale - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + if not CudaToolkitVersionSatisfies(cuda_version, 13, 0): return + grouped = is_grouped(gemm_kind) + # layouts for ABC and their alignments. layouts = [ [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.ColumnMajor, 0]], ] instruction_sizes_1sm = [ @@ -8794,14 +8196,32 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ instruction_sizes_2sm = [ [256, 128, 96], + [256, 192, 96], + [256, 256, 96] ] ab_types = [ + DataType.f4, DataType.e2m1, ] + sf_types = [ + DataType.ue4m3, + DataType.ue8m0 + ] + acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions + def tile_schedulers(sfdtype): + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K. + if grouped: + return [TileSchedulerType.Default] + if sfdtype["type"] == DataType.void: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + min_cc = 103 max_cc = 103 epi_type = DataType.f32 @@ -8810,7 +8230,7 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types): + for instr_size, a_type, b_type, sf_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, sf_types, acc_types): is_runtime_datatype_a = is_runtime_datatype(a_type) is_runtime_datatype_b = is_runtime_datatype(b_type) @@ -8824,12 +8244,12 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ a_type, b_type, acc_type, OpcodeClass.BlockScaledTensorOp, MathOperation.multiply_add, - DataType.ue8m0) # UE8M0 scale factor + sf_type) ) math_instructions_2sm = [] - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types): + for instr_size, a_type, b_type, sf_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, sf_types, acc_types): is_runtime_datatype_a = is_runtime_datatype(a_type) is_runtime_datatype_b = is_runtime_datatype(b_type) @@ -8843,7 +8263,7 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ a_type, b_type, acc_type, OpcodeClass.BlockScaledTensorOp, MathOperation.multiply_add, - DataType.ue8m0) # UE8M0 scale factor + sf_type) ) cluster_shapes_1sm = [ @@ -8851,15 +8271,15 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ # [1,2,1], [2,1,1], # [1,4,1], - [4,4,1] - , DynamicClusterShape + [4,4,1], + DynamicClusterShape ] # 1xSM MMA kernels for math_inst in math_instructions_1sm: tile_descriptions = [] for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = cluster_shape + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape tile_descriptions.append( TileDescription([ math_inst.instruction_shape[0] * multiplier_1sm[0], @@ -8898,8 +8318,69 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ "sf_type" : math_inst.element_scale_factor, "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } ] + # Set alignment d based on Destination format. for layout in layouts: for data_type in data_types: # Set alignment d based on Destination format. @@ -8908,21 +8389,29 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ else: layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]]) - if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1): + if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue # E2M1 x E2M1, vector size 32, E8 isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 - fp4_schedule = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, EpilogueScheduleType.NoSmemWarpSpecialized1Sm] - fp4_schedule_disable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, EpilogueScheduleType.NoSmemWarpSpecialized1Sm] - fp4_schedule_enable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, EpilogueScheduleType.NoSmemWarpSpecialized1Sm] - # For FP4 inputs + epilogue_1sm_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized1Sm, grouped) + + nvfp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, grouped), epilogue_1sm_schedule] + nvfp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, grouped), epilogue_1sm_schedule] + nvfp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, grouped), epilogue_1sm_schedule] + fp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, grouped), epilogue_1sm_schedule] + fp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, grouped), epilogue_1sm_schedule] + fp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, grouped), epilogue_1sm_schedule] + nvfp4_schedules = [nvfp4_schedule, nvfp4_schedule_disable_prefetch, nvfp4_schedule_tma_prefetch] + fp4_schedules = [fp4_schedule, fp4_schedule_disable_prefetch, fp4_schedule_tma_prefetch] + + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + nvfp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) if isFp4: - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule, fp4_schedule_disable_prefetch - ,fp4_schedule_enable_prefetch - ] - , gemm_kind=gemm_kind - ) + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + fp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) cluster_shapes_2sm = [ [2,1,1], @@ -8930,14 +8419,14 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ # [2,4,1], [4,1,1], # [4,2,1], - [4,4,1] - , DynamicClusterShape + [4,4,1], + DynamicClusterShape ] for math_inst in math_instructions_2sm: tile_descriptions = [] for cluster_shape in cluster_shapes_2sm: - multiplier_2sm = (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) tile_descriptions.append( TileDescription([ math_inst.instruction_shape[0] * multiplier_2sm[0], @@ -8954,7 +8443,7 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ "acc_type" : math_inst.element_accumulator, "epi_type" : epi_type, "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} }, { "a_type" : math_inst.element_a, @@ -8966,7 +8455,6 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ "sf_type" : math_inst.element_scale_factor, "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} }, - { "a_type" : math_inst.element_a, "b_type" : math_inst.element_b, @@ -8977,8 +8465,69 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ "sf_type" : math_inst.element_scale_factor, "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } ] + # Set alignment d based on Destination format. for layout in layouts: for data_type in data_types: # Set alignment d based on Destination format. @@ -8987,21 +8536,30 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ else: layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]]) - if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1): + if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue # E2M1 x E2M1, vector size 32, E8 isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 - fp4_schedule = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, EpilogueScheduleType.NoSmemWarpSpecialized2Sm] - fp4_schedule_disable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, EpilogueScheduleType.NoSmemWarpSpecialized2Sm] - fp4_schedule_enable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, EpilogueScheduleType.NoSmemWarpSpecialized2Sm] - # For FP4 inputs + epilogue_2sm_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized2Sm, grouped) + + nvfp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, grouped), epilogue_2sm_schedule] + nvfp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, grouped), epilogue_2sm_schedule] + nvfp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, grouped), epilogue_2sm_schedule] + fp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, grouped), epilogue_2sm_schedule] + fp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, grouped), epilogue_2sm_schedule] + fp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, grouped), epilogue_2sm_schedule] + nvfp4_schedules = [nvfp4_schedule, nvfp4_schedule_disable_prefetch, nvfp4_schedule_tma_prefetch] + fp4_schedules = [fp4_schedule, fp4_schedule_disable_prefetch, fp4_schedule_tma_prefetch] + + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + nvfp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) if isFp4: - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule, fp4_schedule_disable_prefetch - ,fp4_schedule_enable_prefetch - ] - , gemm_kind=gemm_kind - ) + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + fp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): @@ -9053,7 +8611,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): ] tile_schedulers = [ - TileSchedulerType.Default, + TileSchedulerType.Default, TileSchedulerType.StreamK ] # 1xSM MMA kernels @@ -9242,7 +8800,7 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version): max_cc = thor_sm tile_schedulers = [ - TileSchedulerType.Default, + TileSchedulerType.Default, TileSchedulerType.StreamK ] kernel_data_types = [ @@ -9370,7 +8928,7 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version): max_cc = thor_sm tile_schedulers = [ - TileSchedulerType.Default, + TileSchedulerType.Default, TileSchedulerType.StreamK ] kernel_data_types = [ @@ -9498,7 +9056,7 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version): max_cc = thor_sm tile_schedulers = [ - TileSchedulerType.Default, + TileSchedulerType.Default, TileSchedulerType.StreamK ] kernel_data_types = [ @@ -9625,7 +9183,7 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version): max_cc = thor_sm tile_schedulers = [ - TileSchedulerType.Default, + TileSchedulerType.Default, TileSchedulerType.StreamK ] kernel_data_types = [ @@ -9766,7 +9324,7 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): max_cc = thor_sm tile_schedulers = [ - TileSchedulerType.Default, + TileSchedulerType.Default, TileSchedulerType.StreamK ] math_instructions_1sm = [ @@ -9943,418 +9501,6 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], tile_schedulers=tile_schedulers) - -# -# Kernels using the stream-K tile scheduler. -# A reduced set of kernels is generated for these schedulers to reduce functional -# and perofrmance testing time. -# - -def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], - [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], - [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], - [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], - [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4]], - - ] - - data_types = [ - { - "a_type" : DataType.f32, - "b_type" : DataType.f32, - "c_type" : DataType.f32, - "d_type" : DataType.f32, - "acc_type" : DataType.f32, - "epi_type" : DataType.f32, - } - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - math_instructions_1sm = [ - MathInstruction( - [128, 256, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - ] - - cluster_shapes_1sm = [ - [1,2,1], [1,1,1], [1,4,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [ - [1,2,1], [1,1,1], [1,4,1] - , DynamicClusterShape - ] - - tile_schedulers = [ - TileSchedulerType.StreamK, - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, - [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) - - # 2xSM MMA kernels - math_instructions_2sm = [ - MathInstruction( - [256, 256, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - ] - - cluster_shapes_2sm = [ - [2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [ - [2,1,1], [2,2,1], [2,4,1], [4,1,1] - , DynamicClusterShape - ] - - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_2sm: - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - if math_inst.instruction_shape[0] == 128: - epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm - else: - epi_schedule = EpilogueScheduleType.ScheduleAuto - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, - [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) - -def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - # layouts for ABC and their alignments. C alignment will be set later based on output type - layouts = [ - [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]], - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - math_instructions_1sm = [ - MathInstruction( - [128, 256, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add)] - - cluster_shapes_1sm = [ - [1,2,1], [1,1,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [ - [1,2,1], [1,1,1] - , DynamicClusterShape - ] - - tile_schedulers = [ - TileSchedulerType.StreamK - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_accumulator, - "d_type" : math_inst.element_accumulator, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - } - ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, - [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) - - # for mixed precision kernels, also generate kernels that write output matrix in the A/B format - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - data_types_mixed = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_a, - "d_type" : math_inst.element_a, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - } - ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, - [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) - - # 2xSM MMA kernels - math_instructions_2sm = [ - MathInstruction( - [256, 256, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add)] - - cluster_shapes_2sm = [ - [2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [ - [2,1,1], [2,2,1], [2,4,1], [4,1,1] - , DynamicClusterShape - ] - - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_2sm: - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_accumulator, - "d_type" : math_inst.element_accumulator, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - } - ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - if math_inst.instruction_shape[0] == 128: - epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm - else: - epi_schedule = EpilogueScheduleType.ScheduleAuto - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, - [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) - - # for mixed precision kernels, also generate kernels that write output matrix in the A/B format - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - data_types_mixed = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_a, - "d_type" : math_inst.element_a, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - } - ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, - [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) - - -def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - epi_type = DataType.f32 - - math_instructions_1sm = [ - MathInstruction( - [128, 256, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add)] - - cluster_shapes_1sm = [ - [1,2,1], [2,1,1], [1,1,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [ - [1,2,1], [2,1,1], [1,1,1] - , DynamicClusterShape - ] - - tile_schedulers = [ - TileSchedulerType.StreamK, - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e4m3, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }] - - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - for data_type in data_types: - if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ - ( data_type["d_type"] == DataType.e5m2 ): - continue - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) - - # 2xSM MMA kernels - math_instructions_2sm = [ - MathInstruction( - [256, 256, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - ] - - cluster_shapes_2sm = [ - [2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [ - [2,1,1], [2,2,1], [2,4,1], [4,1,1] - , DynamicClusterShape - ] - - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_2sm: - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.e4m3, - "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }] - - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - for data_type in data_types: - if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ - ( data_type["d_type"] == DataType.e5m2 ): - continue - - if math_inst.instruction_shape[0] == 128: - epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm - else: - epi_schedule = EpilogueScheduleType.ScheduleAuto - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) # Conv Utility functions def make_dims_and_alignments_triple(dim: int, bit_per_element_A: int, bit_per_element_B: int, bit_per_element_C: int): bit_alignment_required_by_tma = 128 @@ -11240,9 +10386,6 @@ def GenerateSM100(manifest, cuda_version): GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version) GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version) - GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version) - - GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version) if not bool(set(manifest.compute_capabilities_feature_set).intersection(arch_family_cc)): GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version) @@ -11252,8 +10395,6 @@ def GenerateSM100(manifest, cuda_version): GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) - GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version) - # StreamK is included in regular generation GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version) @@ -11280,6 +10421,7 @@ def GenerateSM100(manifest, cuda_version): GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) # # Conv # @@ -11733,7 +10875,7 @@ def define_parser(): parser.add_argument("--build-dir", default=".", required=False, help="CUTLASS top-level build directory") parser.add_argument("--curr-build-dir", default=".", help="CUTLASS current build directory. cmake files will be emitted in this directory") parser.add_argument("--generator-target", default='library', help="Target of CUTLASS Library Generator.") - parser.add_argument("--architectures", default='53;60;61;70;75;80;90', help="Target compute architectures") + parser.add_argument("--architectures", default='53;60;61;70;75;80;90;100', help="Target compute architectures") parser.add_argument("--kernels", default='', help='Comma-delimited list to filter kernels by name. ' + 'Specifying this as \"all\" includes ALL the kernels, ' + 'while not specifying this includes only the default set of kernels.') diff --git a/python/cutlass_library/heuristics_provider.py b/python/cutlass_library/heuristics_provider.py index b3f6e5c5..01a4112a 100644 --- a/python/cutlass_library/heuristics_provider.py +++ b/python/cutlass_library/heuristics_provider.py @@ -41,6 +41,7 @@ import logging import ctypes import functools + try: import builtins if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index 2e1bd82a..19875a43 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -508,6 +508,8 @@ class KernelScheduleType(enum.Enum): BlockwiseTmaWarpSpecializedCooperative = enum_auto() PtrArrayBlockwiseTmaWarpSpecializedCooperative = enum_auto() + BlockwiseTmaWarpSpecializedPingpong = enum_auto() + PtrArrayBlockwiseTmaWarpSpecializedPingpong = enum_auto() TmaWarpSpecialized1SmSm100 = enum_auto() TmaWarpSpecialized2SmSm100 = enum_auto() @@ -547,20 +549,35 @@ class KernelScheduleType(enum.Enum): Nvf4TmaWarpSpecialized2SmSm100 = enum_auto() # FP4 Ultra - BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto() - - BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto() + + MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto() + + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto() + + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto() + + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto() Mxf8f6f4TmaWarpSpecializedCooperativeSm120 = enum_auto() Mxf8f6f4TmaWarpSpecializedPingpongSm120 = enum_auto() @@ -589,7 +606,8 @@ KernelScheduleTag = { KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum', KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90', - KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum', + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise', + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8Blockwise', KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100', KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100', @@ -620,27 +638,28 @@ KernelScheduleTag = { KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100', # FP4 Ultra - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative', KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum', KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong', KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum', - KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise', KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100", KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100", @@ -651,6 +670,19 @@ KernelScheduleTag = { KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100", KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100", + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf8f6f4Sm120', KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf8f6f4Sm120', KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120', @@ -681,7 +713,8 @@ KernelScheduleSuffixes = { KernelScheduleType.ImplicitTmaWarpSpecializedSm90: '_warpspecialized', KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', - + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + KernelScheduleType.TmaWarpSpecialized1SmSm100: '_1sm', KernelScheduleType.TmaWarpSpecialized2SmSm100: '_2sm', @@ -709,20 +742,20 @@ KernelScheduleSuffixes = { KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_1sm', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_2sm', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_1sm', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_2sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_ultra_2sm', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_1sm_nopf', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_2sm_nopf', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_1sm_nopf', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_2sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_1sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_2sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_1sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_2sm_nopf', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_1sm_tmapf', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_2sm_tmapf', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_1sm_tmapf', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_2sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_1sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_2sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_1sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_2sm_tmapf', KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', @@ -730,6 +763,7 @@ KernelScheduleSuffixes = { KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm', KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm', @@ -740,6 +774,21 @@ KernelScheduleSuffixes = { KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_ultra_2sm', + + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_1sm_nopf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_2sm_nopf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_1sm_nopf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_2sm_nopf', + + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_1sm_tmapf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_2sm_tmapf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_1sm_tmapf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_2sm_tmapf', + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: '_cooperative_q', KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: '_pingpong_q', KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs16', @@ -817,8 +866,8 @@ EpilogueScheduleSuffixes = { EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', EpilogueScheduleType.TmaWarpSpecialized1Sm: '', EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma', - EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '_tma_1sm', - EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_tma_2sm', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_epi_tma', EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma', EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma', } @@ -855,6 +904,7 @@ def to_grouped_schedule(schedule, grouped): # SM90 KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative, KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative, + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong, KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong, KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum, @@ -874,6 +924,21 @@ def to_grouped_schedule(schedule, grouped): KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, + EpilogueScheduleType.NoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm, + EpilogueScheduleType.NoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm, + # SM103 + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, } return group_schedule_map[schedule] @@ -1020,14 +1085,15 @@ ArchitectureNames = { # SharedMemPerCC = { - 70: 96, # 96KB of SMEM - 72: 96, # 96KB of SMEM - 75: 64, # 64KB of SMEM - 80: 163, # 163KB of SMEM - 1KB reserved for the driver - 86: 99, # 99KB of SMEM - 1KB reserved for the driver - 87: 163, # 163KB of SMEM - 1KB reserved for the driver - 89: 99, # 99KB of SMEM - 1KB reserved for the driver - 90: 227, # 227KB of SMEM - 1KB reserved for the driver + 70: 96, # 96KB of SMEM + 72: 96, # 96KB of SMEM + 75: 64, # 64KB of SMEM + 80: 163, # 163KB of SMEM - 1KB reserved for the driver + 86: 99, # 99KB of SMEM - 1KB reserved for the driver + 87: 163, # 163KB of SMEM - 1KB reserved for the driver + 89: 99, # 99KB of SMEM - 1KB reserved for the driver + 90: 227, # 227KB of SMEM - 1KB reserved for the driver + 100: 227, # 227KB of SMEM - 1KB reserved for the driver } ################################################################################################### diff --git a/python/cutlass_library/manifest.py b/python/cutlass_library/manifest.py index baaaac28..5733ef26 100644 --- a/python/cutlass_library/manifest.py +++ b/python/cutlass_library/manifest.py @@ -570,7 +570,7 @@ class Manifest: self.kernel_filter_list.append(filter_re) - def get_sm90_instantiation_level(self, pruned_level=0, default_level=111, exhaustive_level=9992): + def get_instantiation_level(self, pruned_level=0, default_level=111, exhaustive_level=9992): # Non-negative integer which determines how many kernels are instantiated. # 0 = 0000 generates the fewest kernels, 9999 generates all possible combinations. # increasing first digit reduces schedule / mixed type pruning, diff --git a/python/cutlass_library/sm100_shapes.py b/python/cutlass_library/sm100_shapes.py new file mode 100644 index 00000000..32e43765 --- /dev/null +++ b/python/cutlass_library/sm100_shapes.py @@ -0,0 +1,342 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Valid tcgen05 shapes and cluster sizes for SM100, associated with levels. +These shape and level pairs are defined as dicts, where keys are shapes and values are their +associated levels. If the user input level for that category (tcgen05 shape, cluster +size) is smaller than a shape's associated level, it will be excluded, and otherwise, included. +Higher levels are therefore less likely emitted, but lower levels are more emitted more frequently. +Level 0 is always emitted. +""" + +try: + from .library import DynamicClusterShape +except: + from library import DynamicClusterShape + +SM100_CLUSTER_SHAPES_1SM = { + tuple(DynamicClusterShape) : 0, + # size 1 cluster + (1, 1, 1): 1, + # size 2 cluster + (1, 2, 1): 2, + (2, 1, 1): 5, + # size 4 clusters + (2, 2, 1): 6, + (1, 4, 1): 3, + (4, 1, 1): 6, + # size 8 clusters + (2, 4, 1): 7, + (4, 2, 1): 7, + (1, 8, 1): 8, + (8, 1, 1): 8, + # size 16 cluster + (4, 4, 1): 4, +} + +SM100_CLUSTER_SHAPES_2SM = { + tuple(DynamicClusterShape) : 0, + # size 2 cluster + (2, 1, 1): 1, + # size 4 clusters + (2, 2, 1): 2, + (4, 1, 1): 2, + # size 8 clusters + (2, 4, 1): 3, + (4, 2, 1): 3, + (8, 1, 1): 6, + # size 16 cluster + (4, 4, 1): 4, +} + +# MMA shapes + +# 16b Dense + +SM100_MMA_SHAPES_16b_DENSE_1SM = { + (64, 8, 16): 5, + (64, 16, 16): 2, + (64, 24, 16): 5, + (64, 32, 16): 2, + (64, 40, 16): 5, + (64, 48, 16): 5, + (64, 56, 16): 5, + (64, 64, 16): 2, + (64, 72, 16): 5, + (64, 80, 16): 5, + (64, 88, 16): 5, + (64, 96, 16): 5, + (64, 104, 16): 5, + (64, 112, 16): 5, + (64, 120, 16): 5, + (64, 128, 16): 0, + (64, 136, 16): 5, + (64, 144, 16): 5, + (64, 152, 16): 5, + (64, 160, 16): 5, + (64, 168, 16): 5, + (64, 176, 16): 5, + (64, 184, 16): 5, + (64, 192, 16): 3, + (64, 200, 16): 5, + (64, 208, 16): 5, + (64, 216, 16): 5, + (64, 224, 16): 5, + (64, 232, 16): 5, + (64, 240, 16): 5, + (64, 248, 16): 5, + (64, 256, 16): 3, + + (128, 16, 16): 2, + (128, 32, 16): 2, + (128, 48, 16): 5, + (128, 64, 16): 2, + (128, 80, 16): 5, + (128, 96, 16): 5, + (128, 112, 16): 5, + (128, 128, 16): 0, + (128, 144, 16): 5, + (128, 160, 16): 5, + (128, 176, 16): 5, + (128, 192, 16): 3, + (128, 208, 16): 5, + (128, 224, 16): 5, + (128, 240, 16): 5, + (128, 256, 16): 0, + +} + + +SM100_MMA_SHAPES_16b_DENSE_2SM = { + (128, 32, 16): 2, + (128, 64, 16): 2, + (128, 96, 16): 5, + (128, 128, 16): 0, + (128, 160, 16): 5, + (128, 192, 16): 5, + (128, 224, 16): 5, + (128, 256, 16): 0, + + (256, 32, 16): 2, + (256, 64, 16): 2, + (256, 96, 16): 5, + (256, 128, 16): 0, + (256, 160, 16): 5, + (256, 192, 16): 3, + (256, 224, 16): 5, + (256, 256, 16): 0, +} + +# TF32 Dense + +SM100_MMA_SHAPES_TF32_DENSE_1SM = { + (64, 8, 8): 5, + (64, 16, 8): 2, + (64, 24, 8): 5, + (64, 32, 8): 2, + (64, 40, 8): 5, + (64, 48, 8): 5, + (64, 56, 8): 5, + (64, 64, 8): 1, + (64, 72, 8): 5, + (64, 80, 8): 5, + (64, 88, 8): 5, + (64, 96, 8): 5, + (64, 104, 8): 5, + (64, 112, 8): 5, + (64, 120, 8): 5, + (64, 128, 8): 0, + (64, 136, 8): 5, + (64, 144, 8): 5, + (64, 152, 8): 5, + (64, 160, 8): 5, + (64, 168, 8): 5, + (64, 176, 8): 5, + (64, 184, 8): 5, + (64, 192, 8): 3, + (64, 200, 8): 5, + (64, 208, 8): 5, + (64, 216, 8): 5, + (64, 224, 8): 5, + (64, 232, 8): 5, + (64, 240, 8): 5, + (64, 248, 8): 5, + (64, 256, 8): 3, + + (128, 16, 8): 2, + (128, 32, 8): 2, + (128, 48, 8): 5, + (128, 64, 8): 2, + (128, 80, 8): 5, + (128, 96, 8): 5, + (128, 112, 8): 5, + (128, 128, 8): 0, + (128, 144, 8): 5, + (128, 160, 8): 5, + (128, 176, 8): 5, + (128, 192, 8): 3, + (128, 208, 8): 5, + (128, 224, 8): 5, + (128, 240, 8): 5, + (128, 256, 8): 0, + +} + +SM100_MMA_SHAPES_TF32_DENSE_2SM = { + (128, 32, 8): 2, + (128, 64, 8): 1, + (128, 96, 8): 5, + (128, 128, 8): 0, + (128, 160, 8): 5, + (128, 192, 8): 5, + (128, 224, 8): 5, + (128, 256, 8): 0, + + (256, 32, 8): 2, + (256, 64, 8): 1, + (256, 96, 8): 5, + (256, 128, 8): 0, + (256, 160, 8): 5, + (256, 192, 8): 5, + (256, 224, 8): 5, + (256, 256, 8): 0, +} + +# F8F6F4 +SM100_MMA_SHAPES_F8F6F4_DENSE_1SM = { + (64, 8, 32): 4, + (64, 16, 32): 4, + (64, 24, 32): 5, + (64, 32, 32): 3, + (64, 40, 32): 5, + (64, 48, 32): 5, + (64, 56, 32): 5, + (64, 64, 32): 2, + (64, 72, 32): 5, + (64, 80, 32): 5, + (64, 88, 32): 5, + (64, 96, 32): 5, + (64, 104, 32): 5, + (64, 112, 32): 5, + (64, 120, 32): 5, + (64, 128, 32): 0, + (64, 136, 32): 5, + (64, 144, 32): 5, + (64, 152, 32): 5, + (64, 160, 32): 5, + (64, 168, 32): 5, + (64, 176, 32): 5, + (64, 184, 32): 5, + (64, 192, 32): 5, + (64, 200, 32): 5, + (64, 208, 32): 5, + (64, 216, 32): 5, + (64, 224, 32): 5, + (64, 232, 32): 5, + (64, 240, 32): 5, + (64, 248, 32): 5, + (64, 256, 32): 0, + + (128, 16, 32): 4, + (128, 32, 32): 3, + (128, 48, 32): 5, + (128, 64, 32): 2, + (128, 80, 32): 5, + (128, 96, 32): 5, + (128, 112, 32): 5, + (128, 128, 32): 0, + (128, 144, 32): 5, + (128, 160, 32): 5, + (128, 176, 32): 5, + (128, 192, 32): 5, + (128, 208, 32): 5, + (128, 224, 32): 5, + (128, 240, 32): 5, + (128, 256, 32): 0, + +} + +SM100_MMA_SHAPES_F8F6F4_DENSE_2SM = { + (128, 32, 32): 3, + (128, 64, 32): 2, + (128, 96, 32): 5, + (128, 128, 32): 1, + (128, 160, 32): 5, + (128, 192, 32): 5, + (128, 224, 32): 5, + (128, 256, 32): 1, + + (256, 32, 32): 2, + (256, 64, 32): 2, + (256, 96, 32): 5, + (256, 128, 32): 0, + (256, 160, 32): 5, + (256, 192, 32): 5, + (256, 224, 32): 5, + (256, 256, 32): 0, +} + +# MXF8F6F4 +SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM = { + (128, 64, 32): 1, + (128, 128, 32): 0, + (128, 192, 32): 1, + (128, 256, 32): 0, +} + + +SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM = { + (256, 64, 32): 1, + (256, 128, 32): 0, + (256, 192, 32): 1, + (256, 256, 32): 0, + + +} + +# MXF4NVF4 +SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM = { + (128, 64, 64): 1, + (128, 128, 64): 0, + (128, 192, 64): 1, + (128, 256, 64): 0, +} + +SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM = { + # Multiples of 16 for N + (256, 64, 64): 1, + (256, 128, 64): 0, + (256, 192, 64): 1, + (256, 256, 64): 0, + +} diff --git a/python/cutlass_library/sm100_utils.py b/python/cutlass_library/sm100_utils.py new file mode 100644 index 00000000..9bf24fe7 --- /dev/null +++ b/python/cutlass_library/sm100_utils.py @@ -0,0 +1,661 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for enumerating CUTLASS library SM100 kernels +""" + +import argparse +import enum +from itertools import product +import math +import logging +import os.path +import shutil +import sys +import copy +from typing import Any, Optional, Sequence, Tuple, List, Union, Callable + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +#### Step 0: define levels + +# One integer level controls multiple "generators" and how many +# combinations they generate. That is the "global" level. +# "Generators" are WGMMA shapes, MMA multipliers, cluster sizes, and +# anything that is eventually involved in the Cartesian product +# which yields our kernel configurations. +# For simplicity, each generator defines their own levels, +# starting from 0. As a rule we assume 10 or fewer levels, making +# their level a digit. +# The "global" level simply stacks these digits and represents them +# as a single integer. +# +# For example, level 500 indicates cluster sizes are at level 5, MMA +# multipliers are at level 0, and WGMMA shapes are at level 0 as well. +# +# Here we define the global level to generator level mappings. + + +def get_tcgen05_level_from_global_level(global_level: int): + return global_level % 10 + +def get_mma_level_from_global_level(global_level: int): + return (global_level // 10) % 10 + + +def get_cluster_level_from_global_level(global_level: int): + return (global_level // 100) % 10 + + +def get_pruning_level_from_global_level(global_level: int): + return (global_level // 1000) % 10 + + +#### Step 1: generate MMA instruction shapes based on levels + +try: + from .sm100_shapes import * +except: + from sm100_shapes import * + +########### + +def generate_tf32_math_instructions_sm100(level: int): + """ + Generate all TensorOp math instructions for TF32 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_TF32_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_TF32_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + for shape in shapes_2sm: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_16b_math_instructions_sm100(level: int): + """ + Generate all TensorOp math instructions for 16b MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_16b_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_16b_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + + for shape in shapes_2sm: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + + +def generate_fp8_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all TensorOp math instructions for FP8 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + + tcgen05_level = get_tcgen05_level_from_global_level(level) + pruning_level = get_pruning_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if enable_compile_time_dtype: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if pruning_level >= 2: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + for shape in shapes_2sm: + if enable_runtime_dtype: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if enable_compile_time_dtype: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if pruning_level >= 2: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_f8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all TensorOp math instructions for FP8 FP6 and FP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, DataType.e5m2, DataType.e3m2, DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, DataType.e5m2, DataType.e3m2, DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_mxf8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all BlockScaledTensorOp math instructions for MXFP8, MXFP6, and MXFP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + + tcgen05_level = get_tcgen05_level_from_global_level(level) + pruning_level = get_pruning_level_from_global_level(level) + + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + + if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)): + continue + + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, + DataType.e5m2, + DataType.e3m2, + DataType.e2m3, + DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + + if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)): + continue + + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, + DataType.e5m2, + DataType.e3m2, + DataType.e2m3, + DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_mxf4nvf4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all BlockScaledTensorOp math instructions for MXFP4 and MXFP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e2m1, + ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e2m1, + ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + return math_instructions_1sm, math_instructions_2sm + + +def generate_cluster_shapes_sm100(level: int, change_priority_func : Union[Callable, None] = None): + """ + Generate all cluster shapes for SM100 at or above the given level. + + Args: + level: The global level to generate cluster shapes for. + + Returns: + A tuple of two lists of cluster shapes. + The first list contains the cluster shapes for 1SM, and the second list contains the cluster shapes for 2SM. + """ + cluster_level = get_cluster_level_from_global_level(level) + + assert cluster_level >= 4 + + if change_priority_func is not None: + SM100_CLUSTER_SHAPES_1SM_CPY = copy.deepcopy(SM100_CLUSTER_SHAPES_1SM) + SM100_CLUSTER_SHAPES_2SM_CPY = copy.deepcopy(SM100_CLUSTER_SHAPES_2SM) + change_priority_func(SM100_CLUSTER_SHAPES_1SM_CPY, SM100_CLUSTER_SHAPES_2SM_CPY) + shapes_1sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_1SM_CPY.items() if cluster_level >= min_level + ] + shapes_2sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_2SM_CPY.items() if cluster_level >= min_level + ] + + return shapes_1sm, shapes_2sm + + else: + + shapes_1sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_1SM.items() if cluster_level >= min_level + ] + shapes_2sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_2SM.items() if cluster_level >= min_level + ] + + return shapes_1sm, shapes_2sm diff --git a/python/cutlass_library/sm90_utils.py b/python/cutlass_library/sm90_utils.py index 3bf3edb2..fc5fdf14 100644 --- a/python/cutlass_library/sm90_utils.py +++ b/python/cutlass_library/sm90_utils.py @@ -637,7 +637,10 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, if CudaToolkitVersionSatisfies(cuda_version, 12, 1): # Pruning: don't stamp out fp8 ping-pong kernel with non-tma epilogue if not is_fp8 or level >= 1: - schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)]) + if not is_blockwise(gemm_kind): + schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)]) + else: + schedules.append([to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)]) if can_do_fp8_fast_accum: if not grouped: diff --git a/python/setup_cutlass.py b/python/setup_cutlass.py index 8122b7a6..acc0c46e 100644 --- a/python/setup_cutlass.py +++ b/python/setup_cutlass.py @@ -51,7 +51,7 @@ setup_pycute.perform_setup() setup( name='cutlass_cppgen', - version='4.0.0', + version='4.2.0', description='CUTLASS Pythonic Interface', package_dir={'': '.'}, packages=[ diff --git a/python/setup_library.py b/python/setup_library.py index 875ba62d..75ae8ec0 100644 --- a/python/setup_library.py +++ b/python/setup_library.py @@ -36,7 +36,7 @@ from setuptools import setup def perform_setup(): setup( name='cutlass_library', - version='4.1.0', + version='4.2.0', description='CUTLASS library generation scripts', packages=['cutlass_library'] ) diff --git a/python/setup_pycute.py b/python/setup_pycute.py index 7e8a99e0..79acef3d 100644 --- a/python/setup_pycute.py +++ b/python/setup_pycute.py @@ -36,7 +36,7 @@ from setuptools import setup def perform_setup(): setup( name='pycute', - version='4.1.0', + version='4.2.0', description='Python implementation of CuTe', packages=['pycute'], ) diff --git a/setup.cfg b/setup.cfg index 98791943..d2394fe6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = nvidia-cutlass -version = 4.0.0.0 +version = 4.2.0.0 [options] packages = diff --git a/test/python/cutlass/gemm/gemm_testbed.py b/test/python/cutlass/gemm/gemm_testbed.py index 50eb7a9b..6ffda5b4 100644 --- a/test/python/cutlass/gemm/gemm_testbed.py +++ b/test/python/cutlass/gemm/gemm_testbed.py @@ -153,7 +153,7 @@ class GemmUniversalLauncher: else: data_cutlass = data_ref.transpose(-1, -2).contiguous() - data_cutlass = data_cutlass_cppgen.to("cuda") + data_cutlass = data_cutlass.to("cuda") # As of this writing, few operations in PyTorch are supported with FP8 data. # Thus, we perform computation in FP32 for FP8 reference checks. diff --git a/test/python/cutlass/interface/gemm_interface.py b/test/python/cutlass/interface/gemm_interface.py index 723c4c07..2913d593 100644 --- a/test/python/cutlass/interface/gemm_interface.py +++ b/test/python/cutlass/interface/gemm_interface.py @@ -240,8 +240,8 @@ class GemmErrorTests(unittest.TestCase): """ cc = device_cc() - # F64 Tensor Core operations are only avaiable on devices with CC >= 80 - supports_tensorop_f64 = cc >= 80 + # F64 Tensor Core operations are only avaiable on certain devices + supports_tensorop_f64 = cc in [80, 89, 90] plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f64, layout=cutlass_cppgen.LayoutType.RowMajor) error_msg = f'Incorrectly raised an exception for availability of TensorOp with F64 operands on SM{cc}' @@ -288,7 +288,7 @@ class GemmErrorTests(unittest.TestCase): with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'): td.stages = 3 plan.construct(td) - else: + elif cc == 90: original_kschedule = td.kernel_schedule original_eschedule = td.epilogue_schedule with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): @@ -296,10 +296,13 @@ class GemmErrorTests(unittest.TestCase): td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.NoSmemWarpSpecialized td.stages = 3 plan.construct(td) - # Reset schedules td.kernel_schedule = original_kschedule td.epilogue_schedule = original_eschedule + elif cc in [100, 101, 103]: + with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): + td.stages = 3 + plan.construct(td) with ExpectException(True, f'Requested too many stages'): td.stages = 100 @@ -321,12 +324,12 @@ class GemmErrorTests(unittest.TestCase): td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized plan.construct(td) - with ExpectException(True, f'Requested a non-auto kernel schedule with an auto epilogue schedule'): + with ExpectException(cc == 90, f'Requested a non-auto kernel schedule with an auto epilogue schedule'): td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.ScheduleAuto plan.construct(td) - with ExpectException(True, f'Requested an auto kernel schedule with a non-auto epilogue schedule'): + with ExpectException(cc == 90, f'Requested an auto kernel schedule with a non-auto epilogue schedule'): td.kernel_schedule = cutlass_cppgen.KernelScheduleType.ScheduleAuto td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized plan.construct(td) diff --git a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp index d106a53d..cbc54ec5 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp @@ -2276,13 +2276,16 @@ bool TestSmall(double alpha = 1.0, double beta = 1.0, using ElementA = typename Gemm::GemmKernel::ElementA; using ElementB = typename Gemm::GemmKernel::ElementB; using TiledMma = typename Gemm::GemmKernel::TiledMma; - int alignment_bits = 128; static constexpr bool IsF8F6F4 = cutlass::gemm::collective::detail::is_sm100_mma_f8f6f4(); - alignment_bits = cutlass::detail::get_input_alignment_bits(); - // For fp4 and fp6 kernels, the min alignment_input is 128 elements, so we don't need to add alignment_input in test problem sizes. - int alignment_input = (alignment_bits / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits / cute::sizeof_bits::value); - + // For fp4 and fp6 kernels, the min alignment_input is 128 elements, so we don't need to add alignment_input in test problem sizes. + int alignment_bits_a = cutlass::detail::get_input_alignment_bits(); + int alignment_input_a = (alignment_bits_a / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits_a / cute::sizeof_bits::value); + + int alignment_bits_b = cutlass::detail::get_input_alignment_bits(); + int alignment_input_b = (alignment_bits_b / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits_b / cute::sizeof_bits::value); + + int alignment_input = (alignment_input_a == 0 || alignment_input_b == 0) ? 0 : std::max(alignment_input_a, alignment_input_b); if constexpr (apply_alignment_offset) { // If BlockScaled, then min alignment is SFVecSize diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/CMakeLists.txt b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/CMakeLists.txt index 27b5a00a..05a6d388 100644 --- a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/CMakeLists.txt +++ b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/CMakeLists.txt @@ -71,6 +71,7 @@ cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_bs_grouped_gemm_device_tensorop_sm120 sm120_bs_gemm_nvf4_nvf4_f32_nvf4_group_gemm_fusion.cu + sm120_bs_gemm_mxf8_mxf4_f32_group_gemm_fusion.cu ) endif() diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf8_mxf4_f32_group_gemm_fusion.cu b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf8_mxf4_f32_group_gemm_fusion.cu new file mode 100644 index 00000000..34907fe4 --- /dev/null +++ b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf8_mxf4_f32_group_gemm_fusion.cu @@ -0,0 +1,362 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +/*! \file + \brief Tests for device-wide grouped GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/thread/activation.h" + +#include "../../../common/cutlass_unit_test.h" +#include "../gemm_testbed_3x_ptr_array.hpp" + + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +// Pingpong kernel schedule +TEST(SM120_Device_Gemm_e5m2t_e2m1n_e2m1t_tensorop_f32_epilogue_VS32_group_pingpong, row_sf) { + using ElementInputA = float_e5m2_t; + using ElementInputB = float_e2m1_t; + using ElementA = cutlass::mx_float8_t; + using ElementB = cutlass::mx_float4_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::float_e2m1_t; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementSF = cutlass::float_ue8m0_t; + using ElementSFD = ElementSF; + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::RowMajor; + using GmemLayoutB = cutlass::layout::ColumnMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + constexpr int SFVectorSize = 32; + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentB = 128; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // + // Construct CollectiveEpilogue + // + + constexpr int OutputSFVectorSize = SFVectorSize; + // D = alpha * acc + beta * C + // With Row-major BlockScaleFactor generation. + using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + OutputSFVectorSize, + ElementD, + ElementCompute, + ElementSFD, GmemLayoutC, + ElementC>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC *, AlignmentC, + ElementD, GmemLayoutC *, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + ElementA, GmemLayoutA *, AlignmentA, + ElementB, GmemLayoutB *, AlignmentB, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmallFusion(1.0, 0.5); + EXPECT_TRUE(pass); +} + + + +TEST(SM120_Device_Gemm_e5m2t_e2m1n_e2m1t_tensorop_f32_epilogue_VS32_group_pingpong, silu_row_sf) { + using ElementInputA = float_e5m2_t; + using ElementInputB = float_e2m1_t; + using ElementA = cutlass::mx_float8_t; + using ElementB = cutlass::mx_float4_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::float_e2m1_t; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementSF = cutlass::float_ue4m3_t; + using ElementSFD = ElementSF; + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::RowMajor; + using GmemLayoutB = cutlass::layout::ColumnMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + constexpr int SFVectorSize = 32; + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentB = 128; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // + // Construct CollectiveEpilogue + // + + constexpr int OutputSFVectorSize = SFVectorSize; + // D = SiLu(alpha * acc + beta * C) + // With Row-major BlockScaleFactor generation. + using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor< + cutlass::epilogue::thread::SiLu, + OutputSFVectorSize, + ElementD, + ElementCompute, + ElementSFD, GmemLayoutC, + ElementC>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC *, AlignmentC, + ElementD, GmemLayoutC *, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + ElementA, GmemLayoutA *, AlignmentA, + ElementB, GmemLayoutB *, AlignmentB, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmallFusion(1.0, 0.5); + EXPECT_TRUE(pass); +} + + +// Cooperative kenel schedule +TEST(SM120_Device_Gemm_e5m2t_e2m1n_e2m1t_tensorop_f32_epilogue_VS32_group_cooperative, row_sf) { + using ElementInputA = float_e5m2_t; + using ElementInputB = float_e2m1_t; + using ElementA = cutlass::mx_float8_t; + using ElementB = cutlass::mx_float4_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::float_e2m1_t; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementSF = cutlass::float_ue4m3_t; + using ElementSFD = ElementSF; + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::RowMajor; + using GmemLayoutB = cutlass::layout::ColumnMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + constexpr int SFVectorSize = 32; + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentB = 128; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // + // Construct CollectiveEpilogue + // + + constexpr int OutputSFVectorSize = SFVectorSize; + // D = alpha * acc + beta * C + // With Row-major BlockScaleFactor generation. + using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + OutputSFVectorSize, + ElementD, + ElementCompute, + ElementSFD, GmemLayoutC, + ElementC>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC *, AlignmentC, + ElementD, GmemLayoutC *, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + ElementA, GmemLayoutA *, AlignmentA, + ElementB, GmemLayoutB *, AlignmentB, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmallFusion(1.0, 0.5); + EXPECT_TRUE(pass); +} + + + +TEST(SM120_Device_Gemm_e5m2t_e2m1n_e2m1t_tensorop_f32_epilogue_VS32_group_cooperative, silu_row_sf) { + using ElementInputA = float_e5m2_t; + using ElementInputB = float_e2m1_t; + using ElementA = cutlass::mx_float8_t; + using ElementB = cutlass::mx_float4_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::float_e2m1_t; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementSF = cutlass::float_ue4m3_t; + using ElementSFD = ElementSF; + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::RowMajor; + using GmemLayoutB = cutlass::layout::ColumnMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + constexpr int SFVectorSize = 32; + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentB = 128; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // + // Construct CollectiveEpilogue + // + + constexpr int OutputSFVectorSize = SFVectorSize; + // D = SiLu(alpha * acc + beta * C) + // With Row-major BlockScaleFactor generation. + using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor< + cutlass::epilogue::thread::SiLu, + OutputSFVectorSize, + ElementD, + ElementCompute, + ElementSFD, GmemLayoutC, + ElementC>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC *, AlignmentC, + ElementD, GmemLayoutC *, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + ElementA, GmemLayoutA *, AlignmentA, + ElementB, GmemLayoutB *, AlignmentB, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmallFusion(1.0, 0.5); + EXPECT_TRUE(pass); +} +#endif // #if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)