From fd6cfe1ed0f4e511ca78aeb63abbd494c0efde68 Mon Sep 17 00:00:00 2001 From: Junkai-Wu Date: Tue, 22 Jul 2025 10:03:55 +0800 Subject: [PATCH] v4.1 release update v2. (#2481) --- CHANGELOG.md | 23 +- FUNCTIONALITY.md | 30 + README.md | 19 +- .../volta_tensorop_gemm.cu | 2 +- .../turing_tensorop_gemm.cu | 2 +- .../turing_tensorop_conv2dfprop.cu | 2 +- .../device/b2b_implicit_gemm_convolution.h | 2 +- .../fused_two_gemms_grouped_f16_sm80_rf.cu | 2 +- .../13_two_tensor_op_fusion/kernel/b2b_gemm.h | 6 +- .../kernel/default_b2b_conv2d_fprop_sm80.h | 2 +- ...t_b2b_conv2d_fprop_smem_accumulator_sm80.h | 2 +- .../threadblock/b2b_mma_pipelined.h | 2 +- .../b2b_mma_pipelined_smem_accumulator.h | 2 +- .../threadblock/default_b2b_mma.h | 2 +- .../ampere_tensorop_conv2dfprop.cu | 2 +- .../ampere_fp64_tensorop_affine2_gemm.cu | 2 +- .../tensorop_canonical.cu | 2 +- examples/20_simt_canonical/simt_canonical.cu | 2 +- .../gather_scatter_fusion.cu | 2 +- .../gemm_with_layernorm.h | 2 +- examples/40_cutlass_py/customizable/conv2d.py | 2 +- examples/40_cutlass_py/customizable/gemm.py | 2 +- .../fused_multihead_attention_fixed_seqlen.cu | 2 +- ...sed_multihead_attention_variable_seqlen.cu | 2 +- .../gemm/find_default_mma.h | 4 +- .../gemm/mma_from_smem.h | 2 +- .../default_warp_iterator_from_smem.h | 2 +- .../predicated_tile_iterator_residual_last.h | 2 +- .../ell_block_sparse_gemm.cu | 2 +- .../ir_gen/gen_device.py | 2 +- .../ir_gen/gen_verify.py | 4 +- .../ir_gen/helper.py | 4 +- .../49_collective_builder.cu | 2 +- .../50_hopper_gemm_with_epilogue_swizzle.cu | 2 +- .../52_hopper_gather_scatter_fusion.cu | 2 +- .../53_hopper_gemm_permute.cu | 2 +- .../59_ampere_gather_scatter_conv/README.md | 2 +- .../ampere_gather_scatter_conv.cu | 4 +- .../ada_fp8_gemm_grouped.cu | 2 +- ...specialized_gemm_with_blockwise_scaling.cu | 2 +- .../blackwell_gemm_streamk.cu | 2 +- .../75_blackwell_grouped_gemm_block_scaled.cu | 2 +- .../76_blackwell_conv_dgrad.cu | 2 +- .../76_blackwell_conv_fprop.cu | 2 +- .../76_blackwell_conv_wgrad.cu | 2 +- .../77_blackwell_fmha/77_blackwell_fmha.cu | 6 +- .../77_blackwell_fmha_bwd.cu | 123 +- .../77_blackwell_fmha/77_blackwell_mla_fwd.cu | 6 +- examples/77_blackwell_fmha/CMakeLists.txt | 4 + examples/77_blackwell_fmha/README.md | 8 +- .../device/fmha_device_bwd.hpp | 26 +- .../kernel/fmha_kernel_bwd_convert.hpp | 14 +- .../kernel/fmha_kernel_bwd_sum_OdO.hpp | 6 +- ...00_fmha_bwd_kernel_tma_warpspecialized.hpp | 58 +- ...mha_bwd_mla_kernel_tma_warpspecialized.hpp | 1813 ++++++++++++ .../reference/fmha_bwd_reference.hpp | 77 +- .../reference/fmha_fwd_reference.hpp | 14 +- ...9d_blackwell_geforce_nvfp4_grouped_gemm.cu | 2 +- .../cute/tutorial/blackwell/01_mma_sm100.cu | 4 +- .../tutorial/blackwell/02_mma_tma_sm100.cu | 4 +- .../blackwell/03_mma_tma_multicast_sm100.cu | 4 +- .../blackwell/04_mma_tma_2sm_sm100.cu | 4 +- .../blackwell/05_mma_tma_epi_sm100.cu | 4 +- examples/cute/tutorial/tiled_copy.cu | 2 +- .../python/CuTeDSL/ampere/elementwise_add.py | 25 +- .../CuTeDSL/ampere/elementwise_apply.py | 2 +- .../CuTeDSL/ampere/flash_attention_v2.py | 212 +- examples/python/CuTeDSL/ampere/sgemm.py | 131 +- .../python/CuTeDSL/ampere/tensorop_gemm.py | 150 +- .../dense_blockscaled_gemm_persistent.py | 2467 +++++++++++++++++ .../python/CuTeDSL/blackwell/dense_gemm.py | 10 +- .../blackwell/dense_gemm_persistent.py | 128 +- .../blackwell/dense_gemm_software_pipeline.py | 8 +- examples/python/CuTeDSL/blackwell/fmha.py | 175 +- .../python/CuTeDSL/blackwell/grouped_gemm.py | 373 ++- .../blackwell/mamba2_ssd/mamba2_ssd.py | 271 +- examples/python/CuTeDSL/hopper/dense_gemm.py | 159 +- examples/python/CuTeDSL/notebooks/print.ipynb | 64 + .../python/CuTeDSL/notebooks/tensor.ipynb | 10 - .../python/deprecated/03_basic_conv2d.ipynb | 2 +- include/cutlass/arch/wmma.h | 2 +- include/cutlass/arch/wmma_sm70.h | 2 +- include/cutlass/arch/wmma_sm72.h | 4 +- include/cutlass/arch/wmma_sm75.h | 4 +- include/cutlass/array_subbyte.h | 2 +- .../conv/kernel/default_conv2d_fprop.h | 4 +- .../conv/kernel/default_conv2d_fprop_fusion.h | 2 +- .../conv/kernel/default_conv3d_fprop_fusion.h | 2 +- include/cutlass/cuda_host_adapter.hpp | 4 +- .../sm120_blockscaled_sparse_mma_builder.inl | 2 +- .../sm120_blockscaled_sparse_mma_tma.hpp | 4 +- .../gemm/collective/sm120_sparse_mma_tma.hpp | 2 +- .../gemm/collective/sm80_mma_multistage.hpp | 2 +- include/cutlass/gemm/device/ell_gemm.h | 4 +- include/cutlass/gemm/device/gemm.h | 2 +- include/cutlass/gemm/device/gemm_array.h | 2 +- include/cutlass/gemm/device/gemm_batched.h | 2 +- include/cutlass/gemm/device/gemm_complex.h | 2 +- .../gemm/device/gemm_splitk_parallel.h | 2 +- include/cutlass/gemm/device/gemm_universal.h | 2 +- .../gemm/device/gemm_universal_adapter.h | 2 +- .../gemm_universal_streamk_with_broadcast.h | 2 +- .../gemm/device/gemm_universal_with_absmax.h | 2 +- .../device/gemm_universal_with_broadcast.h | 2 +- .../gemm/device/gemm_with_k_reduction.h | 2 +- include/cutlass/gemm/device/rank_2k.h | 2 +- include/cutlass/gemm/device/rank_k.h | 2 +- include/cutlass/gemm/device/symm.h | 2 +- include/cutlass/gemm/device/trmm.h | 6 +- .../gemm/kernel/default_symm_complex.h | 4 +- .../gemm/kernel/gemv_batched_strided.h | 2 +- .../kernel/sm100_tile_scheduler_group.hpp | 2 +- .../kernel/sm100_tile_scheduler_stream_k.hpp | 2 +- ...specialized_cooperative_asymmetric_dma.hpp | 2 +- .../gemm/kernel/tile_scheduler_params.h | 6 +- .../gemm/threadblock/default_ell_mma.h | 6 +- .../cutlass/gemm/threadblock/default_mma.h | 6 +- .../gemm/threadblock/default_mma_core_sm80.h | 2 +- .../default_mma_layernorm_mainloop_fusion.h | 2 +- .../default_mma_softmax_mainloop_fusion.h | 2 +- .../threadblock/default_mma_with_reduction.h | 2 +- .../gemm/threadblock/default_sparse_mma.h | 4 +- .../cutlass/gemm/threadblock/default_trmm.h | 10 +- include/cutlass/gemm/threadblock/gemv.h | 2 +- .../mma_planar_complex_multistage.h | 2 +- .../gemm/threadblock/mma_singlestage.h | 2 +- .../threadblock/threadblock_swizzle_streamk.h | 2 +- .../cutlass/gemm/warp/mma_complex_tensor_op.h | 2 +- .../warp/mma_complex_tensor_op_fast_f32.h | 2 +- .../gemm/warp/mma_mixed_input_tensor_op.h | 6 +- .../cutlass/gemm/warp/mma_sparse_tensor_op.h | 2 +- .../warp/mma_tensor_op_fragment_iterator.h | 2 +- .../warp/mma_tensor_op_tile_access_iterator.h | 2 +- include/cutlass/half.h | 2 +- include/cutlass/kernel_hardware_info.h | 2 +- include/cutlass/matrix.h | 78 +- include/cutlass/numeric_conversion.h | 4 +- .../collective/sm90_wgmma_transpose.hpp | 4 +- .../transform/pitch_linear_thread_map.h | 12 +- .../ell_predicated_tile_iterator.h | 2 +- .../predicated_tile_access_iterator.h | 2 +- .../threadblock/predicated_tile_iterator.h | 2 +- .../predicated_tile_iterator_2dthreadtile.h | 6 +- ...edicated_tile_iterator_triangular_matrix.h | 4 +- .../transform/warp/vector_fragment_iterator.h | 2 +- media/docs/pythonDSL/cute_dsl.rst | 1 + .../cute_dsl_general/dsl_control_flow.rst | 2 +- .../dsl_jit_compilation_options.rst | 50 + media/docs/pythonDSL/limitations.rst | 24 +- python/CuTeDSL/base_dsl/ast_helpers.py | 45 +- python/CuTeDSL/base_dsl/ast_preprocessor.py | 19 +- python/CuTeDSL/base_dsl/compiler.py | 65 + python/CuTeDSL/base_dsl/dsl.py | 136 +- python/CuTeDSL/base_dsl/typing.py | 20 +- python/CuTeDSL/cutlass/cute/__init__.py | 3 + python/CuTeDSL/cutlass/cute/arch/__init__.py | 1 + .../cutlass/cute/arch/nvvm_wrappers.py | 102 +- python/CuTeDSL/cutlass/cute/arch/smem.py | 12 + python/CuTeDSL/cutlass/cute/core.py | 230 +- .../cutlass/cute/nvgpu/cpasync/helpers.py | 6 - python/CuTeDSL/cutlass/cute/nvgpu/helpers.py | 8 +- .../cutlass/cute/nvgpu/tcgen05/__init__.py | 5 + .../cutlass/cute/nvgpu/tcgen05/copy.py | 192 ++ .../cutlass/cute/nvgpu/tcgen05/helpers.py | 27 + .../CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py | 418 ++- python/CuTeDSL/cutlass/cute/testing.py | 56 +- python/CuTeDSL/cutlass/pipeline/sm90.py | 245 -- python/CuTeDSL/cutlass/torch.py | 2 +- python/CuTeDSL/cutlass/utils/__init__.py | 15 + .../CuTeDSL/cutlass/utils/ampere_helpers.py | 8 + .../cutlass/utils/blackwell_helpers.py | 106 +- .../cutlass/utils/blockscaled_layout.py | 287 ++ .../CuTeDSL/cutlass/utils/hopper_helpers.py | 28 +- .../CuTeDSL/cutlass/utils/smem_allocator.py | 11 +- python/CuTeDSL/cutlass/utils/smem_capacity.py | 26 + python/CuTeDSL/cutlass_dsl/cutlass.py | 37 +- python/CuTeDSL/requirements.txt | 2 +- python/cutlass_library/emit_kernel_listing.py | 26 +- .../blockwise_gemm_reference_operation.h | 145 +- 179 files changed, 7878 insertions(+), 1286 deletions(-) create mode 100644 FUNCTIONALITY.md create mode 100644 examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp create mode 100644 examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py create mode 100644 media/docs/pythonDSL/cute_dsl_general/dsl_jit_compilation_options.rst create mode 100644 python/CuTeDSL/cutlass/utils/blockscaled_layout.py create mode 100644 python/CuTeDSL/cutlass/utils/smem_capacity.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 67cc633a..4bb775ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,23 +2,15 @@ # CUTLASS 4.x -## [4.1.0](https://github.com/NVIDIA/cutlass/tree/main) (2025-06-30) +## [4.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.1.0) (2025-07-16) ### CuTe DSL +* Add aarch64 support, you can now pip install `nvidia-cutlass-dsl` on GB200 systems! * More examples demonstrating how to use CuTe DSL to write peak-performance kernels - [Blackwell Mamba2 SSD](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py) + - [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py) * API updates - - for loop - - Python built-in ``range`` now always generates IR and executes at runtime - - ``cutlass.range`` is advanced ``range`` with IR level unrolling and pipelining control - - Deprecated ``cutlass.range_dynamic``, please replace with ``range`` or ``cutlass.range`` - - **Experimental** Added ``pipelining`` control for compiler generated software pipeline code - - while/if - - ``while``/``if`` now by default generates IR and executes at runtime unless ``cutlass.const_expr`` is specified for the predicate - - Deprecated ``cutlass.dynamic_expr``, please remove it - - Rename mbarrier functions to reduce ambiguity - - Modify SyncObject API (`MbarrierArray`, `NamedBarrier`, `TmaStoreFence`) to match `std::barrier` - - Change pipeline `create` function to take only keyword arguments, and make `barrier_storage` optional. + - Please refer to [FUNCTIONALITY.md](https://github.com/NVIDIA/cutlass/blob/main/FUNCTIONALITY.md) for details ### CUTLASS C++ * Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). @@ -31,7 +23,7 @@ - Remove buggy and kludgy `get_layoutA|B|C_MN` and friends from Atoms/TiledX. - Factor out `print_latex` and friends and rewrite. - Factor out `print_svg` and friends and rewrite. -* Support Blackwell SM100 SIMT FFMA2 kernels. +* Support Blackwell SM100 SIMT packed fp32x2 kernels. * Support residual add for implicit gemm kernels. * Various fixes for CUTLASS C++ Python interface's EVT tracer: - Add verifier for sm90 to report the invalid input. @@ -41,6 +33,9 @@ * Fix profiler bugs in exhaustive perf search. - Fix incorrect cluster shape output issue when doing exhaustive search. - Fix a bug in profiler grouped GEMM for setting tile scheduler swizzles, cluster shapes, and raster orders. +* Fix some profiler issues. + - Complete the reference for Blackwell blockwise gemm kernels. + - Fix incorrect regex logic for L1 test. ## [4.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.0.0) (2025-06-03) @@ -61,7 +56,7 @@ - [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/ffi/jit_argument.py) * [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks) * API updates - - Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer`` + - Please refer to [FUNCTIONALITY.md](https://github.com/NVIDIA/cutlass/blob/main/FUNCTIONALITY.md) for details ### CUTLASS C++ * Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9 diff --git a/FUNCTIONALITY.md b/FUNCTIONALITY.md new file mode 100644 index 00000000..a038b0cb --- /dev/null +++ b/FUNCTIONALITY.md @@ -0,0 +1,30 @@ +# Changelog for CuTe DSL API changes + +## [4.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.1.0) (2025-07-16) + + * for loop + - Python built-in ``range`` now always generates IR and executes at runtime + - ``cutlass.range`` is advanced ``range`` with IR level unrolling and pipelining control + - Deprecated ``cutlass.range_dynamic``, please replace with ``range`` or ``cutlass.range`` + - **Experimental** Added ``pipelining`` control for compiler generated software pipeline code + * while/if + - ``while``/``if`` now by default generates IR and executes at runtime unless ``cutlass.const_expr`` is specified for the predicate + - Deprecated ``cutlass.dynamic_expr``, please remove it + * Rename mbarrier functions to reduce ambiguity + * Modify SyncObject API (`MbarrierArray`, `NamedBarrier`, `TmaStoreFence`) to match `std::barrier` + * Change pipeline `create` function to take only keyword arguments, and make `barrier_storage` optional. + * Introduce `cutlass.cute.arch.get_dyn_smem_size` api to get runtime dynamic shared memory size. + * Various API Support for SM100 BlockScaled Gemm + - Introduce BlockScaled MmaOps in [tcgen05/mma.py]([https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py]), and provide a `make_blockscaled_trivial_tiled_mma` function in [blackwell_helpers.py](https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/utils/blackwell_helpers.py) to help construct a BlockScaled TiledMma. + - Introduce S2T CopyOps in [tcgen05/copy.py](https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py). + - Introduce BlockScaled layout utilities in [blockscaled_layout.py](https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/utils/blockscaled_layout.py) for creating the required scale factor layouts in global memory, shared memory and tensor memory. + * `cutlass.cute.compile` now supports compilation options. Refer to [JIT compilation options](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_jit_compilation_options.html) for more details. + * `cutlass.cute.testing.assert_` now works for device JIT function. Specify `--enable-device-assertions` as compilation option to enable. + * `cutlass.cute.make_tiled_copy` is now deprecated. Please use `cutlass.cute.make_tiled_copy_tv` instead. + * Shared memory capacity query + - Introduce `cutlass.utils.get_smem_capacity_in_bytes` for querying the shared memory capacity. + - `_utils.SMEM_CAPACITY[""]` is now deprecated. + +## [4.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.0.0) (2025-06-03) + + * Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer`` diff --git a/README.md b/README.md index 709dffd7..c9807a96 100644 --- a/README.md +++ b/README.md @@ -46,20 +46,12 @@ To get started quickly - please refer : # What's New in CUTLASS 4.1 ## CuTe DSL +* Add aarch64 support, you can now pip install `nvidia-cutlass-dsl` on GB200 systems! * More examples demonstrating how to use CuTe DSL to write peak-performance kernels - [Blackwell Mamba2 SSD](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py) + - [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py) * API updates - - for loop - - Python built-in ``range`` now always generates IR and executes at runtime - - ``cutlass.range`` is advanced ``range`` with IR level unrolling and pipelining control - - Deprecated ``cutlass.range_dynamic``, please replace with ``range`` or ``cutlass.range`` - - **Experimental** Added ``pipelining`` control for compiler generated software pipeline code - - while/if - - ``while``/``if`` now by default generates IR and executes at runtime unless ``cutlass.const_expr`` is specified for the predicate - - Deprecated ``cutlass.dynamic_expr``, please remove it - - Rename mbarrier functions to reduce ambiguity - - Modify SyncObject API (`MbarrierArray`, `NamedBarrier`, `TmaStoreFence`) to match `std::barrier` - - Change pipeline `create` function to take only keyword arguments, and make `barrier_storage` optional. + - Please refer to [FUNCTIONALITY.md](https://github.com/NVIDIA/cutlass/blob/main/FUNCTIONALITY.md) for details ## CUTLASS C++ * Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). @@ -72,7 +64,7 @@ To get started quickly - please refer : - Remove buggy and kludgy `get_layoutA|B|C_MN` and friends from Atoms/TiledX. - Factor out `print_latex` and friends and rewrite. - Factor out `print_svg` and friends and rewrite. -* Support Blackwell SM100 SIMT FFMA2 kernels. +* Support Blackwell SM100 SIMT packed fp32x2 kernels. * Support residual add for implicit gemm kernels. * Various fixes for CUTLASS C++ Python interface's EVT tracer: - Add verifier for sm90 to report the invalid input. @@ -82,6 +74,9 @@ To get started quickly - please refer : * Fix profiler bugs in exhaustive perf search. - Fix incorrect cluster shape output issue when doing exhaustive search. - Fix a bug in profiler grouped GEMM for setting tile scheduler swizzles, cluster shapes, and raster orders. +* Fix some profiler issues. + - Complete the reference for Blackwell blockwise gemm kernels. + - Fix incorrect regex logic for L1 test. Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits. CUTLASS team is working on a fix. diff --git a/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu b/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu index d92d423b..8bad0bbd 100644 --- a/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu +++ b/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu @@ -64,7 +64,7 @@ ElementAccumulator (float), ElementComputeEpilogue (float), ElementInputA (cutla ElementInputB (cutlass::half_t), ElementOutput (float). Communicating just the data type is not enough. As the data is laid out linearly in memory, we have to convey the layout of matrices. We do that by initializing template variable LayoutInputA to column major cutlass variable, LayoutInputB -to row major and LayoutOutput to row major. Next, we setup rules to comptue alpha * X + beta * C +to row major and LayoutOutput to row major. Next, we setup rules to compute alpha * X + beta * C which is called epilogue of the kernel. We initialize template variable EpilogueOp, which takes the data type of output ElementOutput (int32_t), the number of elements per vector memory access (16), data type of accumulator (int32_t) and data type of computation of linear combination (alpha * X + diff --git a/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu b/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu index cdb6c679..70afef7e 100644 --- a/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu +++ b/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu @@ -64,7 +64,7 @@ ElementComputeEpilogue (int32_t), ElementInputA (int8_t), ElementInputB (int8_t) (int32_t). Communicating just the data type is not enough. As the data is laid out linearly in memory, we have to convey the layout of matrices. We do that by initializing template variable LayoutInputA to column major cutlass variable, LayoutInputB to row major and LayoutOutput to row -major. Next, we setup rules to comptue alpha * X + beta * C which is called epilogue of the kernel. +major. Next, we setup rules to compute alpha * X + beta * C which is called epilogue of the kernel. We initialize template variable EpilogueOp, which takes the data type of output ElementOutput (int32_t), the number of elements per vector memory access (16), data type of accumulator (int32_t) and data type of computation of linear combination (alpha * X + beta * C). diff --git a/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu b/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu index cdb3c310..b5889645 100644 --- a/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu +++ b/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu @@ -66,7 +66,7 @@ ElementComputeEpilogue (float), ElementInputA (cutlass::int4b_t), ElementInputB ElementOutput (int32_t). Communicating just the data type is not enough. As the data is laid out linearly in memory, we have to convey the layout of tensors. We do that by initializing template variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup -rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template +rules to compute alpha * X + beta * C which is called epilogue of the kernel. We initialize template variable EpilogueOp, which takes the data type of output ElementOutput (int32_t), the number of elements per vector memory access (32), data type of accumulator (int32_t) and data type of computation of linear combination (alpha * X + beta * C). diff --git a/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h b/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h index 37f81374..e780537a 100644 --- a/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h +++ b/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h @@ -177,7 +177,7 @@ public: if(args.split_k_mode == SplitKMode::kParallel) { // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace. - // The user needs to call a reduction operator to optain the final output tensor + // The user needs to call a reduction operator to obtain the final output tensor workspace_bytes = sizeof(ElementAccumulator) * size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size_0)) * diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_grouped_f16_sm80_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_grouped_f16_sm80_rf.cu index f4df3e1d..53e89ec0 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_grouped_f16_sm80_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_grouped_f16_sm80_rf.cu @@ -153,7 +153,7 @@ struct Options { out << "13_fused_two_gemms_grouped_f16_sm80_rf\n\n" << " This example runs a grouped back-to-back GEMM kernel. A group of independent back-to-back GEMMs are\n" - << " run in a single kernel. Each indivdual problem in the group is subject to the same constraints that non-grouped\n" + << " run in a single kernel. Each individual problem in the group is subject to the same constraints that non-grouped\n" << " back-to-back GEMMs are subject to.s" << "Options:\n\n" << " --help If specified, displays this usage statement.\n\n" diff --git a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h index 6070f86a..ef11f4ab 100644 --- a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h @@ -248,7 +248,7 @@ struct B2bGemm { typename Epilogue::OutputTileIterator::TensorRef* ref_C1; typename Epilogue::OutputTileIterator::TensorRef* ref_D1; - // Epilogue params remain constant across all problmes in the group. Thus, + // Epilogue params remain constant across all problems in the group. Thus, // the parameter here is not a pointer. typename OutputOp0::Params epilogue0; typename OutputOp1::Params epilogue1; @@ -402,7 +402,7 @@ struct B2bGemm { typename Epilogue::OutputTileIterator::TensorRef* ref_C1; typename Epilogue::OutputTileIterator::TensorRef* ref_D1; - // Epilogue params remain constant across all problmes in the group. Thus, + // Epilogue params remain constant across all problems in the group. Thus, // the parameter here is not a pointer. typename OutputOp0::Params output_op_0; typename OutputOp1::Params output_op_1; @@ -434,7 +434,7 @@ struct B2bGemm { // Only row-major outputs are currently supported, so no transpose is performed } - /// Returns non-grouped paramaters to be used as input to the kernel-level + /// Returns non-grouped parameters to be used as input to the kernel-level /// operator for the problem indicated by problem_visitor. CUTLASS_HOST_DEVICE Params to_single_params(const ProblemVisitor& problem_visitor) const { diff --git a/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h b/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h index feb238cf..6e0af4f9 100644 --- a/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h +++ b/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h @@ -560,7 +560,7 @@ struct DefaultB2bConv2dFprop < ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and // multistage pipeline with interleaved layout. template < typename ElementA, diff --git a/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h b/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h index eca1c611..ac80be8e 100644 --- a/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h +++ b/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h @@ -606,7 +606,7 @@ struct DefaultB2bConv2dFprop < ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and // multistage pipeline with interleaved layout. /// Accumulator will be staged in shared memory. template < diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h index d8a9d4c6..28fcc94f 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h @@ -277,7 +277,7 @@ public: IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory - FragmentC0 const &src_accum, ///< source accumualtor tile + FragmentC0 const &src_accum, ///< source accumulator tile OutputOp output_op_0, ///< epilogue operation after 1st Gemm TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h index eb23879b..b3754945 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h @@ -298,7 +298,7 @@ public: IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory - FragmentC0 const &src_accum, ///< source accumualtor tile + FragmentC0 const &src_accum, ///< source accumulator tile OutputOp output_op_0, ///< epilogue operation after 1st Gemm TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment diff --git a/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h b/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h index b7aa1ffe..cbbc24a8 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h +++ b/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h @@ -93,7 +93,7 @@ template < typename InstructionShape_, /// Number of stages used in the pipelined mainloop int Stages, - /// Operation perfomed by GEMM + /// Operation performed by GEMM typename Operator, /// Epilogue output operator typename EpilogueOutputOp, diff --git a/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu b/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu index 86e3a966..91f6a4bb 100644 --- a/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu +++ b/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu @@ -203,7 +203,7 @@ requires any memory for scratch space. If yes, we reserve scratch space and pass it along with other arguments to initialize the CUTLASS kernel. -After lauching the CUTLASS kernel, this example runs +After launching the CUTLASS kernel, this example runs a reference convolution kernel (from CUTLASS utilities) to check correctness. */ diff --git a/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu b/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu index a5a94c85..8e0094f6 100644 --- a/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu +++ b/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu @@ -144,7 +144,7 @@ int run() { // Construct Gemm ProblemSize with user defined output size cutlass::gemm::GemmCoord problem_size = {1024, 512, 1024}; - // Stride factor shows the distance between two elements in the differnet dimensions. The + // Stride factor shows the distance between two elements in the different dimensions. The // first data is the logical distance between two rows, the second is between two columns. // CUTLASS has a utility tool cutlass::layout::Affine2Layout_Factory::layout_factory // to help to convert stride_factor to the two strides. diff --git a/examples/19_tensorop_canonical/tensorop_canonical.cu b/examples/19_tensorop_canonical/tensorop_canonical.cu index 473e7ff4..b18b5387 100644 --- a/examples/19_tensorop_canonical/tensorop_canonical.cu +++ b/examples/19_tensorop_canonical/tensorop_canonical.cu @@ -55,7 +55,7 @@ /////////////////////////////////////////////////////////////////////////////////////////////////// -// Define the overal warp-level problem shape +// Define the overall warp-level problem shape int const kM = 27; int const kN = 31; int const kK = 17; diff --git a/examples/20_simt_canonical/simt_canonical.cu b/examples/20_simt_canonical/simt_canonical.cu index bec2c04d..c1e0b7f7 100644 --- a/examples/20_simt_canonical/simt_canonical.cu +++ b/examples/20_simt_canonical/simt_canonical.cu @@ -59,7 +59,7 @@ /////////////////////////////////////////////////////////////////////////////////////////////////// -// Define the overal warp-level problem shape +// Define the overall warp-level problem shape int const kM = 14; int const kN = 27; int const kK = 17; diff --git a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu index badde725..016102c7 100644 --- a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu +++ b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu @@ -30,7 +30,7 @@ **************************************************************************************************/ // This example fuses gather before GEMM and scatter after GEMM into the same -// GEMM kernel. Gather and scatter operation is controled by an index vector +// GEMM kernel. Gather and scatter operation is controlled by an index vector // to select rows or columns from A, B, C or D matrices. // // Suppose, all matrices are column major. The pseudo code of the fused kernel diff --git a/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h b/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h index 8411807d..813f7560 100644 --- a/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h +++ b/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h @@ -87,7 +87,7 @@ public: using ElementLayernormCompute = ElementLayernormCompute_; using ThreadblockShape = ThreadblockShape_; - // Pre-processing has ensured the layout equivelent to RowMajor + // Pre-processing has ensured the layout equivalent to RowMajor using Layout = cutlass::layout::RowMajor; using TensorVariance = TensorRef; diff --git a/examples/40_cutlass_py/customizable/conv2d.py b/examples/40_cutlass_py/customizable/conv2d.py index e03e6dba..d2df3ed8 100644 --- a/examples/40_cutlass_py/customizable/conv2d.py +++ b/examples/40_cutlass_py/customizable/conv2d.py @@ -87,7 +87,7 @@ parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices "TensorNHWC", "TensorNC32HW32"], help="Memory layout of input tensor A") parser.add_argument('-aa', '--alignment_a', default=1, - type=int, help="Memory alignement of input tensor A") + type=int, help="Memory alignment of input tensor A") # B parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[ "TensorNHWC", "TensorC32RSK32"], diff --git a/examples/40_cutlass_py/customizable/gemm.py b/examples/40_cutlass_py/customizable/gemm.py index 8e0013f3..3494fe53 100644 --- a/examples/40_cutlass_py/customizable/gemm.py +++ b/examples/40_cutlass_py/customizable/gemm.py @@ -86,7 +86,7 @@ parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[ "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], help="Memory layout of input tensor A") parser.add_argument('-aa', '--alignment_a', default=1, - type=int, help="Memory alignement of input tensor A") + type=int, help="Memory alignment of input tensor A") # B parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[ "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], diff --git a/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu b/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu index e21839c6..5dad08d2 100644 --- a/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu +++ b/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu @@ -55,7 +55,7 @@ ``` In practice, and for numerical stability reasons, - we also substract the maximum so far (`mi`) before doing + we also subtract the maximum so far (`mi`) before doing the exponential. When we encounter new keys, the maximum used to compute O so far (`m_prime`) can differ from the current maximum, so we update O before accumulating with diff --git a/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu b/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu index 3383ff17..6fbc7bc0 100644 --- a/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu +++ b/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu @@ -55,7 +55,7 @@ ``` In practice, and for numerical stability reasons, - we also substract the maximum so far (`mi`) before doing + we also subtract the maximum so far (`mi`) before doing the exponential. When we encounter new keys, the maximum used to compute O so far (`m_prime`) can differ from the current maximum, so we update O before accumulating with diff --git a/examples/41_fused_multi_head_attention/gemm/find_default_mma.h b/examples/41_fused_multi_head_attention/gemm/find_default_mma.h index 0e38a203..560da450 100644 --- a/examples/41_fused_multi_head_attention/gemm/find_default_mma.h +++ b/examples/41_fused_multi_head_attention/gemm/find_default_mma.h @@ -31,7 +31,7 @@ /*! \file \brief Cutlass provides helper template functions to figure out the right - datastructures to instanciate to run a GEMM with various parameters (see + datastructures to instantiate to run a GEMM with various parameters (see `cutlass/gemm/threadblock/default_mma.h`). However, due to template instantiation priority rules, it will only create an MmaMultiStage with kStages=3 (otherwise creates an MmePipelined - which is not compatible with @@ -83,7 +83,7 @@ template < typename InstructionShape, /// Number of stages used in the pipelined mainloop int Stages, - /// Operation perfomed by GEMM + /// Operation performed by GEMM typename Operator, typename Enable_ = void> struct FindDefaultMma { diff --git a/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h b/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h index 94541b8d..f2b94d00 100644 --- a/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h +++ b/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h @@ -522,7 +522,7 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< // For API compatibility with MmaMultistageFromSharedMemory // but not supported as it worsens perf: older gpus < sm80 don't - // support async tranfers and have to waste registers + // support async transfers and have to waste registers CUTLASS_DEVICE void set_prologue_done(bool value) {} CUTLASS_DEVICE diff --git a/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h b/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h index dad26742..930ee46d 100644 --- a/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h +++ b/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h @@ -29,7 +29,7 @@ * **************************************************************************************************/ /*! \file - \brief Instanciates the right WarpIterator to read from shared memory + \brief Instantiates the right WarpIterator to read from shared memory The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading data dumped with `B2bGemm::accumToSmem`. */ diff --git a/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h b/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h index 1a3a9c7e..fa40d850 100644 --- a/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h +++ b/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h @@ -86,7 +86,7 @@ namespace threadblock { /// To be efficient, this assumes the iterator will be dereferenced and advanced /// at least once outside any looping structure to minimize integer arithmetic. /// -/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to /// dereferencing the iterator. /// /// diff --git a/examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu b/examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu index 18efab14..7627f737 100644 --- a/examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu +++ b/examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu @@ -49,7 +49,7 @@ Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format for this example: a_rows - Rows in the sparse matrix. - a_cols - Colums in the sparse matrix. + a_cols - Columns in the sparse matrix. a_ell_blocksize - Size of the ELL-Blocks. a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns) tensor_a - ellValue matrix, whose size is (a_rows * a_ell_num_columns) diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py index e10b89ad..6cd01ef1 100644 --- a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py @@ -153,7 +153,7 @@ class gen_device: warp_M_tile = 32 - # Determine maxmimum N_tile + # Determine maximum N_tile Max_Ntile = 0 for layer in self.fuse_gemm_info: n_tile = layer['mnk'][1] diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py index 64b0205b..5bd46516 100644 --- a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py @@ -76,9 +76,9 @@ class gen_verify: ) - def get_params(self, declartion = True): + def get_params(self, declaration = True): code = "" - if declartion: + if declaration: for param in self.params: code += param[0] + " " + param[1] + ";\n" diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py index 5b56f70f..f757d952 100644 --- a/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py @@ -64,8 +64,8 @@ def write_2_headfile(filename, file_dir, string): with open(file_dir + filename, 'w') as f: f.write("/* Auto Generated code - Do not edit.*/\n\n\n#pragma once\n" + string) -def var_idx(varaiable, index): - return varaiable + str(index) +def var_idx(variable, index): + return variable + str(index) def list_2_string(input_list, ): diff --git a/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu index ee841526..d5758aa2 100644 --- a/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu +++ b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu @@ -78,7 +78,7 @@ a single default value. CUTLASS 3.x provides builders for both collective mainloops and epilogues. The particular implementation of - the collective is specified via the schedule tags that corresond to the underlying collective's + the collective is specified via the schedule tags that correspond to the underlying collective's dispatch policy. `gemm::collective::KernelScheduleAuto` and `epilogue::collective::EpilogueScheduleAuto` are special cases of these schedules that allow the builder to also decide the dispatch policy for you, therefore letting the builder pick the collective specialization. diff --git a/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu b/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu index 6e91a3ba..da6d6038 100644 --- a/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu +++ b/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu @@ -425,7 +425,7 @@ int main(int argc, char const **args) { // Pipeline Depth to be used i.e number of A, B buffers in shared memory constexpr int PipelineStages = 8; - // Let's choose a Warp-Specialized Mainloop implemention which uses TMA + // Let's choose a Warp-Specialized Mainloop implementation which uses TMA // Note : This requires / assumes the tensors to be 16B aligned using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized; diff --git a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu index 49505284..5af4e67b 100644 --- a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu +++ b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu @@ -32,7 +32,7 @@ \brief Example of a Hopper gather+GEMM+scatter kernel fusion. This example fuses gather before GEMM and scatter after GEMM into the same - GEMM kernel. Gather and scatter operation is controled by an index vector + GEMM kernel. Gather and scatter operation is controlled by an index vector to select rows or columns from A, B, C or D matrices. Gather/scatter operations are always performed along a strided dimension diff --git a/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu b/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu index 66ab9b24..2d2b7197 100644 --- a/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu +++ b/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu @@ -65,7 +65,7 @@ The approach relies on two things: - The ability of CUTLASS 3 to naturally perform general tensor contractions (GETT) owing to the flexibility of CuTe's hierarchical layouts (see example 51_hopper_gett for more details). - - The harware capabilities of Hopper TMA units that allow for loading multidimensional tensors with + - The hardware capabilities of Hopper TMA units that allow for loading multidimensional tensors with (almost) arbitrary strides, which can be used to represent a permuted view of the data. In this example we reuse the permutation classes of examples 39_gemm_permute as operation tags. diff --git a/examples/59_ampere_gather_scatter_conv/README.md b/examples/59_ampere_gather_scatter_conv/README.md index 2f3d8b83..b16ddf95 100644 --- a/examples/59_ampere_gather_scatter_conv/README.md +++ b/examples/59_ampere_gather_scatter_conv/README.md @@ -188,7 +188,7 @@ Running this example on an RTX 3080Ti prints the following performance numbers ( ``` $> ./examples/59_ampere_gather_scatter_conv/59_ampere_gather_scatter_conv --n=131072 --i=128 --no-check -Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors. +Ampere convolution forward propagation kernel supporting both affine and gather/scatter tensors. Allocating tensors ... done. Initializing data ... done. diff --git a/examples/59_ampere_gather_scatter_conv/ampere_gather_scatter_conv.cu b/examples/59_ampere_gather_scatter_conv/ampere_gather_scatter_conv.cu index ee1d658a..fa28b640 100644 --- a/examples/59_ampere_gather_scatter_conv/ampere_gather_scatter_conv.cu +++ b/examples/59_ampere_gather_scatter_conv/ampere_gather_scatter_conv.cu @@ -29,7 +29,7 @@ * **************************************************************************************************/ /*! \file - \brief Example demonstrating CuTe and CUTLASS 3.x based Ampere convolution forward propogation kernel + \brief Example demonstrating CuTe and CUTLASS 3.x based Ampere convolution forward propagation kernel capable of operating on both affine and gather/scatter tensors. This example demonstartes a few super cool features of CUTLASS and CuTe. It shows off @@ -284,7 +284,7 @@ int ampere_gather_scatter_conv_fprop( int main(int argc, char const** argv) { cutlass::CommandLine cmd(argc, argv); - std::cout << "Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors.\n\n"; + std::cout << "Ampere convolution forward propagation kernel supporting both affine and gather/scatter tensors.\n\n"; if (cmd.check_cmd_line_flag("help")) { std::cout << "Options:\n" diff --git a/examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu b/examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu index 7763bc10..9bed32f5 100644 --- a/examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu +++ b/examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu @@ -291,7 +291,7 @@ struct Options { // Post-process the problem sizes bin_problems(); - // Initalize alpha array + // Initialize alpha array randomize_alpha_ptr_array(cmd); } diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu index b3da5583..9e55755b 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu @@ -358,7 +358,7 @@ void initialize(const Options &options) { // Layout SFA and SFB represent logically broadcasting data in CuTe. // E.g., if Layout SFA has shape ((ScaleGranularityM, M / ScaleGranularityM), (ScaleGraunularityK, K / ScaleGranularityK)) // and strides ((0, 1), (0, M / ScaleGraunuarlityM)), then each collection of ScaleGranularityM x ScaleGranularityK - // indecies in the tensor map to the same offset. + // indices in the tensor map to the same offset. layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l)); layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l)); diff --git a/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu b/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu index c6e1d753..add938a9 100644 --- a/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu +++ b/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu @@ -61,7 +61,7 @@ # Heuristic mode with deterministic reduction ./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=Heuristic --reduction=Deterministic - # Stream-K mode with determinsitic reduction + # Stream-K mode with deterministic reduction ./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=StreamK --reduction=Deterministic # Split-K mode with a splitting factor of 2 and deterministic reduction diff --git a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu index 0632714a..f363dfa0 100644 --- a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu +++ b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu @@ -850,7 +850,7 @@ int run(Options &options, bool host_problem_shapes_available = true) } } else { - std::cout << " Verfication is turned off for this run." << std::endl; + std::cout << " Verification is turned off for this run." << std::endl; } // Run profiling loop diff --git a/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu b/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu index 67313926..91511235 100644 --- a/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu +++ b/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu @@ -36,7 +36,7 @@ APIs on NVIDIA Blackwell SM100 architecture. The basic computation logic of dgrad convolution kernel is, take 3D convolution as an example: - Xformed Actication (NZPQK) * Weight/Filter (KTRSC) = Activation (NDHWC) + Xformed Activation (NZPQK) * Weight/Filter (KTRSC) = Activation (NDHWC) where in terms of GEMM perspective, Matrix A = Xformed Activation, Matrix B = Weight/Filter, Matrix C = Activation diff --git a/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu b/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu index b1d7bc15..15ec02aa 100644 --- a/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu +++ b/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu @@ -36,7 +36,7 @@ APIs on NVIDIA Blackwell SM100 architecture. The basic computation logic of fprop convolution kernel is, take 3D convolution as an example: - Activation (NDHWC) * Weight/Filter (KTRSC) = Xformed Actication (NZPQK) + Activation (NDHWC) * Weight/Filter (KTRSC) = Xformed Activation (NZPQK) where in terms of GEMM perspective, Matrix A = Activation, Matrix B = Weight/Filter, Matrix C = Xformed Activation diff --git a/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu b/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu index abac47ae..e47dbece 100644 --- a/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu +++ b/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu @@ -36,7 +36,7 @@ APIs on NVIDIA Blackwell SM100 architecture. The basic computation logic of wgrad convolution kernel is, take 3D convolution as an example: - Xformed Actication (NZPQK) * Activation (NDHWC) = Weight/Filter (KTRSC) + Xformed Activation (NZPQK) * Activation (NDHWC) = Weight/Filter (KTRSC) where in terms of GEMM perspective, Matrix A = Xformed Activation, Matrix B = Activation, Matrix C = Weight/Filter diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha.cu b/examples/77_blackwell_fmha/77_blackwell_fmha.cu index 405ddfd6..174e05cf 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha.cu @@ -505,8 +505,12 @@ struct FwdRunner { Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()), select<0,3>(problem_shape), stride_LSE); + + auto [Q, K, D, HB] = problem_shape; - fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{}); + auto problem_shape_ref = cute::make_tuple(Q, K, D, D, HB); + + fmha_reference(problem_shape_ref, mQ, mK, mV, mO, mLSE, ActiveMask{}); cudaError_t result = cudaDeviceSynchronize(); if (result != cudaSuccess) { diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu b/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu index 67188b51..f877cc52 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu @@ -32,7 +32,7 @@ \brief Example implementation of fused multi-head attention for Blackwell using CUTLASS 3. This example showcases the use of CUTLASS to build backward fused - multi-head attantion (FMHA) collectives from existing CUTLASS collectives targeting + multi-head attention (FMHA) collectives from existing CUTLASS collectives targeting the NVIDIA Blackwell architecture. Background and motivation @@ -117,6 +117,7 @@ struct Options { std::vector varlen_q; std::vector varlen_k; int d = 128; + int d_vo = 128; int iterations = 3; bool verify = false; bool verbose = false; @@ -178,6 +179,7 @@ struct Options { } cmd.get_cmd_line_argument("d", d, defaults.d); + cmd.get_cmd_line_argument("d_vo", d_vo, d); cmd.get_cmd_line_argument("h", h, -1); if (h == -1) h = 2048 / d; @@ -301,6 +303,7 @@ struct Options { << " --varlen-q=: Sets the variable Q extent per batch (colon separated)\n" << " --varlen-k=: Sets the variable K extent per batch (colon separated)\n" << " --d= Sets the D extent\n" + << " --d_vo= Sets the D_VO extent\n" << " --iterations= Benchmarking iterations\n" << " --verify Verify results\n" << " --verbose Print smem and execution time per kernel\n" @@ -387,6 +390,7 @@ struct ExampleResult { template< bool kIsVarlen, + bool kIsMla, class TileShape, class DispatchPolicy, class ActiveMask, @@ -404,8 +408,8 @@ struct BwdRunner { // Q K D (H B) using ProblemShape = std::conditional_t< kIsVarlen, - cute::tuple>, - cute::tuple> + cute::tuple>, + cute::tuple> >; using TensorStride = Stride>; // Seq D (H B) @@ -461,45 +465,45 @@ struct BwdRunner { // Methods // bool verify(const ProblemShape& problem_shape) { - auto [Q, K, D, HB] = problem_shape; + auto [Q, K, D, D_VO, HB] = problem_shape; auto [H, B] = HB; Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), - select<0,2,3>(problem_shape), + select<0,2,4>(problem_shape), stride_Q); Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), - select<1,2,3>(problem_shape), + select<1,2,4>(problem_shape), stride_K); Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), - select<1,2,3>(problem_shape), + select<1,3,4>(problem_shape), stride_V); Tensor mO = make_tensor(make_gmem_ptr(block_O.get()), - select<0,2,3>(problem_shape), + select<0,3,4>(problem_shape), stride_O); // keep going here! (this might be better in cursor) Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()), - select<0,3>(problem_shape), + select<0,4>(problem_shape), stride_LSE); Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()), - select<0,2,3>(problem_shape), + select<0,2,4>(problem_shape), stride_dQ); Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()), - select<1,2,3>(problem_shape), + select<1,2,4>(problem_shape), stride_dK); Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()), - select<1,2,3>(problem_shape), + select<1,3,4>(problem_shape), stride_dV); Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()), - select<0,2,3>(problem_shape), + select<0,3,4>(problem_shape), stride_dO); fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask{}); @@ -595,14 +599,14 @@ struct BwdRunner { ProblemShape problem_shape{ {max_seqlen_q, block_cumulative_seqlen_q.get(), total_seqlen_q}, {max_seqlen_kv, block_cumulative_seqlen_kv.get(), total_seqlen_kv}, - options.d, {options.h, options.b} + options.d, options.d_vo, {options.h, options.b} }; - auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, make_shape(options.h, 1)); + auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, options.d_vo, make_shape(options.h, 1)); return cute::make_tuple(problem_shape, tensor_shape); } else { - ProblemShape problem_shape{options.q, options.k, options.d, {options.h, options.b}}; + ProblemShape problem_shape{options.q, options.k, options.d, options.d_vo, {options.h, options.b}}; return cute::make_tuple(problem_shape, problem_shape); } } @@ -610,24 +614,25 @@ struct BwdRunner { /// Initialize operands to be used in the GEMM and reference GEMM ProblemShape initialize(Options const& options) { auto [problem_shape, tensor_shape] = initialize_problem_shape(options); - auto [Q, K, D, HB] = tensor_shape; + auto [Q, K, D, D_VO, HB] = tensor_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment // for varlen, Q == total_Q, K == total_K, B = 1 // but in problem_shape, they've got to be max_Q/max_K, and B = B - auto shape_QO = make_shape(Q, D, make_shape(H, B)); - auto shape_KV = make_shape(K, D, make_shape(H, B)); + auto shape_Q = make_shape(Q, D, make_shape(H, B)); + auto shape_O = make_shape(Q, D_VO, make_shape(H, B)); + auto shape_K = make_shape(K, D, make_shape(H, B)); + auto shape_V = make_shape(K, D_VO, make_shape(H, B)); auto shape_LSE = make_shape(Q, make_shape(H, B)); stride_Q = make_stride(D, _1{}, make_stride(D*Q, B == 1 ? 0 : D*Q*H)); stride_K = make_stride(D, _1{}, make_stride(D*K, B == 1 ? 0 : D*K*H)); + stride_V = make_stride(D_VO, _1{}, make_stride(D_VO*K, B == 1 ? 0 : D_VO*K*H)); + stride_O = make_stride(D_VO, _1{}, make_stride(D_VO*Q, B == 1 ? 0 : D_VO*Q*H)); stride_LSE = make_stride(_1{}, make_stride(Q, B == 1 ? 0 : Q*H)); - stride_V = stride_K; - stride_O = stride_Q; - stride_dQ = stride_Q; stride_dK = stride_K; stride_dV = stride_V; @@ -637,20 +642,20 @@ struct BwdRunner { return size(make_shape(1ull, shape)); }; - block_Q.reset(lsize(shape_QO)); - block_K.reset(lsize(shape_KV)); - block_V.reset(lsize(shape_KV)); - block_O.reset(lsize(shape_QO)); + block_Q.reset(lsize(shape_Q)); + block_K.reset(lsize(shape_K)); + block_V.reset(lsize(shape_V)); + block_O.reset(lsize(shape_O)); block_LSE.reset(lsize(shape_LSE)); - block_dQ.reset(lsize(shape_QO)); - block_dK.reset(lsize(shape_KV)); - block_dV.reset(lsize(shape_KV)); - block_dO.reset(lsize(shape_QO)); + block_dQ.reset(lsize(shape_Q)); + block_dK.reset(lsize(shape_K)); + block_dV.reset(lsize(shape_V)); + block_dO.reset(lsize(shape_O)); - block_ref_dQ.reset(lsize(shape_QO)); - block_ref_dK.reset(lsize(shape_KV)); - block_ref_dV.reset(lsize(shape_KV)); + block_ref_dQ.reset(lsize(shape_Q)); + block_ref_dK.reset(lsize(shape_K)); + block_ref_dV.reset(lsize(shape_V)); initialize_block(block_Q, seed + 2023, options.init_style_q); initialize_block(block_K, seed + 2022, options.init_style_k); @@ -665,23 +670,23 @@ struct BwdRunner { initialize_block(block_ref_dV, seed + 2035); Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), - select<0,2,3>(problem_shape), + select<0,2,4>(problem_shape), stride_Q); Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), - select<1,2,3>(problem_shape), + select<1,2,4>(problem_shape), stride_K); Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), - select<1,2,3>(problem_shape), + select<1,3,4>(problem_shape), stride_V); Tensor mO = make_tensor(make_gmem_ptr(block_O.get()), - select<0,2,3>(problem_shape), + select<0,3,4>(problem_shape), stride_O); Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()), - select<0,3>(problem_shape), + select<0,4>(problem_shape), stride_LSE); if (! options.skip_reference) { @@ -698,7 +703,7 @@ struct BwdRunner { ExampleResult example_result; - using Operation = cutlass::fmha::device::Sm100FmhaBwd; + using Operation = cutlass::fmha::device::Sm100FmhaBwd; typename Operation::Arguments arguments{ problem_shape, @@ -811,12 +816,12 @@ struct BwdRunner { runtime_ms /= static_cast(options.iterations); - double flops = 10.0 * (std::is_same_v ? 0.5 : 1.0); + double flops = 2.0 * (std::is_same_v ? 0.5 : 1.0); flops *= static_cast(get<0>(problem_shape)); flops *= static_cast(get<1>(problem_shape)); - flops *= static_cast(get<2>(problem_shape)); - flops *= static_cast(get<3,0>(problem_shape)); - flops *= static_cast(get<3,1>(problem_shape)); + flops *= (3 * static_cast(get<2>(problem_shape)) + 2 * static_cast(get<3>(problem_shape))); + flops *= static_cast(get<4,0>(problem_shape)); + flops *= static_cast(get<4,1>(problem_shape)); double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/); example_result.tflops_tc_s = tflops_s; example_result.runtime_ms = runtime_ms; @@ -892,7 +897,7 @@ template void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { dispatch_bool(options.varlen, [&](auto is_varlen) { - BwdRunner runner; + BwdRunner runner; auto result = runner.run(options, hw_info); print_result(name, result, options.verbose); }); @@ -900,7 +905,7 @@ void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf using HeadDim = _64; - run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma"); + run(Shape<_128, _128, HeadDim, HeadDim>{}, KernelCoop{}, "tma"); } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -909,7 +914,7 @@ template void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { dispatch_bool(options.varlen, [&](auto is_varlen) { - BwdRunner runner; + BwdRunner runner; auto result = runner.run(options, hw_info); print_result(name, result, options.verbose); }); @@ -917,7 +922,22 @@ void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn using HeadDim = _128; - run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma"); + run(Shape<_128, _128, HeadDim, HeadDim>{}, KernelCoop{}, "tma"); +} + +template +void run_bwd_mla_192(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { + dispatch_bool(options.varlen, [&](auto is_varlen) { + BwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }); + }; + + using HeadDim = _192; + + run(Shape<_64, _128, HeadDim, _128>{}, KernelCoop{}, "tma"); } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -981,7 +1001,7 @@ int main_single(int argc, char const **args) { hw_info.sm_count = options.sm_count; } - std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " "; + std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " D_VO " << options.d_vo << " "; std::cout << "Backward" << " " << (options.causal ? "Causal" : "Full") << " "; std::cout << "#SM " << hw_info.sm_count << std::endl; @@ -998,12 +1018,15 @@ int main_single(int argc, char const **args) { }; with_causal([&](auto fusion) { - if (options.d <= 64) { + if (options.d <= 64 && options.d_vo == options.d) { run_bwd_64(fusion, options, hw_info); } - else if (options.d <= 128) { + else if (options.d <= 128 && options.d_vo == options.d) { run_bwd_128(fusion, options, hw_info); } + else if (options.d == 192 && options.d_vo == 128) { + run_bwd_mla_192(fusion, options, hw_info); + } else { std::cout << "No kernel instantiated for d=" << options.d << std::endl; } diff --git a/examples/77_blackwell_fmha/77_blackwell_mla_fwd.cu b/examples/77_blackwell_fmha/77_blackwell_mla_fwd.cu index 51420b00..df229ffb 100644 --- a/examples/77_blackwell_fmha/77_blackwell_mla_fwd.cu +++ b/examples/77_blackwell_fmha/77_blackwell_mla_fwd.cu @@ -485,7 +485,11 @@ struct MlaFwdRunner { select<0,3>(problem_shape), stride_LSE); - fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{}); + auto [Q, K, D, HB] = problem_shape; + + auto problem_shape_ref = cute::make_tuple(Q, K, D, D, HB); + + fmha_reference(problem_shape_ref, mQ, mK, mV, mO, mLSE, ActiveMask{}); cudaError_t result = cudaDeviceSynchronize(); if (result != cudaSuccess) { diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index e5c998e7..2b8f1da2 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -84,6 +84,8 @@ set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap) set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only) set(TEST_MLA_BASIC --b=1 --k=512 --page=128 --verify) +set(TEST_BWD_MLA_BASIC --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=no) +set(TEST_BWD_MLA_VARLEN --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=residual --varlen) if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a)) @@ -174,6 +176,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC TEST_VARLEN_12 TEST_VARLEN_13 TEST_VARLEN_14 + TEST_BWD_MLA_BASIC + TEST_BWD_MLA_VARLEN ) target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO}) diff --git a/examples/77_blackwell_fmha/README.md b/examples/77_blackwell_fmha/README.md index 58ad99a8..1e28929b 100644 --- a/examples/77_blackwell_fmha/README.md +++ b/examples/77_blackwell_fmha/README.md @@ -37,13 +37,19 @@ There are three kernels to compute backwards: `Sm100FmhaBwdKernelTmaWarpSpecialized` is the main point of this sample, as it demonstrates how to use tensor cores to achieve a high performance fused kernel. +## MLA Blackwell Backward + +The sample also provides the feature of MLA backward(d=192, d_vo=128). To enable MLA backward, please specify `--d=192 --d_vo=128` when running the bwd sample. + +`Sm100FmhaBwdMlaKernelTmaWarpSpecialized`is the main point for MLA backward. The MLA approach is slightly different from the original one to enable high performance with the MLA shape. + # MLA Inference for Blackwell This sample provides code for fused multi-head latent attention inference in the weight-absorbed regime, i.e. for latent head dim 512, and rope head dim 64. It supports fp16, bf16, and fp8 input and output types. -To accomodate the large output accumulator due to the large latent head dimension, +To accommodate the large output accumulator due to the large latent head dimension, the sample demonstrates how to leverage 2Sm Blackwell tensor cores. Loading can be done via TMA (either without paging or with page size 128), or using `cp.async` diff --git a/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp b/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp index 3c8f7195..5c5de849 100644 --- a/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp +++ b/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp @@ -39,6 +39,7 @@ #include "../device/fmha.hpp" #include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp" +#include "../kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp" #include "../kernel/fmha_kernel_bwd_sum_OdO.hpp" #include "../kernel/fmha_kernel_bwd_convert.hpp" @@ -55,13 +56,14 @@ template< class Element, class ElementAccumulator, class TileShape, + bool IsMla, class Mask > class Sm100FmhaBwd { public: /// Argument structure: User API struct Arguments { - // Q K D HB + // Q K D D_VO HB ProblemShape problem_shape; const Element* ptr_Q; @@ -98,11 +100,20 @@ public: cutlass::fmha::kernel::FmhaKernelBwdConvert >; - using Operation = cutlass::fmha::device::FMHA< + using OperationNormal= cutlass::fmha::device::FMHA< cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized< ProblemShape, Element, ElementAccumulator, TileShape, Mask > >; + + using OperationMla = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized< + ProblemShape, Element, ElementAccumulator, TileShape, Mask + > + >; + + using Operation = std::conditional_t; + using Kernel = typename Operation::Kernel; struct Params { @@ -121,7 +132,7 @@ private: ElementAccumulator* sum_odo = nullptr, ElementAccumulator* scaled_lse = nullptr) { using namespace cute; - auto [Q_, K, D, HB] = args.problem_shape; + auto [Q_, K, D, D_VO, HB] = args.problem_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment @@ -141,7 +152,7 @@ private: static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) { using namespace cute; - auto [Q_, K, D, HB] = args.problem_shape; + auto [Q_, K, D, D_VO, HB] = args.problem_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment @@ -163,6 +174,7 @@ private: ElementAccumulator* sum_OdO = nullptr, cute::tuple> const& stride_sum_OdO = {}, ElementAccumulator* scaled_lse = nullptr, cute::tuple> const& stride_scaled_lse = {}, ElementAccumulator* dQ_acc = nullptr, cute::tuple> const& stride_dQ = {}) { + return typename Operation::Arguments{ args.problem_shape, { args.ptr_Q, args.stride_Q, @@ -207,7 +219,7 @@ public: /// Gets the workspace size static size_t get_workspace_size(Arguments const& args) { - auto [Q_, K, D, HB] = args.problem_shape; + auto [Q_, K, D, D_VO, HB] = args.problem_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment @@ -227,7 +239,7 @@ public: CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ=" << workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null")); - auto [Q_, K, D, HB] = args.problem_shape; + auto [Q_, K, D, D_VO, HB] = args.problem_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment @@ -256,7 +268,7 @@ public: CUTLASS_TRACE_HOST("Universal::initialize() - workspace " << workspace << ", stream: " << (stream ? "non-null" : "null")); - auto [Q_, K, D, HB] = args.problem_shape; + auto [Q_, K, D, D_VO, HB] = args.problem_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment diff --git a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp index c7f869f9..f18fa3bf 100644 --- a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp +++ b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp @@ -85,11 +85,11 @@ struct FmhaKernelBwdConvert { static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq; static bool can_implement(Arguments const& args) { - return get<2>(args.problem_shape) % kElementsPerLoad == 0; + return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0; } static dim3 get_grid_shape(Params const& params) { - dim3 grid(size<3,0>(params.problem_shape), size<3,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq)); + dim3 grid(size<4,0>(params.problem_shape), size<4,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq)); return grid; } @@ -103,7 +103,7 @@ struct FmhaKernelBwdConvert { } template - CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count) { + CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) { auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y; auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y; @@ -120,7 +120,7 @@ struct FmhaKernelBwdConvert { auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src); auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest); - for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) { + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < d_dim; idx_d += kElementsPerLoad * kNumThreadsD) { ElementAcc value_src[kElementsPerLoad]; Element value_dest[kElementsPerLoad]; @@ -139,13 +139,13 @@ struct FmhaKernelBwdConvert { CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { if (params.ptr_src_dQ != nullptr) { - copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape)); + copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape)); } if (params.ptr_src_dK != nullptr) { - copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape)); + copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape), get<2>(params.problem_shape)); } if (params.ptr_src_dV != nullptr) { - copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape)); + copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape), get<3>(params.problem_shape)); } } }; diff --git a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp index 98c127da..4a26d768 100644 --- a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp +++ b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -86,11 +86,11 @@ struct FmhaKernelBwdSumOdO { static const int kIterationsQ = kBlockQ / kNumThreadsQ; static bool can_implement(Arguments const& args) { - return get<2>(args.problem_shape) % kElementsPerLoad == 0; + return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0; } static dim3 get_grid_shape(Params const& params) { - dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<3,0>(params.problem_shape), size<3,1>(params.problem_shape)); + dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<4,0>(params.problem_shape), size<4,1>(params.problem_shape)); return grid; } @@ -131,7 +131,7 @@ struct FmhaKernelBwdSumOdO { auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse); auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse); - for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) { + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<3>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) { Element value_O[kElementsPerLoad]; Element value_dO[kElementsPerLoad]; diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp index 82ae4270..48f502da 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -344,12 +344,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { static bool can_implement(Arguments const& args) { - auto [Q, K, D, HB] = args.problem_shape; + auto [Q, K, D, D_VO, HB] = args.problem_shape; auto [H, B] = HB; - if (Q <= 0 || K <= 0 || D <= 0 || H <= 0 || B <= 0) { + if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H <= 0 || B <= 0) { return false; } - if (D % Alignment != 0) { + if (D % Alignment != 0 || D_VO % Alignment != 0) { return false; } return true; @@ -362,7 +362,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { static Params to_underlying_arguments(Arguments const& args, void*) { - auto [Q_, K_, D, HB] = args.problem_shape; + auto [Q_, K_, D, D_VO, HB] = args.problem_shape; int Q = Q_; int K = K_; @@ -381,7 +381,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { }, /*workspace=*/nullptr); auto params_vdo = CollectiveMmaVDO::to_underlying_arguments( - make_shape(K, Q, D, HB), + make_shape(K, Q, D_VO, HB), typename CollectiveMmaVDO::Arguments { args.mainloop.ptr_v, args.mainloop.stride_v, args.mainloop.ptr_do, args.mainloop.stride_do, @@ -446,21 +446,21 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { - auto [Q, K, D, HB] = problem_shape; + auto [Q, K, D, D_VO, HB] = problem_shape; using X = Underscore; uint16_t mcast_mask = 0; auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); - auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB)); + auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB)); auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); - auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB)); + auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB)); - auto mK = domain_offset(select<1,2,3>(blk_offset), mK_in); - auto mV = domain_offset(select<1,2,3>(blk_offset), mV_in); - auto mQ = domain_offset(select<0,2,3>(blk_offset), mQ_in); - auto mDO = domain_offset(select<0,2,3>(blk_offset), mDO_in); + auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in); + auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in); + auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in); + auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in); auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{}); auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step{}); @@ -495,7 +495,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { // set up lse and sum_odo - auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); @@ -681,7 +681,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { - auto [Q, K, D, HB] = problem_shape; + auto [Q, K, D, D_VO, HB] = problem_shape; auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); @@ -974,11 +974,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { MainloopArguments const& mainloop_args, EpilogueArguments const& epilogue_args) { - auto [Q, K, D, HB] = problem_shape; - auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord; + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); - auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, blk_coord_k, _0{}, blk_coord_batch); @@ -988,7 +988,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { ); auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); - auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, blk_coord_k, _0{}, blk_coord_batch); @@ -1003,7 +1003,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } } for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) { - if (elem_less(cDV(i), select<1,2>(problem_shape))) { + if (elem_less(cDV(i), select<1,3>(problem_shape))) { gDV(i) = Element(0); } } @@ -1020,8 +1020,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { - auto [Q, K, D, HB] = problem_shape; - auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord; + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; @@ -1029,7 +1029,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { tDKtDK.data() = TmemAllocation::kDK; auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); - auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, blk_coord_k, _0{}, blk_coord_batch); @@ -1065,7 +1065,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { tDVtDV.data() = TmemAllocation::kDV; auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); - auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, blk_coord_k, _0{}, blk_coord_batch); @@ -1088,7 +1088,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV); // store tDVgDV - store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,2>(problem_shape)); + store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape)); cutlass::arch::fence_view_async_tmem_load(); pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); @@ -1140,7 +1140,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { - auto [Q, K, D, HB] = problem_shape; + auto [Q, K, D, D_VO, HB] = problem_shape; // in tmem, S & P overlap // and dP and dQ overlap @@ -1396,9 +1396,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { using X = Underscore; - auto [Q, K, D, HB] = problem_shape; + auto [Q, K, D, D_VO, HB] = problem_shape; - auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; // must match TileShapeDQ auto load_op = SM100_TMEM_LOAD_32dp32b32x{}; @@ -1676,7 +1676,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { pipeline_init_wait(size(ClusterShape{})); - auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z)); + auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z)); auto [problem_shape, blk_offset] = apply_variable_length_offset( params.problem_shape, blk_coord @@ -1809,7 +1809,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } static dim3 get_grid_shape(Params const& params) { - auto [Q, K, D, HB] = params.problem_shape; + auto [Q, K, D, D_VO, HB] = params.problem_shape; auto [H, B] = HB; dim3 grid(ceil_div(K, TileShapeK{}), H, B); return grid; diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp new file mode 100644 index 00000000..67e7203c --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp @@ -0,0 +1,1813 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "collective/fmha_common.hpp" + +#include + +namespace cutlass::fmha::kernel { + +using namespace cutlass::fmha::collective; + +using namespace cute; + +template< + class ProblemShape, + class Element, + class ElementAcc, + class TileShape, + class Mask +> +struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { + + using TileShapeQ = decltype(get<0>(TileShape{})); + using TileShapeK = decltype(get<1>(TileShape{})); + using TileShapeDQK = decltype(get<2>(TileShape{})); + using TileShapeDVO = decltype(get<3>(TileShape{})); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + struct TmemAllocation { + static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc + static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc + static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc + static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp + static constexpr uint32_t kS = kDQ + 65536 * 16; + static constexpr uint32_t kP = kS; + static constexpr uint32_t kTotal = kDQ + TileShapeDQK{}; + }; + + static_assert( + static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, + "using too much tmem" + ); + + enum class WarpRole { + Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4 + }; + + static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull; + static constexpr int kNumComputeWarps = 8; + static constexpr int kNumReduceWarps = 4; + + static constexpr int kLoadPerThread = TileShapeQ{} / NumThreadsPerWarp; + static_assert(TileShapeQ{} % NumThreadsPerWarp == 0, "TileShapeQ must be divisible by NumThreadsPerWarp"); + CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + struct RegisterAllocation { + static constexpr int kWarpgroup0 = 160-8; + static constexpr int kWarpgroup1 = 128; + static constexpr int kWarpgroup2 = 96; + static constexpr int kReduce = kWarpgroup0; + static constexpr int kCompute = kWarpgroup1; + static constexpr int kMma = kWarpgroup2; + static constexpr int kEmpty = kWarpgroup2; + static constexpr int kLoad = kWarpgroup2; + + static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512); + }; + + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = Shape<_1, _1, _1>; + using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + + static constexpr int MinBlocksPerMultiprocessor = 1; + static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4; + static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps; + + static constexpr int Alignment = 128 / sizeof_bits_v; + static constexpr int kStages = 2; + + using TensorStrideContiguousK = Stride>; + using TensorStrideContiguousMN = Stride<_1, int, Stride>; + + // compute S + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + + // compute dP + using CollectiveMmaDOV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDOV = typename CollectiveMmaDOV::TileShape; + using TiledMmaDOV = typename CollectiveMmaDOV::TiledMma; + + // compute dV + using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // needs to match ordering of S calculation + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapePDO = typename CollectiveMmaPDO::TileShape; + using TiledMmaPDO = typename CollectiveMmaPDO::TiledMma; + + // compute dK + using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the next one + Element, TensorStrideContiguousK , Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape; + using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma; + + // compute dQ + using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the previous one + Element, TensorStrideContiguousMN, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSK = typename CollectiveMmaDSK::TileShape; + using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma; + + // pipelines are named Pipeline + static constexpr int kStagesComputeSmem = 1; + using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>; + using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>; + using PipelineLoadComputeLSE = PipelineAsync<1>; + using PipelineLoadComputeSumOdO = PipelineAsync<1>; + using PipelineMmaComputeS = PipelineUmmaAsync<1>; + using PipelineMmaComputeDP = PipelineUmmaAsync<1>; + using PipelineMmaReduceDQ = PipelineUmmaAsync<1>; + using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>; + using PipelineComputeMmaDS = PipelineUmmaConsumerAsync; + using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>; + static constexpr int kStagesReduceTmaStore = 2; + using PipelineReduceTmaStore = PipelineTmaStore; + + struct PipelineStorage { + alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q; + alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do; + alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse; + alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo; + alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s; + alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp; + alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq; + alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p; + alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds; + alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv; + }; + + template + static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutK = decltype(restage(typename CollectiveMmaQK::SmemLayoutB{})); + using SmemLayoutV = decltype(restage(typename CollectiveMmaDOV::SmemLayoutB{})); + using SmemLayoutQ = decltype(restage(typename CollectiveMmaQK::SmemLayoutA{}, _2{})); + using SmemLayoutDO = decltype(restage(typename CollectiveMmaDOV::SmemLayoutA{}, _1{})); + using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int{})); + using SmemLayoutLSE = Layout>; + using SmemLayoutSumOdO = Layout>; + + using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{})); + using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{})); + using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int{})); + using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{})); + using SmemLayoutP = decltype(restage(typename CollectiveMmaPDO::SmemLayoutA{}, _1{})); + using SmemLayoutPT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, _1{})); + + using TileShapeDQ = _32; + using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ + >()); + using SmemShapeDQ = Shape>; + using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{})); + + struct TensorStorage { + union { + alignas(2048) cute::array> smem_k; + alignas(2048) cute::array> smem_k_t; + }; + alignas(2048) cute::array> smem_v; + union { + alignas(2048) cute::array> smem_q; + alignas(2048) cute::array> smem_q_t; + }; + union { + alignas(2048) cute::array> smem_do; + alignas(2048) cute::array> smem_do_t; + }; + union { + alignas(2048) cute::array> smem_ds; + alignas(2048) cute::array> smem_ds_t; + }; + union{ + alignas(2048) cute::array> smem_p; + alignas(2048) cute::array> smem_p_t; + }; + alignas(1024) cute::array> smem_dq; + alignas(16) cute::array> smem_lse; + alignas(16) cute::array> smem_sum_odo; + }; + + static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v); + + static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + struct SharedStorage { + TensorStorage tensors; + PipelineStorage pipelines; + uint32_t tmem_base_ptr; + }; + + // this is tight enough that it won't work with sizeof due to padding for alignment + static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + using TensorStride = TensorStrideContiguousK; // S D (H B) + using RowTensorStride = Stride<_1, Stride>; // S (H B) + + struct MainloopArguments { + const Element* ptr_q; + TensorStride stride_q; + const Element* ptr_k; + TensorStride stride_k; + const Element* ptr_v; + TensorStride stride_v; + const Element* ptr_do; + TensorStride stride_do; + + const ElementAcc* ptr_lse; + RowTensorStride stride_lse; + + const ElementAcc* ptr_sum_odo; + RowTensorStride stride_sum_odo; + + ElementAcc* ptr_dq_acc; + TensorStride stride_dq_acc; + + ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); + }; + + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaDOV::Params::TMA_B; + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_DO = typename CollectiveMmaDOV::Params::TMA_A; + + using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, + make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}), + SmemLayoutDQ{}(_, _, _0{}) + )); + + struct MainloopParams { + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_Q tma_load_q; + TMA_DO tma_load_do; + TMA_DQ tma_red_dq; + }; + + struct EpilogueArguments { + Element* ptr_dk; + TensorStride stride_dk; + Element* ptr_dv; + TensorStride stride_dv; + }; + + struct Arguments { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + MainloopParams mainloop_params; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + + static bool can_implement(Arguments const& args) { + auto [Q, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + if (Q <= 0 || K <= 0 || D <= 0 || H <= 0 || B <= 0 || D_VO <= 0) { + return false; + } + if (D % Alignment != 0 || D_VO % Alignment != 0) { + return false; + } + return true; + } + + + static Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return Status::kSuccess; + } + + + static Params to_underlying_arguments(Arguments const& args, void*) { + auto [Q_, K_, D, D_VO, HB] = args.problem_shape; + int Q = Q_; + int K = K_; + + if constexpr (is_variable_length_v) { + Q = Q_.total_length; + } + if constexpr (is_variable_length_v) { + K = K_.total_length; + } + + auto params_kq = CollectiveMmaQK::to_underlying_arguments( + make_shape(Q, K, D, HB), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q, args.mainloop.stride_q, + args.mainloop.ptr_k, args.mainloop.stride_k, + }, /*workspace=*/nullptr); + + auto params_vdo = CollectiveMmaDOV::to_underlying_arguments( + make_shape(Q, K, D_VO, HB), + typename CollectiveMmaDOV::Arguments { + args.mainloop.ptr_do, args.mainloop.stride_do, + args.mainloop.ptr_v, args.mainloop.stride_v, + }, /*workspace=*/nullptr); + + TMA_DQ tma_red_dq = make_tma_copy( + SM90_TMA_REDUCE_ADD{}, + make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc), + SmemLayoutDQ{}(_, _, _0{}) + ); + + return Params{ + args.problem_shape, + args.mainloop, + MainloopParams{ + params_kq.tma_load_b, + params_vdo.tma_load_b, + params_kq.tma_load_a, + params_vdo.tma_load_a, + tma_red_dq + }, + args.epilogue, + args.hw_info + }; + } + + + template + static CUTLASS_DEVICE auto quantize(T const& input) { + constexpr int AlignmentS = 4; + auto output = make_tensor(shape(input)); + auto input_vec = recast>(input); + auto output_vec = recast>(output); + + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(input_vec); i++) { + output_vec(i) = epilogue_op(input_vec(i)); + } + + return output; + } + + + template + CUTLASS_DEVICE void load( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + using X = Underscore; + + uint16_t mcast_mask = 0; + + auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); + auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB)); + auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); + auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB)); + + auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in); + auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in); + auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in); + auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in); + + auto gK = local_tile(mK, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gQ = local_tile(mQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gV = local_tile(mV, TileShapeDOV{}, make_coord(_,_,_), Step{}); + auto gDO = local_tile(mDO, TileShapeDOV{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + ThrMMA cta_mma_kq = TiledMmaQK{}.get_slice(_0{}); + ThrMMA cta_mma_vdo = TiledMmaDOV{}.get_slice(_0{}); + + auto tSTgK = cta_mma_kq.partition_B(gK); + auto tSTgQ = cta_mma_kq.partition_A(gQ); + auto tDPTgV = cta_mma_vdo.partition_B(gV); + auto tDPTgDO = cta_mma_vdo.partition_A(gDO); + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto [tKgK_mkl, tKsK] = tma_partition( + mainloop_params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSTgK)); + auto [tQgQ_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ)); + auto [tVgV_mkl, tVsV] = tma_partition( + mainloop_params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tDPTgV)); + auto [tDOgDO_mkl, tDOsDO] = tma_partition( + mainloop_params.tma_load_do, _0{}, make_layout(_1{}), + group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); + + // set up lse and sum_odo + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK); + + // load K + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask), + tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tKsK(_, _0{}) + ); + } + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + // 32 threads loading kLoadPerThread * 32 values of 32b each + + int thread_idx = threadIdx.x % NumThreadsPerWarp; + int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread; + int gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); + + // load V + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask), + tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tVsV(_, _0{}) + ); + } + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + + while (iter_count > 0) { + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + } + } + + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{}); + auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{}); + auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{}); + auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}); + auto sP = make_tensor(make_smem_ptr(shared_tensors.smem_p.begin()), SmemLayoutP{}); + auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{}); + + Tensor tSTrK = TiledMmaQK::make_fragment_B(sK); + Tensor tSTrQ = TiledMmaQK::make_fragment_A(sQ); + + Tensor tDPTrV = TiledMmaDOV::make_fragment_B(sV); + Tensor tDPTrDO = TiledMmaDOV::make_fragment_A(sDO); + + Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS); + Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT); + + Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST); + Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT); + + Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP); + Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); + + TiledMmaQK tiled_mma_qk; + TiledMmaDOV tiled_mma_dov; + TiledMmaDSK tiled_mma_dsk; + TiledMmaDSQ tiled_mma_dsq; + TiledMmaPDO tiled_mma_pdo; + + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero; + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero; + + Tensor tSTtST = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(tiled_mma_dov, select<0,1>(TileShapeDOV{})); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{})); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{})); + tDKtDK.data() = TmemAllocation::kDK; + + Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{})); + tDVtDV.data() = TmemAllocation::kDV; + + auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state; + + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTrK(_,_,k_block,_0{}), + tSTtST); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + // dP = dO*V + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_dov, + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTrV(_,_,k_block,_0{}), + tDPTtDPT); + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block,_0{}), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + + // in tmem, S & P overlap + // and dP and dQ overlap + // so we need to acquire dQ and dP at the same time + while (iter_count > 0) { + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTrK(_,_,k_block,_0{}), + tSTtST); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // we need to acquire dP here, because tmem dQ == tmem dP + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + + // we grab dq here, because in tmem dq == dp + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + // dP = dO*V + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_dov, + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTrV(_,_,k_block,_0{}), + tDPTtDPT); + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block,_0{}), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + } + + // signal to the epilogue that dV is ready + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + // signal to epilgue that dK is ready + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + // we've already acquired mma_reduce_dq in the loop + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + } + + + + template + CUTLASS_DEVICE void store( + TensorG gmem, + TensorR const& regs, + TensorC const& coord, + TensorShape const& tensor_shape) { + + Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); + + auto copy_op = make_cotiled_copy( + Copy_Atom, Element>{}, + make_layout(make_shape(_1{}, Int{})), + regs.layout() + ); + auto thr_copy = copy_op.get_slice(_0{}); + + Tensor quantized_regs = quantize(regs); + Tensor tCr = thr_copy.partition_S(quantized_regs); + Tensor tCg = thr_copy.partition_D(gmem); + Tensor tPc = thr_copy.partition_D(preds); + + copy_if(copy_op, tPc, tCr, tCg); + } + + + template + CUTLASS_DEVICE void epilogue_clear( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) { + if (elem_less(cDK(i), select<1,2>(problem_shape))) { + gDK(i) = Element(0); + } + } + for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) { + if (elem_less(cDV(i), select<1,3>(problem_shape))) { + gDV(i) = Element(0); + } + } + + } + + + template + CUTLASS_DEVICE void epilogue( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + + auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDKtDK.data() = TmemAllocation::kDK; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + }; + + auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK); + auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx); + + Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK)); + Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK)); + Tensor tTR_rDK = make_tensor(shape(tTR_cDK)); + Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK)); + + auto tDVtDV = partition_fragment_C(TiledMmaPDO{}, select<0,1>(TileShapePDO{}))(make_coord(_,_),_0{},_0{}); + tDVtDV.data() = TmemAllocation::kDV; + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV); + auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx); + + Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV)); + Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV)); + Tensor tTR_rDV = make_tensor(shape(tTR_cDV)); + Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV)); + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDVtDV + cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV); + + // store tDVgDV + store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDKtDK + cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDK); i++) { + tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i); + } + + // store tDKgDK + store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + } + + + template + CUTLASS_DEVICE void compute( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + TensorStorage& shared_tensors, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + + auto [Q, K, D, D_VO, HB] = problem_shape; + + // in tmem, S & P overlap + // and dP and dQ overlap + + // there are two compute wg's that cooperatively compute softmax + // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc + + auto load_op = SM100_TMEM_LOAD_16dp32b32x{}; + + Tensor tSTtST = partition_fragment_C(TiledMmaQK{}, select<0,1>(TileShapeQK{}))(make_coord(_,_),_0{},_0{}); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(TiledMmaDOV{}, select<0,1>(TileShapeDOV{}))(make_coord(_,_),_0{},_0{}); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor cST = make_identity_tensor(take<0,2>(TileShapeQK{})); + Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeDOV{})); + Tensor cPT = make_identity_tensor(take<0,2>(TileShapeQK{})); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + auto tiled_t2r = make_tmem_copy(load_op, tSTtST); + auto thread_t2r = tiled_t2r.get_slice(dp_idx); + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(size<1>(t))::value > 1) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t)))); + return p(_, make_coord(wg_idx, _), _); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t), size<3>(t)))); + return p(_, make_coord(wg_idx, _), _, _); + } + } + else { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + } + }; + + Tensor tTR_cST_p = thread_t2r.partition_D(cST); + Tensor tTR_cST = split_wg(tTR_cST_p); + Tensor tTR_rST = make_tensor(shape(tTR_cST)); + Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); + + Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); + Tensor tTR_cPT_p = thread_t2r.partition_D(cPT); + Tensor tTR_cDPT = split_wg(tTR_cDPT_p); + Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); + Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT)); + + Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{}); + Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{}); + + bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape); + int last_iter = iter_count - 1 + iter_index; + + CUTLASS_PRAGMA_NO_UNROLL + while (iter_count > 0) { + // wait for S and P + pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state); + pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state); + // wait for LSE + pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + bool leading_causal_masking = false; + if constexpr (std::is_base_of_v, Mask> + || std::is_base_of_v, Mask>) { + leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord)); + } + bool trailing_residual_masking = false; + if constexpr (std::is_base_of_v) { + trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k); + } + + dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { + + // compute P = softmax(S, LSE) + cute::copy(tiled_t2r, tTR_tST, tTR_rST); + + if constexpr (decltype(is_masked_tile)::value) { + Mask{}.apply_mask(tTR_rST, [&](int i) { + auto c_transpose = tTR_cST(i); + return make_coord(get<0>(c_transpose) + iter_index * TileShapeQ{}, get<1>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); + }, problem_shape); + } + + ElementAcc log2_e = static_cast(M_LOG2E); + float2 softmax_scale_log2_e; + softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; + softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rST); i += 2) { + float2 acc; + float2 lse; + float2 out; + acc.x = tTR_rST(i); + acc.y = tTR_rST(i + 1); + lse.x = sLSE(get<0>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index()); + lse.y = sLSE(get<0>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index()); + cute::fma(out, softmax_scale_log2_e, acc, lse); + tTR_rST(i) = ::exp2f(out.x); + tTR_rST(i+1) = ::exp2f(out.y); + } + + auto tRT_rST = quantize(tTR_rST); + + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{}) + (_, _, _, pipeline_compute_mma_p_producer_state.index()); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransformBarrier + ).arrive_and_wait(); + + auto sP_pi = as_position_independent_swizzle_tensor(sP); + + auto thread_layout = make_ordered_layout( + make_shape(_64{}, _32{}, _2{}, _2{}), + make_stride(_3{}, _0{}, _1{}, _2{}) + ); + auto sP_pi_slice_p = sP_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape(tTR_cPT_p))); + auto sP_pi_slice = split_wg(sP_pi_slice_p); + copy_aligned(tRT_rST, sP_pi_slice); + }); + + // notify for P + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state); + ++pipeline_compute_mma_p_producer_state; + // release S + pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state); + ++pipeline_mma_compute_s_consumer_state; + // release LSE + pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state); + ++pipeline_load_compute_lse_consumer_state; + + // wait for OdO + pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state); + // wait for dP + pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state); + + // wait for dS + // in principle, we could defer waiting for dS, and move in the freeing of dP + // however, that would force us to keep dS in registers longer + pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state); + + // compute dS = dsoftmax(P, dP, sum_OdO) + cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDPT); i += 2) { + float2 st; + st.x = tTR_rST(i); + st.y = tTR_rST(i+1); + float2 dpt; + dpt.x = tTR_rDPT(i); + dpt.y = tTR_rDPT(i+1); + float2 odo; + odo.x = sSumOdO(get<0>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index()); + odo.y = sSumOdO(get<0>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index()); + float2 dif; + // sum odo is negated during preprocess + cute::add(dif, dpt, odo); + float2 out; + cute::mul(out, dif, st); + tTR_rDPT(i) = out.x; + tTR_rDPT(i+1) = out.y; + } + + auto tTR_rDST = quantize(tTR_rDPT); + + // release dP + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state); + ++pipeline_mma_compute_dp_consumer_state; + + Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}) + (_, _, _, pipeline_compute_mma_ds_producer_state.index()); + + auto thread_layout = make_ordered_layout( + make_shape(_64{}, _32{}, _2{}, _2{}), + make_stride(_3{}, _0{}, _1{}, _2{}) + ); + auto sDS_pi = as_position_independent_swizzle_tensor(sDS); + auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape (tTR_cDPT_p))); + auto sDS_pi_slice = split_wg(sDS_pi_slice_p); + + copy_aligned(tTR_rDST, sDS_pi_slice); + + // notify for dS + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state); + ++pipeline_compute_mma_ds_producer_state; + // release OdO + pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state); + ++pipeline_load_compute_sum_odo_consumer_state; + + iter_count -= 1; + iter_index += 1; + } + + epilogue( + blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + } + + template + CUTLASS_DEVICE void reduce( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, + PipelineReduceTmaStore& pipeline_reduce_tma_store, + typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { + + using X = Underscore; + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + // must match TileShapeDQ + auto load_op = SM100_TMEM_LOAD_16dp32b16x{}; + + auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{}); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); + auto gDQ = local_tile(mDQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}) + (_, _, _, _0{}, blk_coord_batch); + + Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); + + Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{}); + + int thread_idx = threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp); + auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ); + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + + Tensor tTR_cDQ = thread_t2r.partition_D(cDQ); + Tensor tTR_gDQ = thread_t2r.partition_D(gDQ); + Tensor tTR_sDQ = thread_t2r.partition_D(sDQ); + Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ); + + auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{}); + + Tensor tDQsDQ = block_tma.partition_S(sDQ); + Tensor tDQcDQ = block_tma.partition_S(cDQ); + Tensor tDQgDQ = block_tma.partition_D(gDQ); + + int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0; + + while (iter_count > 0) { + pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state); + + Tensor tTR_rDQ = make_tensor(shape(tTR_cDQ)); + + // load dQ from tmem to rmem + cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state); + ++pipeline_mma_reduce_dq_consumer_state; + + // we don't have enough smem to dump it all to smem, so we do it in stages + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tTR_cDQ); i++) { + if (lane_predicate) { + pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state); + } + // wait in all threads for the acquire to complete + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + + cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index())); + + // wait for the stores to all be visible to the TMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + if (lane_predicate) { + // launch tma store + copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index)); + pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); + } + + ++pipeline_reduce_tma_store_producer_state; + } + + iter_count -= 1; + iter_index += 1; + } + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem) { + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor()); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + int initializing_warp = 0; + typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params; + if (role == WarpRole::Load) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer; + } + pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads K in the first iteration + pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ; + pipeline_load_mma_q_params.initializing_warp = initializing_warp++; + PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params; + if (role == WarpRole::Load) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer; + } + pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads V in the first iteration + pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO; + pipeline_load_mma_do_params.initializing_warp = initializing_warp++; + PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params; + if (role == WarpRole::Load) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer; + } + pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_lse_params.initializing_warp = initializing_warp++; + PipelineLoadComputeLSE pipeline_load_compute_lse( + shared_storage.pipelines.load_compute_lse, + pipeline_load_compute_lse_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params; + if (role == WarpRole::Load) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer; + } + pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++; + PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo( + shared_storage.pipelines.load_compute_sum_odo, + pipeline_load_compute_sum_odo_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer; + } + pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_s_params.initializing_warp = initializing_warp++; + PipelineMmaComputeS pipeline_mma_compute_s( + shared_storage.pipelines.mma_compute_s, + pipeline_mma_compute_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer; + } + pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDP pipeline_mma_compute_dp( + shared_storage.pipelines.mma_compute_dp, + pipeline_mma_compute_dp_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params; + if (role == WarpRole::Mma) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer; + } + if (role == WarpRole::Reduce) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer; + } + pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++; + PipelineMmaReduceDQ pipeline_mma_reduce_dq( + shared_storage.pipelines.mma_reduce_dq, + pipeline_mma_reduce_dq_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer; + } + pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_p_params.consumer_arv_count = 1; + pipeline_compute_mma_p_params.initializing_warp = initializing_warp++; + PipelineComputeMmaP pipeline_compute_mma_p( + shared_storage.pipelines.compute_mma_p, + pipeline_compute_mma_p_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer; + } + pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_ds_params.consumer_arv_count = 1; + pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++; + PipelineComputeMmaDS pipeline_compute_mma_ds( + shared_storage.pipelines.compute_mma_ds, + pipeline_compute_mma_ds_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer; + } + pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDKDV pipeline_mma_compute_dkdv( + shared_storage.pipelines.mma_compute_dkdv, + pipeline_mma_compute_dkdv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + PipelineReduceTmaStore pipeline_reduce_tma_store; + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_mma_q.init_masks(ClusterShape{}); + pipeline_load_mma_do.init_masks(ClusterShape{}); + pipeline_mma_compute_s.init_masks(ClusterShape{}); + pipeline_mma_compute_dp.init_masks(ClusterShape{}); + pipeline_mma_reduce_dq.init_masks(ClusterShape{}); + pipeline_compute_mma_p.init_masks(ClusterShape{}); + pipeline_compute_mma_ds.init_masks(ClusterShape{}); + pipeline_mma_compute_dkdv.init_masks(ClusterShape{}); + + typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state; + typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state; + typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state; + typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state; + typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state; + typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state; + typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state; + typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; + typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; + typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; + + auto pipeline_load_mma_q_producer_state = make_producer_start_state(); + auto pipeline_load_mma_do_producer_state = make_producer_start_state(); + auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); + auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_s_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dp_producer_state = make_producer_start_state(); + auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_p_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_ds_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state(); + auto pipeline_reduce_tma_store_producer_state = make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z)); + auto [problem_shape, blk_offset] = apply_variable_length_offset( + params.problem_shape, + blk_coord + ); + int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{}); + int iter_start = 0; + if constexpr (std::is_base_of_v, Mask> + || std::is_base_of_v, Mask>) { + iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; + } + if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { + return; + } + iter_count -= iter_start; + + if (iter_count <= 0) { + epilogue_clear( + blk_coord, + blk_offset, + problem_shape, + params.mainloop, + params.epilogue + ); + return; + } + + if (role == WarpRole::Load) { + warpgroup_reg_set(); + + load( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_producer_state, + pipeline_load_mma_do, pipeline_load_mma_do_producer_state, + pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state + ); + + } + else if (role == WarpRole::Mma) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + mma( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, + pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state + ); + + } + else if (role == WarpRole::Compute) { + warpgroup_reg_set(); + + compute( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.epilogue, + shared_storage.tensors, + pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier + ).arrive_and_wait(); + + if (warp_idx % kNumComputeWarps == 0) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Reduce) { + warpgroup_reg_set(); + + reduce( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state, + pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state + ); + + pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state); + } + else { + warpgroup_reg_set(); + + /* no-op */ + + } + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static dim3 get_grid_shape(Params const& params) { + auto [Q, K, D, D_VO, HB] = params.problem_shape; + auto [H, B] = HB; + dim3 grid(ceil_div(K, TileShapeK{}), H, B); + return grid; + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp index 66883af4..6f4a0c7a 100644 --- a/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp +++ b/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp @@ -33,7 +33,9 @@ #pragma once #include "cute/tensor.hpp" +#include "collective/fmha_fusion.hpp" +using namespace cutlass::fmha::collective; ///////////////////////////////////////////////////////////////////////////////////////////////// template< @@ -61,20 +63,20 @@ void __global__ fmha_bwd_reference_dQ_kernel( ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in))); - for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) { + for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) { auto [problem_shape, offset] = apply_variable_length_offset( problem_shape_in, - make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in))) + make_coord(_0{}, _0{}, _0{}, _0{},idx2crd(idx_L, get<4>(problem_shape_in))) ); // problem_shape = problem_shape_in; // offset = repeat_like(problem_shape_in, _0{}); - auto mQ = domain_offset(select<0,2,3>(offset), mQ_in); - auto mK = domain_offset(select<1,2,3>(offset), mK_in); - auto mV = domain_offset(select<1,2,3>(offset), mV_in); - auto mO = domain_offset(select<0,2,3>(offset), mO_in); - auto mLSE = domain_offset(select<0,3>(offset), mLSE_in); - auto mDO = domain_offset(select<0,2,3>(offset), mDO_in); - auto mDQ = domain_offset(select<0,2,3>(offset), mDQ_in); + auto mQ = domain_offset(select<0,2,4>(offset), mQ_in); + auto mK = domain_offset(select<1,2,4>(offset), mK_in); + auto mV = domain_offset(select<1,3,4>(offset), mV_in); + auto mO = domain_offset(select<0,3,4>(offset), mO_in); + auto mLSE = domain_offset(select<0,4>(offset), mLSE_in); + auto mDO = domain_offset(select<0,3,4>(offset), mDO_in); + auto mDQ = domain_offset(select<0,2,4>(offset), mDQ_in); for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape); idx_Q += gridDim.x) { for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) { ElementAccumulator acc_qk = 0; @@ -82,10 +84,15 @@ void __global__ fmha_bwd_reference_dQ_kernel( ElementAccumulator acc_doo = 0; for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) { acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); - acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); - acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); + // acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); + // acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); } // for idx_D0 + for (int idx_D1 = 0; idx_D1 < size<3>(problem_shape); idx_D1++) { + acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L); + acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L); + } + auto id = make_identity_tensor(make_shape(1, 1)); auto frag = make_tensor(Shape<_1, _1>{}); frag(0) = acc_qk; @@ -135,20 +142,20 @@ void __global__ fmha_bwd_reference_dK_kernel( ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in))); - for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) { + for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) { auto [problem_shape, offset] = apply_variable_length_offset( problem_shape_in, - make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in))) + make_coord(_0{}, _0{}, _0{}, _0{}, idx2crd(idx_L, get<4>(problem_shape_in))) ); // problem_shape = problem_shape_in; // offset = repeat_like(problem_shape_in, _0{}); - auto mQ = domain_offset(select<0,2,3>(offset), mQ_in); - auto mK = domain_offset(select<1,2,3>(offset), mK_in); - auto mV = domain_offset(select<1,2,3>(offset), mV_in); - auto mO = domain_offset(select<0,2,3>(offset), mO_in); - auto mLSE = domain_offset(select<0,3>(offset), mLSE_in); - auto mDO = domain_offset(select<0,2,3>(offset), mDO_in); - auto mDK = domain_offset(select<1,2,3>(offset), mDK_in); + auto mQ = domain_offset(select<0,2,4>(offset), mQ_in); + auto mK = domain_offset(select<1,2,4>(offset), mK_in); + auto mV = domain_offset(select<1,3,4>(offset), mV_in); + auto mO = domain_offset(select<0,3,4>(offset), mO_in); + auto mLSE = domain_offset(select<0,4>(offset), mLSE_in); + auto mDO = domain_offset(select<0,3,4>(offset), mDO_in); + auto mDK = domain_offset(select<1,2,4>(offset), mDK_in); for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) { for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) { ElementAccumulator acc_qk = 0; @@ -156,10 +163,14 @@ void __global__ fmha_bwd_reference_dK_kernel( ElementAccumulator acc_doo = 0; for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) { acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); - acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); - acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); + // acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); + // acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); } // for idx_D0 - + + for (int idx_D1 = 0; idx_D1 < size<3>(problem_shape); idx_D1++) { + acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L); + acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L); + } auto id = make_identity_tensor(make_shape(1, 1)); auto frag = make_tensor(Shape<_1, _1>{}); frag(0) = acc_qk; @@ -209,20 +220,20 @@ void __global__ fmha_bwd_reference_dV_kernel( ElementAcc softmax_scale = 1.0 / sqrt(ElementAcc(size<2>(problem_shape_in))); - for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) { + for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) { auto [problem_shape, offset] = apply_variable_length_offset( problem_shape_in, - make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in))) + make_coord(_0{}, _0{}, _0{}, _0{}, idx2crd(idx_L, get<4>(problem_shape_in))) ); // problem_shape = problem_shape_in; // offset = repeat_like(problem_shape_in, _0{}); - auto mQ = domain_offset(select<0,2,3>(offset), mQ_in); - auto mK = domain_offset(select<1,2,3>(offset), mK_in); - auto mV = domain_offset(select<1,2,3>(offset), mV_in); - auto mO = domain_offset(select<0,2,3>(offset), mO_in); - auto mLSE = domain_offset(select<0,3>(offset), mLSE_in); - auto mDO = domain_offset(select<0,2,3>(offset), mDO_in); - auto mDV = domain_offset(select<1,2,3>(offset), mDV_in); + auto mQ = domain_offset(select<0,2,4>(offset), mQ_in); + auto mK = domain_offset(select<1,2,4>(offset), mK_in); + auto mV = domain_offset(select<1,3,4>(offset), mV_in); + auto mO = domain_offset(select<0,3,4>(offset), mO_in); + auto mLSE = domain_offset(select<0,4>(offset), mLSE_in); + auto mDO = domain_offset(select<0,3,4>(offset), mDO_in); + auto mDV = domain_offset(select<1,3,4>(offset), mDV_in); for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) { for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) { ElementAcc acc_qk = 0; @@ -244,7 +255,7 @@ void __global__ fmha_bwd_reference_dV_kernel( __syncthreads(); - for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) { + for (int idx_D = threadIdx.x; idx_D < size<3>(problem_shape); idx_D += blockDim.x) { ElementAcc acc = 0; for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) { ElementAcc rS = static_cast(mS[idx_Q]); diff --git a/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp index bcd482f9..d674ee95 100644 --- a/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp +++ b/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp @@ -62,19 +62,20 @@ void __global__ fmha_reference_kernel( ElementAccumulator softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mQ))); auto id = make_identity_tensor(make_shape(1, 1)); - for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) { + + for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) { for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape_in); idx_Q += gridDim.x) { - auto coord_L = idx2crd(idx_L, shape<3>(problem_shape_in)); + auto coord_L = idx2crd(idx_L, shape<4>(problem_shape_in)); auto get_coord_in = [&]() { if constexpr (rank_v(ProblemShapeIn{}))> == 2) { - return cute::make_tuple(idx_Q, _0{}, cute::make_tuple(_0{}, _0{}), coord_L); + return cute::make_tuple(idx_Q, _0{}, cute::make_tuple(_0{}, _0{}), cute::make_tuple(_0{}, _0{}), coord_L); } else { - return cute::make_tuple(idx_Q, _0{}, _0{}, coord_L); + return cute::make_tuple(idx_Q, _0{}, _0{}, _0{}, coord_L); } }; auto coord_in = get_coord_in(); - auto [problem_shape, coord] = apply_variable_length(problem_shape_in, coord_in, get<3,1>(coord_in)); + auto [problem_shape, coord] = apply_variable_length(problem_shape_in, coord_in, get<4,1>(coord_in)); int head_qk = 0; int head_v = 0; @@ -83,7 +84,7 @@ void __global__ fmha_reference_kernel( head_qk = size<2, 0>(problem_shape) + size<2, 1>(problem_shape); head_v = size<2, 0>(problem_shape); } else { - head_qk = size<2>(problem_shape); + head_qk = size<3>(problem_shape); head_v = head_qk; } @@ -157,6 +158,7 @@ void __global__ fmha_reference_kernel( mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast(acc * scale); } + if (threadIdx.x == 0 && mLSE.data() != nullptr) { mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS; } diff --git a/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu b/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu index 48df9108..3342eebb 100644 --- a/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu +++ b/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu @@ -835,7 +835,7 @@ int run(Options &options, bool host_problem_shapes_available = true) } } else { - std::cout << " Verfication is turned off for this run." << std::endl; + std::cout << " Verification is turned off for this run." << std::endl; } // Run profiling loop diff --git a/examples/cute/tutorial/blackwell/01_mma_sm100.cu b/examples/cute/tutorial/blackwell/01_mma_sm100.cu index a11fb17c..d2d5a068 100644 --- a/examples/cute/tutorial/blackwell/01_mma_sm100.cu +++ b/examples/cute/tutorial/blackwell/01_mma_sm100.cu @@ -259,7 +259,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // Step 2: The Mainloop. - // Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator. + // Set mma accumulate option to zero so that the first MMA instruction will clear the TMEM accumulator. tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; // Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM @@ -394,7 +394,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A, // In SM100, the MMAs are Cluster-local and perform CTA-level partitioning. // Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA // and SM100 uses a mma_tiler to extract portions of the Problem for the MMA. - // The MMA's partitioning then yeilds the CTA-local work. + // The MMA's partitioning then yields the CTA-local work. if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) { std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl; diff --git a/examples/cute/tutorial/blackwell/02_mma_tma_sm100.cu b/examples/cute/tutorial/blackwell/02_mma_tma_sm100.cu index 4ce2f4a8..992233fd 100644 --- a/examples/cute/tutorial/blackwell/02_mma_tma_sm100.cu +++ b/examples/cute/tutorial/blackwell/02_mma_tma_sm100.cu @@ -295,7 +295,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // Step 2: The Mainloop. - // Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator. + // Set mma accumulate option to zero so that the first MMA instruction will clear the TMEM accumulator. tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; // Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM @@ -433,7 +433,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A, // In SM100, the MMAs are Cluster-local and perform CTA-level partitioning. // Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA // and SM100 uses a mma_tiler to extract portions of the Problem for the MMA. - // The MMA's partitioning then yeilds the CTA-local work. + // The MMA's partitioning then yields the CTA-local work. if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) { std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl; diff --git a/examples/cute/tutorial/blackwell/03_mma_tma_multicast_sm100.cu b/examples/cute/tutorial/blackwell/03_mma_tma_multicast_sm100.cu index bc788bad..1626238c 100644 --- a/examples/cute/tutorial/blackwell/03_mma_tma_multicast_sm100.cu +++ b/examples/cute/tutorial/blackwell/03_mma_tma_multicast_sm100.cu @@ -333,7 +333,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // Step 2: The Mainloop. - // Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator. + // Set mma accumulate option to zero so that the first MMA instruction will clear the TMEM accumulator. tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; // Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM @@ -471,7 +471,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A, // In SM100, the MMAs are Cluster-local and perform CTA-level partitioning. // Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA // and SM100 uses a mma_tiler to extract portions of the Problem for the MMA. - // The MMA's partitioning then yeilds the CTA-local work. + // The MMA's partitioning then yields the CTA-local work. if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) { std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl; diff --git a/examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu b/examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu index 9b17cd59..1fdf1edd 100644 --- a/examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu +++ b/examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu @@ -328,7 +328,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // Step 2: The Mainloop. - // Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator. + // Set mma accumulate option to zero so that the first MMA instruction will clear the TMEM accumulator. tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; // Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM @@ -473,7 +473,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A, // In SM100, the MMAs are Cluster-local and perform CTA-level partitioning. // Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA // and SM100 uses a mma_tiler to extract portions of the Problem for the MMA. - // The MMA's partitioning then yeilds the CTA-local work. + // The MMA's partitioning then yields the CTA-local work. if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) { std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl; diff --git a/examples/cute/tutorial/blackwell/05_mma_tma_epi_sm100.cu b/examples/cute/tutorial/blackwell/05_mma_tma_epi_sm100.cu index 44b97587..82649eb5 100644 --- a/examples/cute/tutorial/blackwell/05_mma_tma_epi_sm100.cu +++ b/examples/cute/tutorial/blackwell/05_mma_tma_epi_sm100.cu @@ -341,7 +341,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // Step 2: The Mainloop. - // Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator. + // Set mma accumulate option to zero so that the first MMA instruction will clear the TMEM accumulator. tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; // Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM @@ -527,7 +527,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A, // In SM100, the MMAs are Cluster-local and perform CTA-level partitioning. // Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA // and SM100 uses a mma_tiler to extract portions of the Problem for the MMA. - // The MMA's partitioning then yeilds the CTA-local work. + // The MMA's partitioning then yields the CTA-local work. if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) { std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl; diff --git a/examples/cute/tutorial/tiled_copy.cu b/examples/cute/tutorial/tiled_copy.cu index 3cdc2784..e82e4f27 100644 --- a/examples/cute/tutorial/tiled_copy.cu +++ b/examples/cute/tutorial/tiled_copy.cu @@ -200,7 +200,7 @@ int main(int argc, char** argv) // Construct tiled copy, a tiling of copy atoms. // - // Note, this assumes the vector and thread layouts are aligned with contigous data + // Note, this assumes the vector and thread layouts are aligned with contiguous data // in GMEM. Alternative thread layouts are possible but may result in uncoalesced // reads. Alternative value layouts are also possible, though incompatible layouts // will result in compile time errors. diff --git a/examples/python/CuTeDSL/ampere/elementwise_add.py b/examples/python/CuTeDSL/ampere/elementwise_add.py index 6b244b01..596822cc 100644 --- a/examples/python/CuTeDSL/ampere/elementwise_add.py +++ b/examples/python/CuTeDSL/ampere/elementwise_add.py @@ -90,18 +90,17 @@ If you already know the TV layout you want to use for your tiled copy, CuTe DSL # Tile input tensor to thread blocks: ((TileM,TileN),(RestM,RestN)) gA = cute.zipped_divide(mA, tiler_mn) -where `tiler_mn` is the tile size per thread block and `tv_layout` is the TV layout which maps -thread index and inter-thread index of data array per thread to logical coordinates of elements in -input and output tensors. - -Then we can build tiled copy for input and output tensors with `cute.make_tiled_copy` utility. +Then we can build tiled copy for input and output tensors with `cute.make_tiled_copy_tv` utility, which +infers the tiler and tv layout for the tiled copy automatically, where `tiler` is the tile size per thread +block and `tv_layout` is the TV layout which maps thread index and inter-thread index of data array per +thread to logical coordinates of elements in input and output tensors. .. code-block:: python blkA = gA[((None, None), bidx)] # (TileM,TileN) copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type) - tiled_copy_A = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn) + tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout) # get slice of tiled_copy_A for current thread thr_copy_A = tiled_copy_A.get_slice(tidx) @@ -140,8 +139,8 @@ def elementwise_add_kernel( gC: cute.Tensor, cC: cute.Tensor, # coordinate tensor shape: cute.Shape, - tv_layout: cute.Layout, - tiler_mn: cute.Shape, + thr_layout: cute.Layout, + val_layout: cute.Layout, ): tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() @@ -165,9 +164,9 @@ def elementwise_add_kernel( copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type) copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type) - tiled_copy_A = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn) - tiled_copy_B = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn) - tiled_copy_C = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn) + tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout) + tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout) + tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout) thr_copy_A = tiled_copy_A.get_slice(tidx) thr_copy_B = tiled_copy_B.get_slice(tidx) @@ -254,7 +253,7 @@ def elementwise_add(mA, mB, mC, copy_bits: cutlass.Constexpr = 128): cC = cute.zipped_divide(idC, tiler=tiler_mn) print(f"[DSL INFO] coord tensor = {cC.type}") - elementwise_add_kernel(gA, gB, gC, cC, mC.shape, tv_layout, tiler_mn).launch( + elementwise_add_kernel(gA, gB, gC, cC, mC.shape, thr_layout, val_layout).launch( grid=[cute.size(gC, mode=[1]), 1, 1], block=[cute.size(tv_layout, mode=[0]), 1, 1], ) @@ -362,7 +361,7 @@ def run_elementwise_add( workspace_generator=generate_tensors, workspace_count=10, warmup_iterations=warmup_iterations, - profiling_iterations=iterations, + iterations=iterations, ) # Print execution results diff --git a/examples/python/CuTeDSL/ampere/elementwise_apply.py b/examples/python/CuTeDSL/ampere/elementwise_apply.py index 649c7789..43c3b5fc 100644 --- a/examples/python/CuTeDSL/ampere/elementwise_apply.py +++ b/examples/python/CuTeDSL/ampere/elementwise_apply.py @@ -353,7 +353,7 @@ def run_elementwise_apply_and_verify( current_stream, ), warmup_iterations=warmup_iterations, - profiling_iterations=iterations, + iterations=iterations, use_cuda_graphs=True, stream=current_stream, ) diff --git a/examples/python/CuTeDSL/ampere/flash_attention_v2.py b/examples/python/CuTeDSL/ampere/flash_attention_v2.py index b36ec1d1..441238cc 100644 --- a/examples/python/CuTeDSL/ampere/flash_attention_v2.py +++ b/examples/python/CuTeDSL/ampere/flash_attention_v2.py @@ -32,13 +32,13 @@ from typing import Type, Union, Callable import torch import cuda.bindings.driver as cuda - +import cutlass.cute.testing as testing import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync, warp import cutlass.torch as cutlass_torch from cutlass.cute.runtime import from_dlpack -import cutlass.utils.ampere_helpers as sm80_utils +import cutlass.utils as utils """ A flash attention v2 forward pass example for NVIDIA Ampere SM80 architecture using CUTE DSL. @@ -163,7 +163,7 @@ class FlashAttentionForwardAmpere: # Check if block size setting is out of shared memory capacity # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size smem_usage = (m_block_size * head_dim + n_block_size * head_dim * 2) * 2 - smem_capacity = sm80_utils.SMEM_CAPACITY["sm80"] + smem_capacity = utils.get_smem_capacity_in_bytes("sm_80") if smem_usage > smem_capacity: return False @@ -469,21 +469,9 @@ class FlashAttentionForwardAmpere: warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self._dtype, ) - smem_tiled_copy_Q = cute.make_tiled_copy( - smem_copy_atom_Q, - layout_tv=tiled_mma.tv_layout_A_tiled, - tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)), - ) - smem_tiled_copy_K = cute.make_tiled_copy( - smem_copy_atom_K, - layout_tv=tiled_mma.tv_layout_B_tiled, - tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)), - ) - smem_tiled_copy_V = cute.make_tiled_copy( - smem_copy_atom_V, - layout_tv=tiled_mma.tv_layout_B_tiled, - tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)), - ) + smem_tiled_copy_Q = cute.make_tiled_copy_A(smem_copy_atom_Q, tiled_mma) + smem_tiled_copy_K = cute.make_tiled_copy_B(smem_copy_atom_K, tiled_mma) + smem_tiled_copy_V = cute.make_tiled_copy_B(smem_copy_atom_V, tiled_mma) smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx) smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx) @@ -702,11 +690,7 @@ class FlashAttentionForwardAmpere: cute.nvgpu.CopyUniversalOp(), self._dtype ) # tiled copy atom for O - smem_tiled_copy_O = cute.make_tiled_copy( - smem_copy_atom_O, - layout_tv=tiled_mma.tv_layout_C_tiled, - tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(1)), - ) + smem_tiled_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma) smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx) taccOrO = smem_thr_copy_O.retile(rO) taccOsO = smem_thr_copy_O.partition_D(sO) @@ -1178,7 +1162,7 @@ class FlashAttentionForwardAmpere: return cute.arch.exp2(x) -def run_flash_attention_fwd( +def run( dtype: Type[cutlass.Numeric], batch_size: int, seqlen_q: int, @@ -1193,6 +1177,8 @@ def run_flash_attention_fwd( warmup_iterations: int = 0, iterations: int = 1, skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, ): # Skip unsupported testcase if not FlashAttentionForwardAmpere.can_implement( @@ -1207,6 +1193,23 @@ def run_flash_attention_fwd( f"Unsupported testcase {dtype}, {head_dim}, {m_block_size}, {n_block_size}, {num_threads}, {is_causal}" ) + print(f"Running Ampere SM80 FlashAttentionForward test with:") + print(f" dtype: {dtype}") + print(f" batch_size: {batch_size}") + print(f" seqlen_q: {seqlen_q}") + print(f" seqlen_k: {seqlen_k}") + print(f" num_head: {num_head}") + print(f" head_dim: {head_dim}") + print(f" softmax_scale: {softmax_scale}") + print(f" m_block_size: {m_block_size}") + print(f" n_block_size: {n_block_size}") + print(f" num_threads: {num_threads}") + print(f" is_causal: {is_causal}") + print(f" warmup_iterations: {warmup_iterations}") + print(f" iterations: {iterations}") + print(f" skip_ref_check: {skip_ref_check}") + print(f" use_cold_l2: {use_cold_l2}") + # Create tensor Q/K/V/O def create_tensor( batch_size: int, @@ -1217,22 +1220,28 @@ def run_flash_attention_fwd( ) -> cute.Tensor: # (batch_size, seqlen, num_head, head_dim) shape = (batch_size, seqlen, num_head, head_dim) - return ( - torch.empty(*shape, dtype=torch.int32).random_(-2, 2).to(dtype=dtype).cuda() + torch_tensor = ( + torch.empty(*shape, dtype=torch.int32) + .random_(-2, 2) + .to(dtype=cutlass_torch.dtype(dtype)) + .cuda() ) + # assume input is 16B aligned. + cute_tensor = ( + from_dlpack(torch_tensor, assumed_align=16) + .mark_layout_dynamic(leading_dim=3) + .mark_compact_shape_dynamic( + mode=3, + stride_order=torch_tensor.dim_order(), + divisibility=(128 // dtype.width), + ) + ) + return cute_tensor, torch_tensor - q = create_tensor( - batch_size, seqlen_q, num_head, head_dim, cutlass_torch.dtype(dtype) - ) - k = create_tensor( - batch_size, seqlen_k, num_head, head_dim, cutlass_torch.dtype(dtype) - ) - v = create_tensor( - batch_size, seqlen_k, num_head, head_dim, cutlass_torch.dtype(dtype) - ) - o = create_tensor( - batch_size, seqlen_q, num_head, head_dim, cutlass_torch.dtype(dtype) - ) + q, q_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype) + k, k_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype) + v, v_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype) + o, o_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype) fa2_fwd = FlashAttentionForwardAmpere( head_dim, @@ -1241,78 +1250,63 @@ def run_flash_attention_fwd( num_threads, is_causal, ) - # assume input is 16B align. - q_tensor = ( - from_dlpack(q, assumed_align=16) - .mark_layout_dynamic(leading_dim=3) - .mark_compact_shape_dynamic( - mode=3, stride_order=q.dim_order(), divisibility=(128 // dtype.width) - ) - ) - k_tensor = ( - from_dlpack(k, assumed_align=16) - .mark_layout_dynamic(leading_dim=3) - .mark_compact_shape_dynamic( - mode=3, stride_order=k.dim_order(), divisibility=(128 // dtype.width) - ) - ) - v_tensor = ( - from_dlpack(v, assumed_align=16) - .mark_layout_dynamic(leading_dim=3) - .mark_compact_shape_dynamic( - mode=3, stride_order=v.dim_order(), divisibility=(128 // dtype.width) - ) - ) - o_tensor = ( - from_dlpack(o, assumed_align=16) - .mark_layout_dynamic(leading_dim=3) - .mark_compact_shape_dynamic( - mode=3, stride_order=o.dim_order(), divisibility=(128 // dtype.width) - ) - ) + # Get current CUDA stream from PyTorch torch_stream = torch.cuda.current_stream() # Get the raw stream pointer as a CUstream current_stream = cuda.CUstream(torch_stream.cuda_stream) # compile the fa2 forward pass - compiled_fa2_fwd = cute.compile( - fa2_fwd, q_tensor, k_tensor, v_tensor, o_tensor, softmax_scale, current_stream + compiled_fa2_fwd = cute.compile(fa2_fwd, q, k, v, o, softmax_scale, current_stream) + + if not skip_ref_check: + compiled_fa2_fwd(q, k, v, o, softmax_scale, current_stream) + torch.cuda.synchronize() + q_ref = q_torch.permute(0, 2, 1, 3) + k_ref = k_torch.permute(0, 2, 1, 3) + v_ref = v_torch.permute(0, 2, 1, 3) + torch.backends.cuda.enable_flash_sdp(enabled=True) + ref_o = torch.nn.functional.scaled_dot_product_attention( + q_ref, k_ref, v_ref, scale=softmax_scale, is_causal=is_causal + ).permute(0, 2, 1, 3) + torch.testing.assert_close(o_torch.cpu(), ref_o.cpu(), atol=1e-02, rtol=1e-04) + print("Results verified successfully!") + + def generate_tensors(): + q_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype) + k_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype) + v_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype) + o_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype) + return testing.JitArguments( + q_workspace, + k_workspace, + v_workspace, + o_workspace, + softmax_scale, + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + q_torch.numel() * q_torch.element_size() + + k_torch.numel() * k_torch.element_size() + + v_torch.numel() * v_torch.element_size() + + o_torch.numel() * o_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + avg_time_us = testing.benchmark( + compiled_fa2_fwd, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, ) - # warmup - for _ in range(warmup_iterations): - compiled_fa2_fwd( - q_tensor, - k_tensor, - v_tensor, - o_tensor, - softmax_scale, - current_stream, - ) - # run the compiled fa2 forward pass - for _ in range(iterations): - compiled_fa2_fwd( - q_tensor, - k_tensor, - v_tensor, - o_tensor, - softmax_scale, - current_stream, - ) - torch.cuda.synchronize() - - if skip_ref_check: - return - # reference implementation - q_ref = q.permute(0, 2, 1, 3) - k_ref = k.permute(0, 2, 1, 3) - v_ref = v.permute(0, 2, 1, 3) - torch.backends.cuda.enable_flash_sdp(enabled=True) - ref_o = torch.nn.functional.scaled_dot_product_attention( - q_ref, k_ref, v_ref, scale=softmax_scale, is_causal=is_causal - ).permute(0, 2, 1, 3) - - torch.testing.assert_close(o.cpu(), ref_o.cpu(), atol=1e-02, rtol=1e-04) + return avg_time_us # Return execution time in microseconds if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -1334,9 +1328,15 @@ if __name__ == "__main__": parser.add_argument( "--skip_ref_check", action="store_true", help="Skip reference check" ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) args = parser.parse_args() - run_flash_attention_fwd( + run( args.dtype, args.batch_size, args.seqlen_q, @@ -1348,6 +1348,10 @@ if __name__ == "__main__": args.n_block_size, args.num_threads, args.is_causal, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, ) print("PASS") diff --git a/examples/python/CuTeDSL/ampere/sgemm.py b/examples/python/CuTeDSL/ampere/sgemm.py index 474366e5..8058f24d 100644 --- a/examples/python/CuTeDSL/ampere/sgemm.py +++ b/examples/python/CuTeDSL/ampere/sgemm.py @@ -634,16 +634,50 @@ class SGemm: return -def main( +def run( + mnk: Tuple[int, int, int], a_major: str, b_major: str, c_major: str, - problem_shape: Tuple[int, int, int], + static_shape: bool = False, warmup_iterations: int = 2, iterations: int = 100, skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, ): - M, N, K = problem_shape + """Execute SIMT GEMM operation and benchmark performance. + + :param mnk: GEMM problem size (M, N, K, L) + :type mnk: Tuple[int, int, int, int] + :param a_major: Memory layout of tensor A + :type a_major: str + :param b_major: Memory layout of tensor B + :type b_major: str + :param c_major: Memory layout of tensor C + :type c_major: str + :param static_shape: Whether to use static shape optimization, defaults to False + :type static_shape: bool, optional + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 2 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 100 + :type iterations: int, optional + :param skip_ref_check: Skip validation against reference implementation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :return: Execution time of the GEMM kernel in microseconds + :rtype: float + """ + print(f"Running Ampere SIMT GEMM example:") + print(f"mnk: {mnk}") + print(f"A major: {a_major}, B major: {b_major}, C major: {c_major}") + print(f"Static shape: {static_shape}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {use_cold_l2}") + M, N, K = mnk # Create and permute tensor A/B/C def create_and_permute_tensor(mode0, mode1, is_mode0_major, dtype): @@ -710,20 +744,6 @@ def main( print("Executing GEMM kernel...") - avg_time_us = testing.benchmark( - gemm, - kernel_arguments=testing.JitArguments( - a_tensor, b_tensor, c_tensor, current_stream - ), - warmup_iterations=warmup_iterations, - profiling_iterations=iterations, - use_cuda_graphs=False, - stream=current_stream, - ) - - # Print execution results - print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms") - if not skip_ref_check: gemm(a_tensor, b_tensor, c_tensor) torch.cuda.synchronize() @@ -732,6 +752,71 @@ def main( torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05) print("Results verified successfully!") + def generate_tensors(): + # Create new tensors for each workspace to ensure cold L2 cache + a_workspace = create_and_permute_tensor(M, K, a_major == "m", torch.float32) + b_workspace = create_and_permute_tensor(N, K, b_major == "n", torch.float32) + c_workspace = create_and_permute_tensor(M, N, c_major == "m", torch.float32) + + if static_shape: + a_tensor_workspace = ( + from_dlpack(a_workspace, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0)) + .mark_compact_shape_dynamic( + mode=(1 if a_major == "k" else 0), + divisibility=divisibility_a, + ) + ) + else: + a_tensor_workspace = from_dlpack(a_workspace, assumed_align=16) + + b_tensor_workspace = ( + from_dlpack(b_workspace, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0)) + .mark_compact_shape_dynamic( + mode=(1 if b_major == "k" else 0), + divisibility=divisibility_b, + ) + ) + + c_tensor_workspace = ( + from_dlpack(c_workspace, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + .mark_compact_shape_dynamic( + mode=(1 if c_major == "n" else 0), + divisibility=divisibility_c, + ) + ) + + return testing.JitArguments( + a_tensor_workspace, b_tensor_workspace, c_tensor_workspace, current_stream + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a.numel() * a.element_size() + + b.numel() * b.element_size() + + c.numel() * c.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + avg_time_us = testing.benchmark( + gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + # Print execution results + print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms") + + return avg_time_us # Return execution time in microseconds + if __name__ == "__main__": @@ -753,19 +838,27 @@ if __name__ == "__main__": parser.add_argument("--warmup_iterations", default=2, type=int) parser.add_argument("--iterations", default=100, type=int) parser.add_argument("--skip_ref_check", action="store_true") + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) args = parser.parse_args() print("Running SIMT GEMM example:") torch.manual_seed(1024) - main( + run( + args.mnk, args.a_major, args.b_major, args.c_major, - args.mnk, + args.static_shape, args.warmup_iterations, args.iterations, args.skip_ref_check, + args.use_cold_l2, ) print("PASS") diff --git a/examples/python/CuTeDSL/ampere/tensorop_gemm.py b/examples/python/CuTeDSL/ampere/tensorop_gemm.py index 2e6482f9..86110f36 100644 --- a/examples/python/CuTeDSL/ampere/tensorop_gemm.py +++ b/examples/python/CuTeDSL/ampere/tensorop_gemm.py @@ -51,7 +51,7 @@ This GEMM kernel supports the following features: - Utilizes Ampere's tensor cores for matrix multiply-accumulate (MMA) operations - Threadblock rasterization to improve data re-use - Supports multi-stage pipeline to overlap computation and memory access - - Implements shared memory buffering for epilogue to increase coalesed global memory access + - Implements shared memory buffering for epilogue to increase coalesced global memory access This GEMM works as follows: 1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies. @@ -214,7 +214,7 @@ class TensorOpGemm: atom_async_copy, mB.element_type, self.b_major_mode, ab_copy_bits ) - # Creates a synchonous copy atom and thread layouts for the epilogue + # Creates a synchronous copy atom and thread layouts for the epilogue c_copy_bits = 128 atom_sync_copy = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), @@ -550,16 +550,8 @@ class TensorOpGemm: # Creates the tiled copy so that it matches the thread-value layout # expected by the tiled mma - tiled_copy_s2r_A = cute.make_tiled_copy( - atom_copy_s2r_A, - layout_tv=tiled_mma.tv_layout_A_tiled, - tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)), - ) - tiled_copy_s2r_B = cute.make_tiled_copy( - atom_copy_s2r_B, - layout_tv=tiled_mma.tv_layout_B_tiled, - tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)), - ) + tiled_copy_s2r_A = cute.make_tiled_copy_A(atom_copy_s2r_A, tiled_mma) + tiled_copy_s2r_B = cute.make_tiled_copy_B(atom_copy_s2r_B, tiled_mma) thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx) thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx) @@ -836,8 +828,7 @@ class TensorOpGemm: if major_mode == utils.LayoutEnum.ROW_MAJOR else cute.make_layout((copy_elems, 1)) ) - tiler_mn, layout_tv = cute.make_layout_tv(thread_layout, value_layout) - return cute.make_tiled_copy(atom_copy, layout_tv, tiler_mn) + return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout) def raster_tile(self, i, j, f): new_i = i // f @@ -845,20 +836,33 @@ class TensorOpGemm: return (new_i, new_j) -def run_tensor_op_gemm( +def run( a_major: str, b_major: str, c_major: str, ab_dtype: Type[cutlass.Numeric], c_dtype: Type[cutlass.Numeric], acc_dtype: Type[cutlass.Numeric], - problem_shape: Tuple[int, int, int, int], + mnkl: Tuple[int, int, int, int], atom_layout_mnk: Tuple[int, int, int], warmup_iterations: int = 2, iterations: int = 100, skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, ): - M, N, K, L = problem_shape + print(f"Running Ampere tensor core GEMM example:") + print(f"mnkl: {mnkl}") + print( + f"A dtype: {ab_dtype}, B dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}" + ) + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Atoms layout: {atom_layout_mnk}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {use_cold_l2}") + M, N, K, L = mnkl # Create and permute tensor A/B/C def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype): @@ -866,23 +870,28 @@ def run_tensor_op_gemm( # else: (l, mode0, mode1) -> (mode0, mode1, l) shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) - - return ( + torch_tensor = ( torch.empty(*shape, dtype=torch.int32) .random_(-2, 2) - .to(dtype=dtype) + .to(dtype=cutlass_torch.dtype(dtype)) .permute(permute_order) .cuda() ) + # assume input is 16B aligned + cute_tensor = ( + from_dlpack(torch_tensor, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if not is_mode0_major else 0)) + .mark_compact_shape_dynamic( + mode=(1 if not is_mode0_major else 0), + stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0), + divisibility=(128 // dtype.width), + ) + ) + return cute_tensor, torch_tensor - a = create_and_permute_tensor( - L, M, K, a_major == "m", cutlass_torch.dtype(ab_dtype) - ) - b = create_and_permute_tensor( - L, N, K, b_major == "n", cutlass_torch.dtype(ab_dtype) - ) - c = create_and_permute_tensor(L, M, N, c_major == "m", cutlass_torch.dtype(c_dtype)) - ref = torch.einsum("mkl,nkl->mnl", a, b).to(cutlass_torch.dtype(c_dtype)) + mA, a_torch = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype) + mB, b_torch = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype) + mC, c_torch = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype) tensor_op_gemm = TensorOpGemm( ab_dtype, @@ -891,56 +900,49 @@ def run_tensor_op_gemm( atom_layout_mnk, ) - # assume input is 16B aligned - a_tensor = ( - from_dlpack(a, assumed_align=16) - .mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0)) - .mark_compact_shape_dynamic( - mode=(1 if a_major == "k" else 0), - stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0), - divisibility=(128 // ab_dtype.width), - ) - ) - b_tensor = ( - from_dlpack(b, assumed_align=16) - .mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0)) - .mark_compact_shape_dynamic( - mode=(1 if b_major == "k" else 0), - stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0), - divisibility=(128 // ab_dtype.width), - ) - ) - c_tensor = ( - from_dlpack(c, assumed_align=16) - .mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) - .mark_compact_shape_dynamic( - mode=(1 if c_major == "n" else 0), - stride_order=(2, 0, 1) if c_major == "n" else (2, 1, 0), - divisibility=(128 // c_dtype.width), - ) - ) - print("Compiling kernel with cute.compile ...") - gemm = cute.compile(tensor_op_gemm, a_tensor, b_tensor, c_tensor) + compiled_gemm = cute.compile(tensor_op_gemm, mA, mB, mC) print("Executing GEMM kernel...") + if not skip_ref_check: + ref = torch.einsum( + "mkl,nkl->mnl", + a_torch.to(dtype=torch.float32), + b_torch.to(dtype=torch.float32), + ).to(cutlass_torch.dtype(c_dtype)) + compiled_gemm(mA, mB, mC) + print("Verifying results...") + torch.testing.assert_close(c_torch.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05) + print("Results verified successfully!") + + def generate_tensors(): + a_workspace, _ = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype) + b_workspace, _ = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype) + c_workspace, _ = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype) + return testing.JitArguments(a_workspace, b_workspace, c_workspace) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch.numel() * a_torch.element_size() + + b_torch.numel() * b_torch.element_size() + + c_torch.numel() * c_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + avg_time_us = testing.benchmark( - gemm, - kernel_arguments=testing.JitArguments(a_tensor, b_tensor, c_tensor), + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, warmup_iterations=warmup_iterations, - profiling_iterations=iterations, + iterations=iterations, use_cuda_graphs=False, ) - print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms") - - if not skip_ref_check: - gemm(a_tensor, b_tensor, c_tensor) - print("Verifying results...") - torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05) - print("Results verified successfully!") - + return avg_time_us # Return execution time in microseconds if __name__ == "__main__": @@ -985,10 +987,15 @@ if __name__ == "__main__": parser.add_argument("--warmup_iterations", default=2, type=int) parser.add_argument("--iterations", default=100, type=int) parser.add_argument("--skip_ref_check", action="store_true") + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) args = parser.parse_args() - print("Running Ampere tensor core GEMM example:") - run_tensor_op_gemm( + run( args.a_major, args.b_major, args.c_major, @@ -1000,5 +1007,6 @@ if __name__ == "__main__": args.warmup_iterations, args.iterations, args.skip_ref_check, + args.use_cold_l2, ) print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py b/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py new file mode 100644 index 00000000..81720ee9 --- /dev/null +++ b/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py @@ -0,0 +1,2467 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Optional, Type, Tuple, Union + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.runtime import from_dlpack + +""" +This example provides an experimental implementation of the SM100 batched dense blockscaled GEMM kernel, please note that the APIs and implementation details related to this kernel may change in future releases. + +A high-performance persistent batched dense blockscaled GEMM example for the NVIDIA Blackwell SM100 architecture +using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") for MXF8 input type and can only be row-major("K") for MXF4/NVF4 input type +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") for MXF8 input type and can only be row-major("K") for MXF4/NVF4 input type +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") +- Matrix SFA layout is filled internally according to A shape and BlockScaledBasicChunk, which has M×ceil_div(K, sf_vec_size)×L elements respectively +- Matrix SFB layout is filled internally according to B shape and BlockScaledBasicChunk, which has N×ceil_div(K, sf_vec_size)×L elements respectively + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions) + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. MMA warp: + - Load scale factor A/B from shared memory (SMEM) to tensor memory (TMEM) using tcgen05.cp instruction. + - Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. EPILOGUE warp: + - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Type convert C matrix to output type. + - Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations, + or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations. + - Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor: + e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0)) + +SM100 tcgen05.mma.kind.block_scale instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Read scalefactor A from TMEM +- Read scalefactor B from TMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +Input arguments to this example is shown below: + +.. code-block:: bash + + python examples/blackwell/dense_blockscaled_gemm_persistent.py \ + --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ + --c_dtype Float16 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,1024,1 + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/dense_blockscaled_gemm_persistent.py \ + --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ + --c_dtype Float16 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,1024,1 \ + --warmup_iterations 1 --iterations 10 --skip_ref_check + + +Constraints: +* Supported input data types: mxf8, mxf4, nvf4 + see detailed valid dtype combinations in below Sm100BlockScaledPersistentDenseGemmKernel class documentation +* A/B tensor must have the same data type, mixed data type is not supported (e.g., mxf8 x mxf4) +* Mma tiler M must be 128 or 256(use_2cta_instrs) +* Mma tiler N must be 128 or 256 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 if Mma tiler M is 256(use_2cta_instrs) +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 16 and 32 for Float8 and Float4, respectively. +""" + + +class Sm100BlockScaledPersistentDenseGemmKernel: + """This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types + and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported combinations of A/B data types, SF data typs and SF vector size: + - MXF8: A/B: Float8E5M2/Float8E4M3FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - MXF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - NVF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU/Float8E4M3FN + sf_vec_size: 16 + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float32 + - Float16/BFloat16 + - Float8E4M3FN/Float8E5M2 + :note: Constraints: + - MMA tiler M must be 128 or 256 (use_2cta_instrs) + - MMA tiler N must be 128/256 + - Cluster shape M must be multiple of 2 if Mma tiler M is 256 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + - Also, Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors + + Example: + >>> gemm = Sm100BlockScaledPersistentDenseGemmKernel( + ... sf_vec_size=16, + ... mma_tiler_mn=(256, 128), + ... cluster_shape_mn=(2, 1) + ... ) + >>> gemm(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, max_active_clusters, stream) + """ + + def __init__( + self, + sf_vec_size: int, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ): + """Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator, always set to Float32 + - sf_vec_size: Scalefactor A/B vector size. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype = cutlass.Float32 + self.sf_vec_size = sf_vec_size + self.use_2cta_instrs = mma_tiler_mn[0] == 256 + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.threads_per_cta = 32 * len( + (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) + ) + # Set barrier id for cta sync, epilogue sync and tmem ptr sync + self.cta_sync_bar_id = 0 + self.epilog_sync_bar_id = 1 + self.tmem_ptr_sync_bar_id = 2 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B/SFA/SFB + - Computing epilogue subtile + - Setting up A/B/SFA/SFB/C stage counts in shared memory + - Computing A/B/SFA/SFB/C shared memory layout + - Computing tensor memory allocation columns + """ + # Compute mma instruction shapes + mma_inst_bits_k = 256 + # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K) + self.mma_inst_shape_mnk = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_bits_k // self.a_dtype.width, + ) + # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) + self.mma_inst_shape_mnk_sfb = ( + self.mma_inst_shape_mnk[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mnk[1], 128), + self.mma_inst_shape_mnk[2], + ) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mnk[:2], + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mnk_sfb[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_inst_shape_mnk[0], + self.mma_inst_shape_mnk[1], + self.mma_inst_shape_mnk[2] * mma_inst_tile_k, + ) + self.mma_tiler_sfb = ( + self.mma_inst_shape_mnk_sfb[0], + self.mma_inst_shape_mnk_sfb[1], + self.mma_inst_shape_mnk_sfb[2] * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 + + # Compute epilogue subtile + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.a_major_mode, + self.b_dtype, + self.b_major_mode, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sf_dtype, + self.sf_vec_size, + self.smem_capacity, + self.occupancy, + ) + + # Compute A/B/SFA/SFB/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + + @cute.jit + def __call__( + self, + a_tensor: cute.Tensor, + b_tensor: cute.Tensor, + sfa_tensor: cute.Tensor, + sfb_tensor: cute.Tensor, + c_tensor: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a_tensor: Input tensor A + :type a_tensor: cute.Tensor + :param b_tensor: Input tensor B + :type b_tensor: cute.Tensor + :param sfa_tensor: Scale factor tensor A + :type sfa_tensor: cute.Tensor + :param sfb_tensor: Scale factor tensor B + :type sfb_tensor: cute.Tensor + :param c_tensor: Output tensor C + :type c_tensor: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a_tensor.element_type + self.b_dtype: Type[cutlass.Numeric] = b_tensor.element_type + self.sf_dtype: Type[cutlass.Numeric] = sfa_tensor.element_type + self.c_dtype: Type[cutlass.Numeric] = c_tensor.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a_tensor).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_tensor).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c_tensor) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + a_tensor.shape, self.sf_vec_size + ) + sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + b_tensor.shape, self.sf_vec_size + ) + sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mnk[:2], + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mnk_sfb[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a_tensor, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_tensor, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for SFA + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfa_smem_layout = cute.slice_( + self.sfa_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, + sfa_tensor, + sfa_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # Setup TMA load for SFB + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfb_smem_layout = cute.slice_( + self.sfb_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb_tensor, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes = ( + a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size + ) * atom_thr_size + + # Setup TMA store for C + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c_tensor, + epi_smem_layout, + self.epi_tile, + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c_tensor, + self.cta_tile_shape_mnk, + self.cluster_shape_mn, + max_active_clusters, + ) + + self.buffer_align_bytes = 1024 + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.c_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sSFA: cute.struct.Align[ + cute.struct.MemRange[ + self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sSFB: cute.struct.Align[ + cute.struct.MemRange[ + self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tiled_mma_sfb, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + smem=self.shared_storage.size_in_bytes(), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_sfa) + cpasync.prefetch_descriptor(tma_atom_sfb) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * ( + 2 if use_2cta_instrs else 1 + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == self.tma_warp_id: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/SFA/SFB/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + + # + # Compute multicast mask for A/B/SFA/SFB buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bK, RestM, RestK, RestL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_block_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # TMA load SFA partition_S/D + sfa_cta_layout = a_cta_layout + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfa, + block_in_cluster_coord_vmnk[2], + sfa_cta_layout, + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + + # TMA load SFB partition_S/D + sfb_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + cute.arch.barrier( + barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta + ) + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + while work_tile.is_valid_tile: + + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # ((atom_v, rest_v), RestK) + tAgSFA_slice = tAgSFA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgSFB_slice = tBgSFB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_block_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + # + # Tma load loop + # + for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status + ) + + # TMA load A/B/SFA/SFB + cute.copy( + tma_atom_a, + tAgA_slice[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atom_sfa, + tAgSFA_slice[(None, ab_producer_state.count)], + tAsSFA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=sfa_full_mcast_mask, + ) + cute.copy( + tma_atom_sfb, + tBgSFB_slice[(None, ab_producer_state.count)], + tBsSFB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=sfb_full_mcast_mask, + ) + + # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_block_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator/SFA/SFB tensor + # + # Make accumulator tmem tensor + acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # Make SFA tmem tensor + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base), + dtype=self.sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # Make SFB tmem tensor + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=self.sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + # + # Partition for S2T copy of SFA/SFB + # + tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = ( + self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ) + tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = ( + self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # Peek (try_wait) AB buffer full for k_block = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_block_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # + # Mma mainloop + # + for k_block in range(k_block_cnt): + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait( + ab_consumer_state, peek_ab_full_status + ) + + # Copy SFA/SFB from smem to tmem + s2t_stage_coord = ( + None, + None, + None, + None, + ab_consumer_state.index, + ) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) + + # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB + num_kphases = cute.size(tCrA, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = ( + None, + None, + kphase_idx, + ab_consumer_state.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kphase_coord = (None, None, kphase_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kphase_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB[sf_kphase_coord].iterator, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kphase_coord], + tCrB[kphase_coord], + tCtAcc, + ) + + # Enable accumulate on tCtAcc after first kphase + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_block = k_block + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_block_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Async arrive accumulator buffer full + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # + # Alloc tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, + tmem_holding_buf, + is_two_cta=use_2cta_instrs, + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + epi_tidx = tidx + tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = ( + self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs + ) + ) + + tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + tma_atom_c, bSG_sC, bSG_gC_partitioned = ( + self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC + ) + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + while work_tile.is_valid_tile: + + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # + # Wait for accumulator buffer full + # + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Dealloc the tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads + ) + if warp_idx == self.epilog_warp_id[0]: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + acc_tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + # + # Wait for C store complete + # + c_pipeline.producer_tail() + + def mainloop_s2t_copy_and_partition( + self, + sSF: cute.Tensor, + tSF: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination). + + :param sSF: The scale factor tensor in smem + :type sSF: cute.Tensor + :param tSF: The scale factor tensor in tmem + :type tSF: cute.Tensor + + :return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where: + - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t) + - tCsSF_compact_s2t: The partitioned scale factor tensor in smem + - tSF_compact_s2t: The partitioned scale factor tensor in tmem + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSF_compact = cute.filter_zeros(sSF) + # (MMA, MMA_MN, MMA_K) + tCtSF_compact = cute.filter_zeros(tSF) + + # Make S2T CopyAtom and tiledCopy + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(self.cta_group), + self.sf_dtype, + ) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t, tCsSF_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + a_major_mode: tcgen05.OperandMajorMode, + b_dtype: Type[cutlass.Numeric], + b_major_mode: tcgen05.OperandMajorMode, + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param a_major_mode: Major mode of operand A. + :type a_major_mode: tcgen05.OperandMajorMode + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param b_major_mode: Major mode of operand B. + :type b_major_mode: tcgen05.OperandMajorMode + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C. + :type c_layout: utils.LayoutEnum + :param sf_dtype: Data type of Scale factor. + :type sf_dtype: type[cutlass.Numeric] + :param sf_vec_size: Scale factor vector size. + :type sf_vec_size: int + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # ACC stages + num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 + + # Default C stages + num_c_stage = 2 + + # Calculate smem layout and size for one stage of A, B, SFA, SFB and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = ( + cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + ) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B/SFA/SFB stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B/SFA/SFB stage + num_ab_stage = ( + smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B/SFA/SFB stages and reserved bytes + # Add remaining unused smem to epilogue + num_c_stage += ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes) + ) // (occupancy * c_bytes_per_stage) + + return num_acc_stage, num_ab_stage, num_c_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes and sf_vec_size are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size of the scale factor + :type sf_vec_size: int + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes and sf_vec_size are valid, False otherwise + :rtype: bool + """ + is_valid = True + + # Check valid ab_dtype + if ab_dtype not in { + cutlass.Float4E2M1FN, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + # Check valid sf_vec_size + if sf_vec_size not in {16, 32}: + is_valid = False + + # Check valid sf_dtype + if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}: + is_valid = False + + # Check valid sf_dtype and sf_vec_size combinations + if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32: + is_valid = False + if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} and sf_vec_size == 16: + is_valid = False + + # Check valid c_dtype + if c_dtype not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + return is_valid + + @staticmethod + def is_valid_layouts( + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the dtypes and sf_vec_size are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major dimension of the A tensor + :type a_major: str + :param b_major: The major dimension of the B tensor + :type b_major: str + :param c_major: The major dimension of the C tensor + :type c_major: str + + :return: True if the layouts are valid, False otherwise + :rtype: bool + """ + is_valid = True + + if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"): + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not mma_tiler_mn[0] in [128, 256]: + is_valid = False + if not mma_tiler_mn[1] in [128, 256]: + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + # Special cluster shape check for scale factor multicasts. + # Due to limited size of scale factors, we can't multicast among more than 4 CTAs. + or cluster_shape_mn[0] > 4 + or cluster_shape_mn[1] > 4 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor tensor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size + :type sf_vec_size: int + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype, sf_dtype, sf_vec_size, c_dtype + ): + can_implement = False + # Skip unsupported layouts + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_layouts( + ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + return can_implement + + +@cute.jit +def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + sf_ref_tensor: cute.Tensor, + sf_mma_tensor: cute.Tensor, +): + """Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout""" + # sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l) + # group to ((32, 4, rest_m), (4, rest_k), l) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) + for i in cutlass.range(cute.size(sf_ref_tensor)): + mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) + sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] + + +def run( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + tolerance: float = 1e-01, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + """Execute a persistent batched dense blockscaled GEMM operation on Blackwell architecture with performance benchmarking. + + This function prepares input tensors, configures and launches the persistent GEMM kernel, + optionally performs reference validation, and benchmarks the execution performance. + + :param mnkl: Problem size (M, N, K, L) + :type mnkl: Tuple[int, int, int, int] + :param ab_dtype: Data type for input tensors A and B + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: Data type for scale factor tensor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: Vector size for scale factor tensor + :type sf_vec_size: int + :param c_dtype: Data type for output tensor C + :type c_dtype: Type[cutlass.Numeric] + :param a_major/b_major/c_major: Memory layout of tensor A/B/C + :type a_major/b_major/c_major: str + :param mma_tiler_mn: MMA tiling size. + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster shape. + :type cluster_shape_mn: Tuple[int, int] + :param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01 + :type tolerance: float, optional + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 1 + :type iterations: int, optional + :param skip_ref_check: Whether to skip reference result validation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :raises RuntimeError: If CUDA GPU is not available + :raises ValueError: If the configuration is invalid or unsupported by the kernel + :return: Execution time of the GEMM kernel + :rtype: float + """ + print(f"Running Sm100 Persistent Dense BlockScaled GEMM test with:") + print(f"mnkl: {mnkl}") + print(f"AB dtype: {ab_dtype}, SF dtype: {sf_dtype}, SF Vec size: {sf_vec_size}") + print(f"C dtype: {c_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") + + # Unpack parameters + m, n, k, l = mnkl + + # Skip unsupported testcase + if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( + ab_dtype, + sf_dtype, + sf_vec_size, + c_dtype, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Create tensor A/B/C + a_ref = cutlass_torch.matrix(l, m, k, a_major == "m", cutlass.Float32) + b_ref = cutlass_torch.matrix(l, n, k, b_major == "n", cutlass.Float32) + c_ref = cutlass_torch.matrix(l, m, n, c_major == "m", cutlass.Float32) + + a_tensor, a_torch = cutlass_torch.cute_tensor_like( + a_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, b_torch = cutlass_torch.cute_tensor_like( + b_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, c_torch = cutlass_torch.cute_tensor_like( + c_ref, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + # Mark tensor to be byte aligned + a_tensor.mark_compact_shape_dynamic( + mode=1 if a_major == "k" else 0, + stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0), + divisibility=2 if ab_dtype == cutlass.Float4E2M1FN else 1, + ) + b_tensor.mark_compact_shape_dynamic( + mode=1 if b_major == "k" else 0, + stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0), + divisibility=2 if ab_dtype == cutlass.Float4E2M1FN else 1, + ) + c_tensor.mark_compact_shape_dynamic( + mode=1 if c_major == "n" else 0, + stride_order=(2, 0, 1) if c_major == "n" else (2, 1, 0), + divisibility=2 if c_dtype == cutlass.Float4E2M1FN else 1, + ) + + # Create scale factor tensor SFA/SFB + def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype): + def ceil_div(a, b): + return (a + b - 1) // b + + sf_k = ceil_div(k, sf_vec_size) + ref_shape = (l, mn, sf_k) + + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + + ref_permute_order = (1, 2, 0) + mma_permute_order = (3, 4, 1, 5, 2, 0) + + # Create f32 ref torch tensor (cpu) + ref_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + ref_shape, + torch.float32, + permute_order=ref_permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=1, + max_val=3, + ), + ) + + # Create f32 cute torch tensor (cpu) + cute_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + mma_shape, + torch.float32, + permute_order=mma_permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=0, + max_val=1, + ), + ) + + # convert ref f32 tensor to cute f32 tensor + cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + from_dlpack(ref_f32_torch_tensor_cpu), + from_dlpack(cute_f32_torch_tensor_cpu), + ) + cute_f32_torch_tensor = cute_f32_torch_tensor_cpu.cuda() + + # reshape makes memory contiguous + ref_f32_torch_tensor_cpu = ( + ref_f32_torch_tensor_cpu.permute(2, 0, 1) + .unsqueeze(-1) + .expand(l, mn, sf_k, sf_vec_size) + .reshape(l, mn, sf_k * sf_vec_size) + .permute(*ref_permute_order) + ) + # prune to mkl for reference check. + ref_f32_torch_tensor_cpu = ref_f32_torch_tensor_cpu[:, :k, :] + + # Create dtype cute torch tensor (cpu) + cute_tensor, cute_torch_tensor = cutlass_torch.cute_tensor_like( + cute_f32_torch_tensor_cpu, + dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + + # Convert f32 cute tensor to dtype cute tensor + cute_tensor = cutlass_torch.convert_cute_tensor( + cute_f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=True, + ) + return ref_f32_torch_tensor_cpu, cute_tensor, cute_torch_tensor + + sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor( + l, m, k, sf_vec_size, sf_dtype + ) + sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor( + l, n, k, sf_vec_size, sf_dtype + ) + + # Configure gemm kernel + gemm = Sm100BlockScaledPersistentDenseGemmKernel( + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + ) + + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + + # Initialize Stream + current_stream = cutlass_torch.default_stream() + + # Compile gemm kernel + compiled_gemm = cute.compile( + gemm, + a_tensor, + b_tensor, + sfa_tensor, + sfb_tensor, + c_tensor, + max_active_clusters, + current_stream, + ) + + # Compute reference result + if not skip_ref_check: + # Execute kernel once for reference checking + compiled_gemm( + a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, current_stream + ) + print("Verifying results...") + res_a = torch.einsum("mkl,mkl->mkl", a_ref, sfa_ref) + res_b = torch.einsum("nkl,nkl->nkl", b_ref, sfb_ref) + ref = torch.einsum("mkl,nkl->mnl", res_a, res_b) + + # Convert c back to f32 for comparison. + c_ref_device = c_ref.cuda() + cute.testing.convert( + c_tensor, + from_dlpack(c_ref_device, assumed_align=16).mark_layout_dynamic( + leading_dim=(1 if c_major == "n" else 0) + ), + ) + c_ref = c_ref_device.cpu() + + if c_dtype in (cutlass.Float32, cutlass.Float16, cutlass.BFloat16): + torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02) + elif c_dtype in (cutlass.Float8E5M2, cutlass.Float8E4M3FN): + # Convert ref : f32 -> f8 -> f32 + ref_f8_ = torch.empty(*(l, m, n), dtype=torch.uint8, device="cuda").permute( + 1, 2, 0 + ) + ref_f8 = from_dlpack(ref_f8_, assumed_align=16).mark_layout_dynamic( + leading_dim=1 + ) + ref_f8.element_type = c_dtype + ref_device = ref.permute(2, 0, 1).contiguous().permute(1, 2, 0).cuda() + ref_tensor = from_dlpack(ref_device, assumed_align=16).mark_layout_dynamic( + leading_dim=1 + ) + cute.testing.convert(ref_tensor, ref_f8) + cute.testing.convert(ref_f8, ref_tensor) + ref = ref_device.cpu() + torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02) + def generate_tensors(): + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, _ = cutlass_torch.cute_tensor_like( + c_ref, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + # Mark tensor to be byte aligned + a_tensor.mark_compact_shape_dynamic( + mode=1 if a_major == "k" else 0, + stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0), + divisibility=2 if ab_dtype == cutlass.Float4E2M1FN else 1, + ) + b_tensor.mark_compact_shape_dynamic( + mode=1 if b_major == "k" else 0, + stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0), + divisibility=2 if ab_dtype == cutlass.Float4E2M1FN else 1, + ) + c_tensor.mark_compact_shape_dynamic( + mode=1 if c_major == "n" else 0, + stride_order=(2, 0, 1) if c_major == "n" else (2, 1, 0), + divisibility=2 if c_dtype == cutlass.Float4E2M1FN else 1, + ) + + _, sfa_tensor, _ = create_scale_factor_tensor(l, m, k, sf_vec_size, sf_dtype) + _, sfb_tensor, _ = create_scale_factor_tensor(l, n, k, sf_vec_size, sf_dtype) + return cute.testing.JitArguments( + a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, current_stream + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch.numel() * a_torch.element_size() + + b_torch.numel() * b_torch.element_size() + + sfa_torch.numel() * sfa_torch.element_size() + + sfb_torch.numel() * sfb_torch.element_size() + + c_torch.numel() * c_torch.element_size() + ) + workspace_count = cute.testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = cute.testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of Sm100 Dense Persistent BlockScaled GEMM." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(512, 256, 256, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float4E2M1FN) + parser.add_argument("--sf_dtype", type=cutlass.dtype, default=cutlass.Float8E8M0FNU) + parser.add_argument("--sf_vec_size", type=int, default=16) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float16) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run( + args.mnkl, + args.ab_dtype, + args.sf_dtype, + args.sf_vec_size, + args.c_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm.py b/examples/python/CuTeDSL/blackwell/dense_gemm.py index 77c5c923..c36c28a8 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm.py @@ -212,7 +212,7 @@ class DenseGemmKernel: self.occupancy = 1 self.threads_per_cta = 128 - self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"] + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs @@ -1106,11 +1106,7 @@ class DenseGemmKernel: copy_atom_r2s = sm100_utils.get_smem_store_op( self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r ) - tiled_copy_r2s = cute.make_tiled_copy( - copy_atom_r2s, - layout_tv=tiled_copy_t2r.layout_dst_tv_tiled, - tiler_mn=tiled_copy_t2r.tiler_mn, - ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) # (R2S, R2S_M, R2S_N, PIPE_D) thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) tRS_sC = thr_copy_r2s.partition_D(sC) @@ -1772,7 +1768,7 @@ def run_dense_gemm( ref_c = ref elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: # m major: (l, n, m) -> (m, n, l) - # k major: (l, m, n) -> (m, n, l) + # n major: (l, m, n) -> (m, n, l) permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) shape = (l, m, n) if c_major == "n" else (l, n, m) f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py b/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py index e577b7a7..3251abbd 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py @@ -38,6 +38,7 @@ from cutlass.cute.nvgpu import cpasync, tcgen05 import cutlass.torch as cutlass_torch import cutlass.utils as utils import cutlass.pipeline as pipeline +import cutlass.cute.testing as testing import cutlass.utils.blackwell_helpers as sm100_utils from cutlass.cute.runtime import from_dlpack @@ -226,7 +227,7 @@ class PersistentDenseGemmKernel: self.cta_sync_bar_id = 0 self.epilog_sync_bar_id = 1 self.tmem_ptr_sync_bar_id = 2 - self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"] + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs @@ -1308,11 +1309,7 @@ class PersistentDenseGemmKernel: copy_atom_r2s = sm100_utils.get_smem_store_op( self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r ) - tiled_copy_r2s = cute.make_tiled_copy( - copy_atom_r2s, - layout_tv=tiled_copy_t2r.layout_dst_tv_tiled, - tiler_mn=tiled_copy_t2r.tiler_mn, - ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) # (R2S, R2S_M, R2S_N, PIPE_D) thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) tRS_sC = thr_copy_r2s.partition_D(sC) @@ -1824,7 +1821,7 @@ class PersistentDenseGemmKernel: return can_implement -def run_dense_gemm( +def run( mnkl: Tuple[int, int, int, int], ab_dtype: Type[cutlass.Numeric], c_dtype: Type[cutlass.Numeric], @@ -1832,17 +1829,58 @@ def run_dense_gemm( a_major: str, b_major: str, c_major: str, - mma_tiler_mn: Tuple[int, int], - cluster_shape_mn: Tuple[int, int], - use_2cta_instrs: bool, - use_tma_store: bool, - tolerance: float, + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Tuple[int, int] = (2, 1), + use_2cta_instrs: bool = True, + use_tma_store: bool = True, + tolerance: float = 1e-01, warmup_iterations: int = 0, iterations: int = 1, skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, ): - """ - Prepare A/B/C tensors, launch GPU kernel, and reference checking. + """Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking. + + This function prepares input tensors, configures and launches the persistent GEMM kernel, + optionally performs reference validation, and benchmarks the execution performance. + + :param mnkl: Problem size (M, N, K, L) + :type mnkl: Tuple[int, int, int, int] + :param ab_dtype: Data type for input tensors A and B + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: Data type for output tensor C + :type c_dtype: Type[cutlass.Numeric] + :param acc_dtype: Data type for accumulation during matrix multiplication + :type acc_dtype: Type[cutlass.Numeric] + :param a_major/b_major/c_major: Memory layout of tensor A/B/C + :type a_major/b_major/c_major: str + :param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the + default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters. + :type mma_tiler_mn: Tuple[int, int], optional + :param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the + default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters. + :type cluster_shape_mn: Tuple[int, int], optional + :param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner + will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters. + :type use_2cta_instrs: bool, optional + :param use_tma_store: Whether to use TMA store. If not specified in the decorator parameters, the autotuner will use + the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters. + :type use_tma_store: bool, optional + :param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01 + :type tolerance: float, optional + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 1 + :type iterations: int, optional + :param skip_ref_check: Whether to skip reference result validation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :raises RuntimeError: If CUDA GPU is not available + :raises ValueError: If the configuration is invalid or unsupported by the kernel + :return: Execution time of the GEMM kernel + :rtype: float """ print(f"Running Blackwell Persistent Dense GEMM test with:") print(f"mnkl: {mnkl}") @@ -1855,6 +1893,7 @@ def run_dense_gemm( print(f"Warmup iterations: {warmup_iterations}") print(f"Iterations: {iterations}") print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") # Unpack parameters m, n, k, l = mnkl @@ -1931,15 +1970,15 @@ def run_dense_gemm( is_dynamic_layout=is_dynamic_layout, ) - return f32_torch_tensor, cute_tensor, torch_tensor + return f32_torch_tensor, cute_tensor, torch_tensor, torch_tensor_cpu - a_ref, a_tensor, a_torch = create_and_permute_tensor( + a_ref, a_tensor, a_torch, a_torch_cpu = create_and_permute_tensor( l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True ) - b_ref, b_tensor, b_torch = create_and_permute_tensor( + b_ref, b_tensor, b_torch, b_torch_cpu = create_and_permute_tensor( l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True ) - c_ref, c_tensor, c_torch = create_and_permute_tensor( + c_ref, c_tensor, c_torch, c_torch_cpu = create_and_permute_tensor( l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True ) @@ -1967,16 +2006,8 @@ def run_dense_gemm( gemm, a_tensor, b_tensor, c_tensor, max_active_clusters, current_stream ) - # Launch GPU kernel - # Warm up - for i in range(warmup_iterations): - compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream) - # Execution - for i in range(iterations): - compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream) - - # Compute reference result if not skip_ref_check: + compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream) if ab_dtype in { cutlass.Int8, cutlass.Uint8, @@ -2028,6 +2059,40 @@ def run_dense_gemm( rtol=1e-05, ) + def generate_tensors(): + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, _ = cutlass_torch.cute_tensor_like( + c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + return testing.JitArguments(a_tensor, b_tensor, c_tensor, current_stream) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch_cpu.numel() * a_torch_cpu.element_size() + + b_torch_cpu.numel() * b_torch_cpu.element_size() + + c_torch_cpu.numel() * c_torch_cpu.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + if __name__ == "__main__": @@ -2090,6 +2155,12 @@ if __name__ == "__main__": parser.add_argument( "--skip_ref_check", action="store_true", help="Skip reference checking" ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) args = parser.parse_args() @@ -2102,7 +2173,7 @@ if __name__ == "__main__": if len(args.cluster_shape_mn) != 2: parser.error("--cluster_shape_mn must contain exactly 2 values") - run_dense_gemm( + run( args.mnkl, args.ab_dtype, args.c_dtype, @@ -2118,5 +2189,6 @@ if __name__ == "__main__": args.warmup_iterations, args.iterations, args.skip_ref_check, + args.use_cold_l2, ) print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py b/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py index 8f6ab9fb..d69ad401 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py @@ -223,7 +223,7 @@ class DenseGemmKernel: self.occupancy = 1 self.threads_per_cta = 128 - self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"] + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs @@ -1063,11 +1063,7 @@ class DenseGemmKernel: copy_atom_r2s = sm100_utils.get_smem_store_op( self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r ) - tiled_copy_r2s = cute.make_tiled_copy( - copy_atom_r2s, - layout_tv=tiled_copy_t2r.layout_dst_tv_tiled, - tiler_mn=tiled_copy_t2r.tiler_mn, - ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) # (R2S, R2S_M, R2S_N, PIPE_D) thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) tRS_sC = thr_copy_r2s.partition_D(sC) diff --git a/examples/python/CuTeDSL/blackwell/fmha.py b/examples/python/CuTeDSL/blackwell/fmha.py index ce4cafe9..537d9b43 100644 --- a/examples/python/CuTeDSL/blackwell/fmha.py +++ b/examples/python/CuTeDSL/blackwell/fmha.py @@ -43,6 +43,7 @@ import cutlass.utils as utils import cutlass.pipeline as pipeline import cutlass.torch as cutlass_torch import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.cute.testing as testing from cutlass.cute.runtime import from_dlpack from cutlass.cute.typing import Int32, Int64, Float32, Boolean @@ -90,7 +91,7 @@ Constraints for this example: * Number of heads in Q must be divisible by number of heads in K * mma_tiler_mn must be 128,128 * Batch size must be the same for Q, K, and V tensors -* For causal masking, use --has_casual_mask (note: specify without =True/False) +* For causal masking, use --is_causal (note: specify without =True/False) * For persistent scheduling, use --is_persistent (note: specify without =True/False) """ @@ -2373,11 +2374,7 @@ class BlackwellFusedMultiHeadAttentionForward: smem_copy_atom = sm100_utils.get_smem_store_op( self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load ) - tiled_smem_store = cute.make_tiled_copy( - smem_copy_atom, - layout_tv=tiled_tmem_load.layout_dst_tv_tiled, - tiler_mn=tiled_tmem_load.tiler_mn, - ) + tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load) tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i[(None, None), None]) tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i[(None, None), None]) @@ -2619,7 +2616,7 @@ class BlackwellFusedMultiHeadAttentionForward: return tile_sched_params, grid -def run_fmha_and_verify( +def run( q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int], k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int], in_dtype: Type[cutlass.Numeric], @@ -2628,7 +2625,7 @@ def run_fmha_and_verify( pv_acc_dtype: Type[cutlass.Numeric], mma_tiler_mn: Tuple[int, int], is_persistent: bool, - has_casual_mask: bool, + is_causal: bool, scale_q: float, scale_k: float, scale_v: float, @@ -2638,6 +2635,8 @@ def run_fmha_and_verify( warmup_iterations: int, iterations: int, skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, ): """Execute Fused Multi-Head Attention (FMHA) on Blackwell architecture and validate results. @@ -2670,8 +2669,8 @@ def run_fmha_and_verify( :type mma_tiler_mn: Tuple[int, int] :param is_persistent: Whether to use persistent kernel optimization :type is_persistent: bool - :param has_casual_mask: Whether to apply causal masking - :type has_casual_mask: bool + :param is_causal: Whether to apply causal masking + :type is_causal: bool :param scale_q: Scaling factor for query tensor :type scale_q: float :param scale_k: Scaling factor for key tensor @@ -2690,9 +2689,13 @@ def run_fmha_and_verify( :type iterations: int :param skip_ref_check: Skip validation against reference implementation :type skip_ref_check: bool + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache + :type use_cold_l2: bool :raises ValueError: If input shapes are incompatible or head dimension is unsupported :raises RuntimeError: If GPU is unavailable for computation + :return: Execution time of the FMHA kernel in microseconds + :rtype: float """ print(f"Running Blackwell SM100 FMHA test with:") @@ -2704,13 +2707,17 @@ def run_fmha_and_verify( print(f" pv_acc_dtype: {pv_acc_dtype}") print(f" mma_tiler_mn: {mma_tiler_mn}") print(f" is_persistent: {is_persistent}") - print(f" has_casual_mask: {has_casual_mask}") + print(f" is_causal: {is_causal}") print(f" scale_q: {scale_q}") print(f" scale_k: {scale_k}") print(f" scale_v: {scale_v}") print(f" inv_scale_o: {inv_scale_o}") print(f" scale_softmax: {scale_softmax}") print(f" tolerance: {tolerance}") + print(f" warmup_iterations: {warmup_iterations}") + print(f" iterations: {iterations}") + print(f" skip_ref_check: {skip_ref_check}") + print(f" use_cold_l2: {use_cold_l2}") # Unpack parameters b, s_q, h_q, d = q_shape @@ -2882,7 +2889,7 @@ def run_fmha_and_verify( mma_tiler = (*mma_tiler_mn, d) mask_type = MaskType.NO_MASK - if has_casual_mask: + if is_causal: mask_type = MaskType.CAUSAL_MASK else: if isinstance(s_k, tuple): @@ -2942,41 +2949,7 @@ def run_fmha_and_verify( compilation_time = time.time() - start_time print(f"Compilation time: {compilation_time:.4f} seconds") - # Warmup - for _ in range(warmup_iterations): - compiled_fmha( - q_tensor.iterator, - k_tensor.iterator, - v_tensor.iterator, - o_tensor.iterator, - problem_size, - cum_seqlen_q, - cum_seqlen_k, - scale_softmax_log2, - scale_output, - current_stream, - ) - - # Execute kernel - for _ in range(iterations): - compiled_fmha( - q_tensor.iterator, - k_tensor.iterator, - v_tensor.iterator, - o_tensor.iterator, - problem_size, - cum_seqlen_q, - cum_seqlen_k, - scale_softmax_log2, - scale_output, - current_stream, - ) - - torch.cuda.synchronize() - - def run_torch_fmha( - q, k, v, scale_softmax=1.0, scale_output=1.0, has_casual_mask=False - ): + def run_torch_fmha(q, k, v, scale_softmax=1.0, scale_output=1.0, is_causal=False): h_q = q.shape[2] h_k = k.shape[2] @@ -3005,7 +2978,7 @@ def run_fmha_and_verify( v = v.transpose(1, 2) # For the situation that torch has not supported, we need to handle it manually - situation1 = has_casual_mask and (q.is_nested or k.is_nested) + situation1 = is_causal and (q.is_nested or k.is_nested) situation2 = (q.is_nested and not k.is_nested) or ( not q.is_nested and k.is_nested ) @@ -3025,8 +2998,9 @@ def run_fmha_and_verify( attn_mask=None, dropout_p=0.0, scale=scale_softmax, - is_causal=has_casual_mask, + is_causal=is_causal, ) + ref_i = ref_i.transpose(0, 1) * scale_output ref_list.append(ref_i) if q.is_nested: ref = torch.nested.nested_tensor(ref_list, layout=torch.jagged) @@ -3040,15 +3014,28 @@ def run_fmha_and_verify( attn_mask=None, dropout_p=0.0, scale=scale_softmax, - is_causal=has_casual_mask, + is_causal=is_causal, ) - ref = ref.transpose(1, 2) * scale_output + ref = ref.transpose(1, 2) * scale_output return ref if not skip_ref_check: + # Execute kernel once for reference checking + compiled_fmha( + q_tensor.iterator, + k_tensor.iterator, + v_tensor.iterator, + o_tensor.iterator, + problem_size, + cum_seqlen_q, + cum_seqlen_k, + scale_softmax_log2, + scale_output, + current_stream, + ) print("Verifying results...") o_ref = run_torch_fmha( - q_ref, k_ref, v_ref, scale_softmax, scale_output, has_casual_mask + q_ref, k_ref, v_ref, scale_softmax, scale_output, is_causal ) if o_ref.is_nested: @@ -3095,6 +3082,76 @@ def run_fmha_and_verify( torch.testing.assert_close(o_result, o_ref, atol=tolerance, rtol=1e-05) print("Results verified successfully!") + def generate_tensors(): + _, q_tensor_workspace, _ = create_and_pad_tensor( + qo_shape, + qo_padding, + in_dtype, + s_cumsum=cum_seqlen_q_torch, + is_dynamic_layout=True, + ) + _, k_tensor_workspace, _ = create_and_pad_tensor( + kv_shape, + kv_padding, + in_dtype, + s_cumsum=cum_seqlen_k_torch, + is_dynamic_layout=True, + ) + _, v_tensor_workspace, _ = create_and_pad_tensor( + kv_shape, + kv_padding, + in_dtype, + s_cumsum=cum_seqlen_k_torch, + is_dynamic_layout=True, + ) + _, o_tensor_workspace, _ = create_and_pad_tensor( + qo_shape, + qo_padding, + out_dtype, + s_cumsum=cum_seqlen_q_torch, + is_dynamic_layout=True, + ) + return testing.JitArguments( + q_tensor_workspace.iterator, + k_tensor_workspace.iterator, + v_tensor_workspace.iterator, + o_tensor_workspace.iterator, + problem_size, + cum_seqlen_q, + cum_seqlen_k, + scale_softmax_log2, + scale_output, + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + q_torch_effective = q_torch.values() if q_torch.is_nested else q_torch + k_torch_effective = k_torch.values() if k_torch.is_nested else k_torch + v_torch_effective = v_torch.values() if v_torch.is_nested else v_torch + o_torch_effective = o_torch.values() if o_torch.is_nested else o_torch + one_workspace_bytes = ( + q_torch_effective.numel() * q_torch_effective.element_size() + + k_torch_effective.numel() * k_torch_effective.element_size() + + v_torch_effective.numel() * v_torch_effective.element_size() + + o_torch_effective.numel() * o_torch_effective.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_fmha, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + if __name__ == "__main__": def parse_comma_separated_ints(s: str): @@ -3185,7 +3242,7 @@ if __name__ == "__main__": ) parser.add_argument( - "--has_casual_mask", + "--is_causal", action="store_true", help="Whether to use casual mask", ) @@ -3263,6 +3320,13 @@ if __name__ == "__main__": help="Skip reference check", ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + args = parser.parse_args() if len(args.q_shape) != 4: @@ -3279,7 +3343,7 @@ if __name__ == "__main__": torch.manual_seed(1111) - run_fmha_and_verify( + run( args.q_shape, args.k_shape, args.in_dtype, @@ -3288,7 +3352,7 @@ if __name__ == "__main__": args.pv_acc_dtype, args.mma_tiler_mn, args.is_persistent, - args.has_casual_mask, + args.is_causal, args.scale_q, args.scale_k, args.scale_v, @@ -3298,6 +3362,7 @@ if __name__ == "__main__": args.warmup_iterations, args.iterations, args.skip_ref_check, + args.use_cold_l2, ) print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/grouped_gemm.py b/examples/python/CuTeDSL/blackwell/grouped_gemm.py index 3b67bf3b..0dba2bb6 100644 --- a/examples/python/CuTeDSL/blackwell/grouped_gemm.py +++ b/examples/python/CuTeDSL/blackwell/grouped_gemm.py @@ -36,6 +36,7 @@ import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute +import cutlass.cute.testing as testing import cutlass.utils as utils from cutlass.cute.nvgpu import cpasync, tcgen05 import cutlass.utils.blackwell_helpers as sm100_utils @@ -157,7 +158,7 @@ class GroupedGemmKernel: self.tmem_ptr_sync_bar_id = 2 # Barrier ID used by MMA/TMA warps to signal A/B tensormap initialization completion self.tensormap_ab_init_bar_id = 4 - self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"] + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") self.num_tma_load_bytes = 0 def _setup_attributes(self): @@ -951,7 +952,7 @@ class GroupedGemmKernel: # Specialized MMA warp # if warp_idx == self.mma_warp_id: - # initilize tensormap A, B for TMA warp + # initialize tensormap A, B for TMA warp if cutlass.const_expr(self.delegate_tensormap_ab_init): tensormap_manager.init_tensormap_from_atom( tma_atom_a, tensormap_a_init_ptr, self.mma_warp_id @@ -1540,11 +1541,7 @@ class GroupedGemmKernel: copy_atom_r2s = sm100_utils.get_smem_store_op( self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r ) - tiled_copy_r2s = cute.make_tiled_copy( - copy_atom_r2s, - layout_tv=tiled_copy_t2r.layout_dst_tv_tiled, - tiler_mn=tiled_copy_t2r.tiler_mn, - ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) # (R2S, R2S_M, R2S_N, PIPE_D) thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) tRS_sC = thr_copy_r2s.partition_D(sC) @@ -1815,7 +1812,136 @@ class GroupedGemmKernel: tensor_memory_management_bytes = 12 -def run_grouped_gemm( +# Create tensor and return the pointer, tensor, and stride +def create_tensor_and_stride( + l: int, + mode0: int, + mode1: int, + is_mode0_major: bool, + dtype: type[cutlass.Numeric], + is_dynamic_layout: bool = True, + torch_tensor_cpu: torch.Tensor = None, +) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]: + """Create a GPU tensor from scratch or based on an existing CPU tensor. + + :param torch_tensor_cpu: Optional existing CPU tensor to reuse. If None, creates a new one. + :type torch_tensor_cpu: torch.Tensor, optional + """ + if torch_tensor_cpu is None: + # Create new CPU tensor + torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype) + + # Create GPU tensor from CPU tensor (new or existing) + cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like( + torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16 + ) + return ( + torch_tensor.data_ptr(), + torch_tensor, + cute_tensor, + torch_tensor_cpu, + torch_tensor.stride()[:-1], + ) + + +def create_tensors_for_all_groups( + problem_sizes_mnkl: List[tuple[int, int, int, int]], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + torch_fp32_tensors_abc: List[List[torch.Tensor]] = None, +) -> tuple[ + List[List[int]], + List[List[torch.Tensor]], + List[tuple], + List[List[tuple]], + List[List[torch.Tensor]], +]: + if torch_fp32_tensors_abc is not None and len(torch_fp32_tensors_abc) != len( + problem_sizes_mnkl + ): + raise ValueError("torch_fp32_tensors_abc must have one entry per group") + + # Initialize lists to store tensors for all groups + new_torch_fp32_tensors_abc = ( + [] if torch_fp32_tensors_abc is None else torch_fp32_tensors_abc + ) + torch_tensors_abc = [] + cute_tensors_abc = [] + strides_abc = [] + ptrs_abc = [] + + # Iterate through all groups and create tensors for each group + for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl): + # Get existing CPU tensors if available, otherwise None + existing_cpu_a = ( + torch_fp32_tensors_abc[group_idx][0] if torch_fp32_tensors_abc else None + ) + existing_cpu_b = ( + torch_fp32_tensors_abc[group_idx][1] if torch_fp32_tensors_abc else None + ) + existing_cpu_c = ( + torch_fp32_tensors_abc[group_idx][2] if torch_fp32_tensors_abc else None + ) + + # Create tensors (reusing CPU tensors if provided) + ( + ptr_a, + torch_tensor_a, + cute_tensor_a, + tensor_fp32_a, + stride_mk_a, + ) = create_tensor_and_stride( + l, m, k, a_major == "m", ab_dtype, torch_tensor_cpu=existing_cpu_a + ) + ( + ptr_b, + torch_tensor_b, + cute_tensor_b, + tensor_fp32_b, + stride_nk_b, + ) = create_tensor_and_stride( + l, n, k, b_major == "n", ab_dtype, torch_tensor_cpu=existing_cpu_b + ) + ( + ptr_c, + torch_tensor_c, + cute_tensor_c, + tensor_fp32_c, + stride_mn_c, + ) = create_tensor_and_stride( + l, m, n, c_major == "m", c_dtype, torch_tensor_cpu=existing_cpu_c + ) + + # Only append to new_torch_fp32_tensors_abc if we created new CPU tensors + if torch_fp32_tensors_abc is None: + new_torch_fp32_tensors_abc.append( + [tensor_fp32_a, tensor_fp32_b, tensor_fp32_c] + ) + + ptrs_abc.append([ptr_a, ptr_b, ptr_c]) + torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c]) + strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c]) + cute_tensors_abc.append( + ( + cute_tensor_a, + cute_tensor_b, + cute_tensor_c, + ) + ) + + return ( + ptrs_abc, + torch_tensors_abc, + cute_tensors_abc, + strides_abc, + new_torch_fp32_tensors_abc, + ) + + +def run( num_groups: int, problem_sizes_mnkl: tuple[int, int, int, int], ab_dtype: Type[cutlass.Numeric], @@ -1832,8 +1958,16 @@ def run_grouped_gemm( warmup_iterations: int, iterations: int, skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, ): - """Run grouped GEMM example with specified configurations.""" + """Run grouped GEMM example with specified configurations. + + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :return: Execution time of the GEMM kernel in microseconds + :rtype: float + """ print(f"Running Blackwell Grouped GEMM test with:") print(f"{num_groups} groups") for i, (m, n, k, l) in enumerate(problem_sizes_mnkl): @@ -1847,6 +1981,7 @@ def run_grouped_gemm( print(f"Warmup iterations: {warmup_iterations}") print(f"Iterations: {iterations}") print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") # Skip unsupported types if ab_dtype not in { @@ -1902,66 +2037,22 @@ def run_grouped_gemm( if not torch.cuda.is_available(): raise RuntimeError("GPU is required to run this example!") - # Create tensor and return the pointer, tensor, and stride - def create_tensor_and_stride( - l: int, - mode0: int, - mode1: int, - is_mode0_major: bool, - dtype: type[cutlass.Numeric], - is_dynamic_layout: bool = True, - ) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]: - torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype) - cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like( - torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16 - ) - return ( - torch_tensor.data_ptr(), - torch_tensor, - cute_tensor, - torch_tensor_cpu, - torch_tensor.stride()[:-1], - ) + # Create tensors for all groups using the new function + ( + ptrs_abc, + torch_tensors_abc, + cute_tensors_abc, + strides_abc, + torch_fp32_tensors_abc, + ) = create_tensors_for_all_groups( + problem_sizes_mnkl, + ab_dtype, + c_dtype, + a_major, + b_major, + c_major, + ) - # iterate all groups and create tensors for each group - torch_fp32_tensors_abc = [] - torch_tensors_abc = [] - cute_tensors_abc = [] - strides_abc = [] - ptrs_abc = [] - for _, (m, n, k, l) in enumerate(problem_sizes_mnkl): - ( - ptr_a, - torch_tensor_a, - cute_tensor_a, - tensor_fp32_a, - stride_mk_a, - ) = create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype) - ( - ptr_b, - torch_tensor_b, - cute_tensor_b, - tensor_fp32_b, - stride_nk_b, - ) = create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype) - ( - ptr_c, - torch_tensor_c, - cute_tensor_c, - tensor_fp32_c, - stride_mn_c, - ) = create_tensor_and_stride(l, m, n, c_major == "m", c_dtype) - ptrs_abc.append([ptr_a, ptr_b, ptr_c]) - torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c]) - torch_fp32_tensors_abc.append([tensor_fp32_a, tensor_fp32_b, tensor_fp32_c]) - strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c]) - cute_tensors_abc.append( - ( - cute_tensor_a, - cute_tensor_b, - cute_tensor_c, - ) - ) # Choose A, B, C with the smallest size to create initial tensormaps key_size_a = lambda item: item[1][0] * item[1][2] key_size_b = lambda item: item[1][1] * item[1][2] @@ -2078,36 +2169,19 @@ def run_grouped_gemm( current_stream, ) - # Launch GPU kernel - # Warm up - for _ in range(warmup_iterations): - compiled_grouped_gemm( - initial_cute_tensors_abc[0], - initial_cute_tensors_abc[1], - initial_cute_tensors_abc[2], - tensor_of_dim_size_mnkl, - tensor_of_strides_abc, - tensor_of_ptrs_abc, - tensor_of_tensormap, - current_stream, - ) - # Execution - for i in range(iterations): - compiled_grouped_gemm( - initial_cute_tensors_abc[0], - initial_cute_tensors_abc[1], - initial_cute_tensors_abc[2], - tensor_of_dim_size_mnkl, - tensor_of_strides_abc, - tensor_of_ptrs_abc, - tensor_of_tensormap, - current_stream, - ) - - torch.cuda.synchronize() - - # Compute reference result if not skip_ref_check: + compiled_grouped_gemm( + initial_cute_tensors_abc[0], + initial_cute_tensors_abc[1], + initial_cute_tensors_abc[2], + tensor_of_dim_size_mnkl, + tensor_of_strides_abc, + tensor_of_ptrs_abc, + tensor_of_tensormap, + current_stream, + ) + + # Compute reference result for i, (a, b, c) in enumerate(torch_tensors_abc): ref = torch.einsum( "mkl,nkl->mnl", @@ -2122,6 +2196,102 @@ def run_grouped_gemm( rtol=1e-05, ) + def generate_tensors(): + # Reuse existing CPU tensors and create new GPU tensors from them + ( + ptrs_abc_workspace, + torch_tensors_abc_workspace, + cute_tensors_abc_workspace, + strides_abc_workspace, + _, + ) = create_tensors_for_all_groups( + problem_sizes_mnkl, + ab_dtype, + c_dtype, + a_major, + b_major, + c_major, + torch_fp32_tensors_abc, + ) + + initial_cute_tensors_abc_workspace = [ + cute_tensors_abc_workspace[min_a_idx][0], # A with smallest (m, k) + cute_tensors_abc_workspace[min_b_idx][1], # B with smallest (n, k) + cute_tensors_abc_workspace[min_c_idx][2], # C with smallest (m, n) + ] + + # Create new tensors for this workspace + tensor_of_strides_abc_workspace, _ = cutlass_torch.cute_tensor_like( + torch.tensor(strides_abc_workspace, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + tensor_of_ptrs_abc_workspace, _ = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_abc_workspace, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + tensormap_workspace, _ = cutlass_torch.cute_tensor_like( + torch.empty(tensormap_shape, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + ) + + return testing.JitArguments( + initial_cute_tensors_abc_workspace[0], + initial_cute_tensors_abc_workspace[1], + initial_cute_tensors_abc_workspace[2], + tensor_of_dim_size_mnkl, + tensor_of_strides_abc_workspace, + tensor_of_ptrs_abc_workspace, + tensormap_workspace, + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + sum( + [ + sum( + [ + torch_tensor.numel() * torch_tensor.element_size() + for torch_tensor in group_tensors + ] + ) + for group_tensors in torch_tensors_abc + ] + ) + + + # Add size of strides tensor + tensor_of_strides_abc_torch.numel() + * tensor_of_strides_abc_torch.element_size() + + + # Add size of ptrs tensor + tensor_of_ptrs_abc_torch.numel() * tensor_of_ptrs_abc_torch.element_size() + + + # Add size of tensormap tensor + tensor_of_tensormap_torch.numel() * tensor_of_tensormap_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_grouped_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + if __name__ == "__main__": @@ -2218,6 +2388,12 @@ if __name__ == "__main__": parser.add_argument( "--skip_ref_check", action="store_true", help="Skip reference checking" ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) args = parser.parse_args() @@ -2248,7 +2424,7 @@ if __name__ == "__main__": torch.manual_seed(2025) - run_grouped_gemm( + run( args.num_groups, args.problem_sizes_mnkl, args.ab_dtype, @@ -2265,5 +2441,6 @@ if __name__ == "__main__": args.warmup_iterations, args.iterations, args.skip_ref_check, + args.use_cold_l2, ) print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py index 342d5580..e77221bb 100644 --- a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py +++ b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py @@ -29,13 +29,14 @@ import argparse from typing import List, Type, Tuple, Optional -from cuda import cuda +import cuda.bindings.driver as cuda import torch import torch.nn.functional as F import cutlass import cutlass.cute as cute +import cutlass.cute.testing as testing import cutlass.utils as utils import cutlass.pipeline as pipeline from cutlass.cute.nvgpu import cpasync, tcgen05 @@ -43,13 +44,16 @@ import cutlass.torch as cutlass_torch import cutlass.utils.blackwell_helpers as sm100_utils from cutlass.cute.runtime import from_dlpack -from .mamba2_ssd_reference import ( +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent)) +from mamba2_ssd_reference import ( ssd_reference_fp32_all, ssd_reference_lowprecision_intermediates, analyze_relative_diffs, ) - -from .mamba2_ssd_tile_scheduler import ( +from mamba2_ssd_tile_scheduler import ( Mamba2SSDTileSchedulerParams, Mamba2SSDTileScheduler, ) @@ -122,7 +126,7 @@ class SSDKernel: *self.epilog_warp_id, ) ) - self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"] + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") # Named barriers self.pre_inter_sync_bar_id = 1 @@ -1522,7 +1526,10 @@ class SSDKernel: # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N) # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE) tiled_r2s_b, tBrB_r2s, tBsB_r2s = self.pre_inter_smem_store_and_partition_b( - local_tidx, smem_bt_internal_, tiled_s2r_b, tBrB_s2r + local_tidx, + smem_bt_internal_, + tiled_s2r_b, + tBrB_s2r, ) # (MMA, MMA_M, MMA_K, INPUT_STAGE) @@ -3053,7 +3060,7 @@ class SSDKernel: # SegSum # fadd2 + fsel + fmul2/mufu + fmul2 - for subtile_idx in range(0, cute.size(tTR_rQ), 2): + for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True): ( tCompute[subtile_idx], tCompute[subtile_idx + 1], @@ -3061,11 +3068,11 @@ class SSDKernel: (tCrDeltaA_Col[subtile_idx], tCrDeltaA_Col[subtile_idx + 1]), (-tCrDeltaA_Row[subtile_idx], -tCrDeltaA_Row[subtile_idx + 1]), ) - for subtile_idx in range(cute.size(tTR_rQ)): + for subtile_idx in cutlass.range(cute.size(tTR_rQ), unroll_full=True): m, n = tCoord[subtile_idx] if m < n: tCompute[subtile_idx] = cutlass.Float32(-float("inf")) - for subtile_idx in range(0, cute.size(tTR_rQ), 2): + for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True): # TODO: use math.exp directly ( tCompute[subtile_idx], @@ -3130,11 +3137,7 @@ class SSDKernel: dtype, num_bits_per_copy=128, ) - tiled_r2s_b = cute.make_tiled_copy( - copy_atom_r2s_b, - layout_tv=tiled_s2r_b.layout_tv_tiled, - tiler_mn=tiled_s2r_b.tiler_mn, - ) + tiled_r2s_b = cute.make_tiled_copy_S(copy_atom_r2s_b, tiled_s2r_b) thr_r2s_b = tiled_r2s_b.get_slice(local_tidx) # Partition shared tensor for smem store Bt @@ -3333,17 +3336,24 @@ class SSDKernel: ) -def run_ssd( +def run( gbehcdln: Tuple[int, int, int, int, int, int, int, int], io_dtype: Type[cutlass.Numeric], cumsum_delta_dtype: Type[cutlass.Numeric], acc_dtype: Type[cutlass.Numeric], - has_d: bool, - d_has_hdim: bool, + fuse_scale_d: str, tolerance: float, print_rtol_stats: bool, ref_lower_precision: bool, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, ): + has_d = fuse_scale_d != "none" + d_has_hdim = fuse_scale_d == "vector" + print(f"Running B100 Mamba2 SSD with:") print(f"GBEHCDLN: {gbehcdln}") print( @@ -3353,6 +3363,10 @@ def run_ssd( f"Has D (True means fuse Y+=X*D): {has_d}, D has Hdim (True means D.shape DxEH, False means 1xEH): {d_has_hdim}" ) print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") # Unpack parameters G, B, E, H, C, D, L, N = gbehcdln @@ -3515,39 +3529,146 @@ def run_ssd( stream, ) - # Launch compiled ssd kernel - compiled_ssd( - x_tensor, - cumsum_delta_tensor, - delta_tensor, - b_tensor, - c_tensor, - y_tensor, - fstate_tensor, - d_tensor, - stream, + # Launch compiled ssd kernel for reference check + if not skip_ref_check: + compiled_ssd( + x_tensor, + cumsum_delta_tensor, + delta_tensor, + b_tensor, + c_tensor, + y_tensor, + fstate_tensor, + d_tensor, + stream, + ) + + # Reference check + if print_rtol_stats: + print("\nY's Relative diffs:") + analyze_relative_diffs( + y_torch.cpu(), y_ref.to(cutlass_torch.dtype(io_dtype)) + ) + print("\nFstate's Relative diffs:") + analyze_relative_diffs( + fstate_torch.cpu(), fstate_ref.to(cutlass_torch.dtype(io_dtype)) + ) + torch.testing.assert_close( + y_torch.cpu(), + y_ref.to(cutlass_torch.dtype(io_dtype)), + atol=tolerance, + rtol=1e-02, + ) + torch.testing.assert_close( + fstate_torch.cpu(), + fstate_ref.to(cutlass_torch.dtype(io_dtype)), + atol=tolerance, + rtol=1e-05, + ) + + def generate_tensors(): + # Reuse existing CPU reference tensors and create new GPU tensors from them + _, x_tensor_new, _ = create_and_permute_tensor( + [B, EH, D, C, L], + [2, 4, 3, 1, 0], + io_dtype, + ref_tensor=x_ref, + dynamic_modes=[2, 3, 4], + ) + _, cumsum_delta_tensor_new, _ = create_and_permute_tensor( + [B, EH, C, L], + [3, 2, 1, 0], + cumsum_delta_dtype, + ref_tensor=cumsum_delta_ref, + dynamic_modes=[1, 2, 3], + ) + _, delta_tensor_new, _ = create_and_permute_tensor( + [B, EH, C, L], + [3, 2, 1, 0], + io_dtype, + ref_tensor=delta_ref, + dynamic_modes=[1, 2, 3], + ) + _, b_tensor_new, _ = create_and_permute_tensor( + [B, G, N, C, L], + [4, 2, 3, 1, 0], + io_dtype, + ref_tensor=b_ref, + dynamic_modes=[2, 3, 4], + ) + _, c_tensor_new, _ = create_and_permute_tensor( + [B, G, N, C, L], + [4, 2, 3, 1, 0], + io_dtype, + ref_tensor=c_ref, + dynamic_modes=[2, 3, 4], + ) + _, y_tensor_new, _ = create_and_permute_tensor( + [B, EH, D, C, L], + [4, 2, 3, 1, 0], + io_dtype, + ref_tensor=y_ref, + dynamic_modes=[2, 3, 4], + ) + _, fstate_tensor_new, _ = create_and_permute_tensor( + [B, EH, D, N], + [2, 3, 1, 0], + io_dtype, + ref_tensor=fstate_ref, + dynamic_modes=[2, 3], + ) + + if has_d: + _, d_tensor_new, _ = create_and_permute_tensor( + [EH, D if d_has_hdim else 1], + [1, 0], + io_dtype, + ref_tensor=d_ref, + dynamic_modes=[1], + ) + else: + d_tensor_new = d_tensor + + return testing.JitArguments( + x_tensor_new, + cumsum_delta_tensor_new, + delta_tensor_new, + b_tensor_new, + c_tensor_new, + y_tensor_new, + fstate_tensor_new, + d_tensor_new, + stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + x_torch.numel() * x_torch.element_size() + + cumsum_delta_torch.numel() * cumsum_delta_torch.element_size() + + delta_torch.numel() * delta_torch.element_size() + + b_torch.numel() * b_torch.element_size() + + c_torch.numel() * c_torch.element_size() + + y_torch.numel() * y_torch.element_size() + + fstate_torch.numel() * fstate_torch.element_size() + ) + if has_d: + one_workspace_bytes += d_torch.numel() * d_torch.element_size() + + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_ssd, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=stream, + warmup_iterations=warmup_iterations, + iterations=iterations, ) - # Reference check - if print_rtol_stats: - print("\nY's Relative diffs:") - analyze_relative_diffs(y_torch.cpu(), y_ref.to(cutlass_torch.dtype(io_dtype))) - print("\nFstate's Relative diffs:") - analyze_relative_diffs( - fstate_torch.cpu(), fstate_ref.to(cutlass_torch.dtype(io_dtype)) - ) - torch.testing.assert_close( - y_torch.cpu(), - y_ref.to(cutlass_torch.dtype(io_dtype)), - atol=tolerance, - rtol=1e-02, - ) - torch.testing.assert_close( - fstate_torch.cpu(), - fstate_ref.to(cutlass_torch.dtype(io_dtype)), - atol=tolerance, - rtol=1e-05, - ) + return exec_time # Return execution time in microseconds if __name__ == "__main__": @@ -3586,15 +3707,53 @@ if __name__ == "__main__": ) parser.add_argument( "--ref_lower_precision", - type=bool, + action="store_true", default=True, help="Use lower precision for reference check", ) + parser.add_argument( + "--no-ref_lower_precision", + action="store_false", + dest="ref_lower_precision", + default=False, + help="Disable lower precision for reference check", + ) parser.add_argument( "--tolerance", type=float, default=5e-02, help="Tolerance for validation" ) parser.add_argument( - "--print_rtol_stats", type=bool, default=True, help="Print rtol stats" + "--print_rtol_stats", + action="store_true", + default=True, + help="Enable print rtol stats", + ) + parser.add_argument( + "--no-print_rtol_stats", + action="store_false", + dest="print_rtol_stats", + default=False, + help="Disable print rtol stats", + ) + parser.add_argument( + "--warmup_iterations", + type=int, + default=0, + help="Number of warmup iterations", + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", ) args = parser.parse_args() @@ -3602,18 +3761,18 @@ if __name__ == "__main__": if len(args.gbehcdln) != 8: parser.error("--gbehcdln must contain exactly 8 values") - has_d = args.fuse_scale_d != "none" - d_has_hdim = args.fuse_scale_d == "vector" - - run_ssd( + run( args.gbehcdln, args.io_dtype, args.cumsum_delta_dtype, args.acc_dtype, - has_d, - d_has_hdim, + args.fuse_scale_d, args.tolerance, args.print_rtol_stats, args.ref_lower_precision, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, ) print("PASS") diff --git a/examples/python/CuTeDSL/hopper/dense_gemm.py b/examples/python/CuTeDSL/hopper/dense_gemm.py index d9d9783d..6bab06ea 100644 --- a/examples/python/CuTeDSL/hopper/dense_gemm.py +++ b/examples/python/CuTeDSL/hopper/dense_gemm.py @@ -35,6 +35,7 @@ import torch import cutlass import cutlass.cute as cute +import cutlass.cute.testing as testing import cutlass.utils as utils import cutlass.pipeline as pipeline import cutlass.torch as cutlass_torch @@ -166,6 +167,24 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument( "--tolerance", type=float, default=1e-01, help="Tolerance for validation" ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) args = parser.parse_args() @@ -264,7 +283,7 @@ class HopperWgmmaGemmKernel: self.mma_warp_groups = math.prod(self.atom_layout_mnk) self.num_threads_per_warp_group = 128 self.threads_per_cta = self.mma_warp_groups * self.num_threads_per_warp_group - self.smem_capacity = sm90_utils.SMEM_CAPACITY["sm90"] + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90") self.ab_stage = None self.epi_stage = None @@ -1309,7 +1328,7 @@ class HopperWgmmaGemmKernel: }: is_valid = False # tested acc_dtype - if acc_dtype != cutlass.Float32: + if acc_dtype not in {cutlass.Float32, cutlass.Float16}: is_valid = False # tested c_dtype if c_dtype not in { @@ -1335,7 +1354,7 @@ class HopperWgmmaGemmKernel: return is_valid -def run_dense_gemm( +def run( mnkl: Tuple[int, int, int, int], a_dtype: Type[cutlass.Numeric], b_dtype: Type[cutlass.Numeric], @@ -1347,9 +1366,43 @@ def run_dense_gemm( tile_shape_mnk: Tuple[int, int, int], cluster_shape_mn: Tuple[int, int], tolerance: float, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, ): """ Prepare A/B/C tensors, launch GPU kernel, and reference checking. + + :param mnkl: Problem size (M, N, K, L) + :type mnkl: Tuple[int, int, int, int] + :param a_dtype: Data type for input tensor A + :type a_dtype: Type[cutlass.Numeric] + :param b_dtype: Data type for input tensor B + :type b_dtype: Type[cutlass.Numeric] + :param c_dtype: Data type for output tensor C + :type c_dtype: Type[cutlass.Numeric] + :param acc_dtype: Data type for accumulation during matrix multiplication + :type acc_dtype: Type[cutlass.Numeric] + :param a_major/b_major/c_major: Memory layout of tensor A/B/C + :type a_major/b_major/c_major: str + :param tile_shape_mnk: CTA tile shape (M, N, K) + :type tile_shape_mnk: Tuple[int, int, int] + :param cluster_shape_mn: Cluster shape (M, N) + :type cluster_shape_mn: Tuple[int, int] + :param tolerance: Tolerance value for reference validation comparison + :type tolerance: float + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 1 + :type iterations: int, optional + :param skip_ref_check: Whether to skip reference result validation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :return: Execution time of the GEMM kernel in microseconds + :rtype: float """ print(f"Running Hopper Dense GEMM with:") @@ -1360,6 +1413,10 @@ def run_dense_gemm( print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}") print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {use_cold_l2}") # Unpack parameters m, n, k, l = mnkl @@ -1437,46 +1494,76 @@ def run_dense_gemm( stream = cuda.CUstream(torch_stream.cuda_stream) # compile gemm kernel compiled_gemm = cute.compile(gemm, mA, mB, mC, stream) - # execution - compiled_gemm(mA, mB, mC, stream) - torch.cuda.synchronize() + if not skip_ref_check: + # execution + compiled_gemm(mA, mB, mC, stream) - # Ref check - ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu() + torch.cuda.synchronize() - if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2): - # m major: (l, n, m) -> (m, n, l) - # k major: (l, m, n) -> (m, n, l) - permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) - shape = (l, m, n) if c_major == "n" else (l, n, m) - f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( - shape, - torch.uint8, - permute_order=permute_order, - init_type=cutlass_torch.TensorInitType.SKIP, - ).cuda() - # Create dtype cute tensor (gpu) - ref_c_tensor = from_dlpack( - f8_torch_tensor, assumed_align=16 - ).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) - ref_c_tensor.element_type = c_dtype - ref_c_tensor = cutlass_torch.convert_cute_tensor( - ref, - ref_c_tensor, - c_dtype, - is_dynamic_layout=True, + # Ref check + ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu() + + if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2): + # m major: (l, n, m) -> (m, n, l) + # n major: (l, m, n) -> (m, n, l) + permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) + shape = (l, m, n) if c_major == "n" else (l, n, m) + f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch.uint8, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.SKIP, + ).cuda() + # Create dtype cute tensor (gpu) + ref_c_tensor = from_dlpack( + f8_torch_tensor, assumed_align=16 + ).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + ref_c_tensor.element_type = c_dtype + ref_c_tensor = cutlass_torch.convert_cute_tensor( + ref, + ref_c_tensor, + c_dtype, + is_dynamic_layout=True, + ) + ref_c = f8_torch_tensor.cpu() + else: + ref_c = ref.to(cutlass_torch.dtype(c_dtype)) + + torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03) + + def generate_tensors(): + _, mA_workspace, _ = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype) + _, mB_workspace, _ = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype) + _, mC_workspace, _ = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype) + return testing.JitArguments(mA_workspace, mB_workspace, mC_workspace, stream) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch.numel() * a_torch.element_size() + + b_torch.numel() * b_torch.element_size() + + c_torch.numel() * c_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations ) - ref_c = f8_torch_tensor.cpu() - else: - ref_c = ref.to(cutlass_torch.dtype(c_dtype)) - torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03) + exec_time = testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds if __name__ == "__main__": args = parse_arguments() - run_dense_gemm( + run( args.mnkl, args.a_dtype, args.b_dtype, @@ -1488,5 +1575,9 @@ if __name__ == "__main__": args.tile_shape_mnk, args.cluster_shape_mn, args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, ) print("PASS") diff --git a/examples/python/CuTeDSL/notebooks/print.ipynb b/examples/python/CuTeDSL/notebooks/print.ipynb index 64787bb4..ce1cf9ec 100644 --- a/examples/python/CuTeDSL/notebooks/print.ipynb +++ b/examples/python/CuTeDSL/notebooks/print.ipynb @@ -399,6 +399,70 @@ "\n", "tensor_print_example3()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To print the tensor in device memory, you can use `cute.print_tensor` within CuTe JIT kernels." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "@cute.kernel\n", + "def print_tensor_gpu(src: cute.Tensor):\n", + " print(src)\n", + " cute.print_tensor(src)\n", + "\n", + "@cute.jit\n", + "def print_tensor_host(src: cute.Tensor):\n", + " print_tensor_gpu(src).launch(grid=(1,1,1), block=(1,1,1))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor o (4,3):(3,1)>\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(raw_ptr(0x00007f5f81200400: f32, gmem, align<4>) o (4,3):(3,1), data=\n", + " [[-0.690547, -0.274619, -1.659539, ],\n", + " [-1.843524, -1.648711, 1.163431, ],\n", + " [-0.716668, -1.900705, 0.592515, ],\n", + " [ 0.711333, -0.552422, 0.860237, ]])\n" + ] + } + ], + "source": [ + "import torch\n", + "def tensor_print_example4():\n", + " a = torch.randn(4, 3, device=\"cuda\")\n", + " cutlass.cuda.initialize_cuda_context()\n", + " print_tensor_host(from_dlpack(a))\n", + "\n", + "tensor_print_example4()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Currently, `cute.print_tensor` only supports tensor with integer data types and `Float16`/`Float32`/`Float64` floating point data types. We will support more data types in the future." + ] } ], "metadata": { diff --git a/examples/python/CuTeDSL/notebooks/tensor.ipynb b/examples/python/CuTeDSL/notebooks/tensor.ipynb index 80b9cff1..ad5c7fc1 100644 --- a/examples/python/CuTeDSL/notebooks/tensor.ipynb +++ b/examples/python/CuTeDSL/notebooks/tensor.ipynb @@ -256,16 +256,6 @@ " cute.printf(\"a[2,3] = {}\", a[2,3])\n", " cute.printf(\"a[(2,4)] = {}\", a[(2,4)])\n", "\n", - "@cute.kernel\n", - "def print_tensor_gpu(ptr: cute.Pointer):\n", - " layout = cute.make_layout((8, 5), stride=(5, 1))\n", - " tensor = cute.make_tensor(ptr, layout)\n", - "\n", - " tidx, _, _ = cute.arch.thread_idx()\n", - "\n", - " if tidx == 0:\n", - " cute.print_tensor(tensor)\n", - "\n", "\n", "# Create a tensor with sequential data using torch\n", "data = torch.arange(0, 8*5, dtype=torch.float32).reshape(8, 5)\n", diff --git a/examples/python/deprecated/03_basic_conv2d.ipynb b/examples/python/deprecated/03_basic_conv2d.ipynb index d0eb4526..09ebd7bd 100644 --- a/examples/python/deprecated/03_basic_conv2d.ipynb +++ b/examples/python/deprecated/03_basic_conv2d.ipynb @@ -363,7 +363,7 @@ "| | \"few_channels\" | optimized for small `C` and requires `C % alignment_input == 0`|\n", "| | \"fixed_channels\" | optimized for small `C` and requires `C == alignment_input` |\n", "|Dgrad | \"analytic\" | Functionally correct in all cases but lower performance |\n", - "| | \"optimized\" | Optimzed for and require `R <= 32`, `S<= 32`, `K % alignment_grad_output == 0`, and `C % alignment_weight == 0`|\n", + "| | \"optimized\" | Optimized for and require `R <= 32`, `S<= 32`, `K % alignment_grad_output == 0`, and `C % alignment_weight == 0`|\n", "|Wgrad | \"analytic\" | Functionally correct in all cases but lower performance |\n", "| | \"optimized\" | Optimized for and require `K % alignment_grad_output == 0`, and `C % alignment_input == 0`|\n", "\n", diff --git a/include/cutlass/arch/wmma.h b/include/cutlass/arch/wmma.h index 9cb9c04f..2d4861ab 100644 --- a/include/cutlass/arch/wmma.h +++ b/include/cutlass/arch/wmma.h @@ -177,7 +177,7 @@ struct WmmaToCutlassDataType<__nv_bfloat16> { ///////////////////////////////////////////////////////////////////////////////////////////////// // WMMA template structure defines nvcuda::wmma::fragments and static assertion chaeks -// for a specific template paramterized data type (Element[A|B|C]), layout (Layout[A|B|C]), +// for a specific template parameterized data type (Element[A|B|C]), layout (Layout[A|B|C]), // and native wmma size (Shape) ///////////////////////////////////////////////////////////////////////////////////////////////// template < diff --git a/include/cutlass/arch/wmma_sm70.h b/include/cutlass/arch/wmma_sm70.h index 99d81487..bec61172 100644 --- a/include/cutlass/arch/wmma_sm70.h +++ b/include/cutlass/arch/wmma_sm70.h @@ -123,7 +123,7 @@ struct Wmma< nvcuda::wmma::mma_sync(D, A, B, C); } #else - static_assert(false, "wmma.mma.sync for floating point multiplicands is avialable only for SM70 and beyond"); + static_assert(false, "wmma.mma.sync for floating point multiplicands is available only for SM70 and beyond"); #endif }; diff --git a/include/cutlass/arch/wmma_sm72.h b/include/cutlass/arch/wmma_sm72.h index 3c488c76..95c639e6 100644 --- a/include/cutlass/arch/wmma_sm72.h +++ b/include/cutlass/arch/wmma_sm72.h @@ -117,7 +117,7 @@ struct Wmma< } #else - static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond"); + static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM72 and beyond"); #endif }; @@ -197,7 +197,7 @@ struct Wmma< } #else - static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond"); + static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM72 and beyond"); #endif }; diff --git a/include/cutlass/arch/wmma_sm75.h b/include/cutlass/arch/wmma_sm75.h index d49e8ca8..10e9d916 100644 --- a/include/cutlass/arch/wmma_sm75.h +++ b/include/cutlass/arch/wmma_sm75.h @@ -115,7 +115,7 @@ struct Wmma< } #else - static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond"); + static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM75 and beyond"); #endif }; @@ -194,7 +194,7 @@ struct Wmma< } #else - static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond"); + static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM75 and beyond"); #endif }; diff --git a/include/cutlass/array_subbyte.h b/include/cutlass/array_subbyte.h index 6a61379c..756890bb 100644 --- a/include/cutlass/array_subbyte.h +++ b/include/cutlass/array_subbyte.h @@ -118,7 +118,7 @@ struct Array { // result[0] = xxx; // ``` // - // Will leads to compiler warning on use of unintialized member variable. Although we know + // Will leads to compiler warning on use of uninitialized member variable. Although we know // this read of uninitialized member variable is harmeless. #if defined(__clang__) diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop.h b/include/cutlass/conv/kernel/default_conv2d_fprop.h index 77e4c5dc..932d1abd 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop.h @@ -1056,7 +1056,7 @@ struct DefaultConv2dFprop < ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and /// multistage pipeline. template < typename ElementA, @@ -1184,7 +1184,7 @@ struct DefaultConv2dFprop < ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and // multistage pipeline with interleaved layout. template < typename ElementA, diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h b/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h index 107a1be6..85b142a0 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h @@ -215,7 +215,7 @@ struct DefaultConv2dFpropFusion < ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and /// multistage pipeline. template < typename ElementA, diff --git a/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h b/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h index 024fb820..513de059 100644 --- a/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h +++ b/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h @@ -217,7 +217,7 @@ struct DefaultConv3dFpropFusion < ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv3dFprop specialzation for Optimzed IteratorAlgorithm and +/// Defines a kernel for Conv3dFprop specialzation for Optimized IteratorAlgorithm and /// multistage pipeline. template < typename ElementA, diff --git a/include/cutlass/cuda_host_adapter.hpp b/include/cutlass/cuda_host_adapter.hpp index 98e77893..a8af62be 100644 --- a/include/cutlass/cuda_host_adapter.hpp +++ b/include/cutlass/cuda_host_adapter.hpp @@ -30,7 +30,7 @@ **************************************************************************************************/ /*! \file - \brief Interface betweeen a CUTLASS device-wide operator and CUDA. + \brief Interface between a CUTLASS device-wide operator and CUDA. */ #pragma once @@ -392,7 +392,7 @@ protected: /** * Fills a buffer in Global Memory with a byte sequence copied from host memory. - * This function can be overriden to dispatch to the appropriate cuMemsetD*Async API + * This function can be overridden to dispatch to the appropriate cuMemsetD*Async API */ virtual Status memsetDeviceImpl( void* destination, ///< Device memory pointer to be filled diff --git a/include/cutlass/gemm/collective/builders/sm120_blockscaled_sparse_mma_builder.inl b/include/cutlass/gemm/collective/builders/sm120_blockscaled_sparse_mma_builder.inl index e127c38c..df39a2b1 100755 --- a/include/cutlass/gemm/collective/builders/sm120_blockscaled_sparse_mma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm120_blockscaled_sparse_mma_builder.inl @@ -271,7 +271,7 @@ struct CollectiveBuilder< // Construct TileShape for SFB load from GMEM to SMEM. // It is required to keep consistency with BlockScaled granularity defined in Sm1xxBlkScaledConfig. - // So that TileShape for scaling factor needs to be defined as a mutliple of Blk_MN. + // So that TileShape for scaling factor needs to be defined as a multiple of Blk_MN. using TileShapeSf_MNK = decltype(make_shape(ceil_div(size<0>(TileShape_MNK{}), Blk_MN{}) * Blk_MN{}, ceil_div(size<1>(TileShape_MNK{}), Blk_MN{}) * Blk_MN{}, shape<2>(TileShape_MNK{}))); diff --git a/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp b/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp index 6442eb3b..03163121 100755 --- a/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp +++ b/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp @@ -153,13 +153,13 @@ struct CollectiveMma< // Asymmetric buffering // Tensor A/B could have different buffering, with TILEK, and STAGEs. // It let AsymmetricKRatio equals TILEK_A / TILEK_B, to make sure A/B's - // pipeline keep same steps when procude / consume data. + // pipeline keep same steps when produce / consume data. // Currently, AsymmetricKRatio = {1, 2} is the only support. static constexpr int AsymmetricKRatio = DispatchPolicy::StagesA != DispatchPolicy::StagesB ? 2 : 1; // Construct TileShape for SFB load from GMEM to SMEM. // It is required to keep consistency with BlockScaled granularity defined in Sm1xxBlkScaledConfig. - // So that TileShape for scaling factor needs to be defined as a mutliple of Blk_MN. + // So that TileShape for scaling factor needs to be defined as a multiple of Blk_MN. using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; using TileShapeSF = decltype(make_shape(ceil_div(size<0>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}, ceil_div(size<1>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}, diff --git a/include/cutlass/gemm/collective/sm120_sparse_mma_tma.hpp b/include/cutlass/gemm/collective/sm120_sparse_mma_tma.hpp index 3316308b..7eec27bc 100644 --- a/include/cutlass/gemm/collective/sm120_sparse_mma_tma.hpp +++ b/include/cutlass/gemm/collective/sm120_sparse_mma_tma.hpp @@ -136,7 +136,7 @@ struct CollectiveMma< // Asymmetric buffering // Tensor A/B could have different buffering, with TILEK, and STAGEs. // It let AsymmetricKRatio equals TILEK_A / TILEK_B, to make sure A/B's - // pipeline keep same steps when procude / consume data. + // pipeline keep same steps when produce / consume data. static constexpr int AsymmetricKRatio = DispatchPolicy::StagesA != DispatchPolicy::StagesB ? 2 : 1; using TileShapeB = decltype(make_shape(size<0>(TileShape{}), diff --git a/include/cutlass/gemm/collective/sm80_mma_multistage.hpp b/include/cutlass/gemm/collective/sm80_mma_multistage.hpp index 97758404..2e3e394d 100644 --- a/include/cutlass/gemm/collective/sm80_mma_multistage.hpp +++ b/include/cutlass/gemm/collective/sm80_mma_multistage.hpp @@ -100,7 +100,7 @@ struct CollectiveMma< using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; // Follow the change in TestSmall: TileShape switch to CtaShape - // For sm80 arch, CtaShape should euqal to TileShape + // For sm80 arch, CtaShape should equal to TileShape using CtaShape_MNK = TileShape; static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); diff --git a/include/cutlass/gemm/device/ell_gemm.h b/include/cutlass/gemm/device/ell_gemm.h index 4261496b..097debf5 100644 --- a/include/cutlass/gemm/device/ell_gemm.h +++ b/include/cutlass/gemm/device/ell_gemm.h @@ -99,7 +99,7 @@ namespace device { Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format: a_rows - Rows in the sparse matrix. - a_cols - Colums in the sparse matrix. + a_cols - Columns in the sparse matrix. BlockedEllA - Packed matrix (ellValue matrix) that stores non-zero values in consecutive blocks, whose size is (a_rows * a_ell_num_columns) ell_idx - Blocked-ELL Column indices (ellColInd) matrix, whose size is @@ -715,7 +715,7 @@ public: /// Constructs the GEMM. EllGemm() { } - /// Helper to construct a transposed equivalent for the underying GEMM operator + /// Helper to construct a transposed equivalent for the underlying GEMM operator static UnderlyingArguments to_underlying_arguments(Arguments const &args) { return UnderlyingArguments( {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, diff --git a/include/cutlass/gemm/device/gemm.h b/include/cutlass/gemm/device/gemm.h index 7c36f6a7..f4ea4ebe 100644 --- a/include/cutlass/gemm/device/gemm.h +++ b/include/cutlass/gemm/device/gemm.h @@ -696,7 +696,7 @@ public: /// Constructs the GEMM. Gemm() { } - /// Helper to construct a transposed equivalent for the underying GEMM operator + /// Helper to construct a transposed equivalent for the underlying GEMM operator static UnderlyingArguments to_underlying_arguments(Arguments const &args) { return UnderlyingArguments( {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, diff --git a/include/cutlass/gemm/device/gemm_array.h b/include/cutlass/gemm/device/gemm_array.h index c59ea0d5..ab5ed26b 100644 --- a/include/cutlass/gemm/device/gemm_array.h +++ b/include/cutlass/gemm/device/gemm_array.h @@ -653,7 +653,7 @@ public: /// Constructs the GEMM. GemmArray() { } - /// Helper to construct a transposed equivalent for the underying GEMM operator + /// Helper to construct a transposed equivalent for the underlying GEMM operator static UnderlyingArguments to_underlying_arguments(Arguments const &args) { GemmCoord problem_size{ diff --git a/include/cutlass/gemm/device/gemm_batched.h b/include/cutlass/gemm/device/gemm_batched.h index 45a471ce..4a5b4105 100644 --- a/include/cutlass/gemm/device/gemm_batched.h +++ b/include/cutlass/gemm/device/gemm_batched.h @@ -626,7 +626,7 @@ public: /// Constructs the GEMM. GemmBatched() { } - /// Helper to construct a transposed equivalent for the underying GEMM operator + /// Helper to construct a transposed equivalent for the underlying GEMM operator static UnderlyingArguments to_underlying_arguments(Arguments const &args) { return UnderlyingArguments( {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, diff --git a/include/cutlass/gemm/device/gemm_complex.h b/include/cutlass/gemm/device/gemm_complex.h index 35965012..b0403230 100644 --- a/include/cutlass/gemm/device/gemm_complex.h +++ b/include/cutlass/gemm/device/gemm_complex.h @@ -645,7 +645,7 @@ public: /// Constructs the GEMM. GemmComplex() { } - /// Helper to construct a transposed equivalent for the underying GEMM operator + /// Helper to construct a transposed equivalent for the underlying GEMM operator static UnderlyingArguments to_underlying_arguments(Arguments const &args) { return UnderlyingArguments( {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, diff --git a/include/cutlass/gemm/device/gemm_splitk_parallel.h b/include/cutlass/gemm/device/gemm_splitk_parallel.h index e0599810..1cf506f5 100644 --- a/include/cutlass/gemm/device/gemm_splitk_parallel.h +++ b/include/cutlass/gemm/device/gemm_splitk_parallel.h @@ -561,7 +561,7 @@ public: /// Constructs the GEMM. GemmSplitKParallel() { } - /// Helper to construct a transposed equivalent for the underying GEMM operator + /// Helper to construct a transposed equivalent for the underlying GEMM operator static UnderlyingArguments to_underlying_arguments(Arguments const &args) { return UnderlyingArguments( {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, diff --git a/include/cutlass/gemm/device/gemm_universal.h b/include/cutlass/gemm/device/gemm_universal.h index 5da6a367..c2c76eb8 100644 --- a/include/cutlass/gemm/device/gemm_universal.h +++ b/include/cutlass/gemm/device/gemm_universal.h @@ -367,7 +367,7 @@ public: /// Constructs the GEMM. GemmUniversal() { } - /// Helper to construct a transposed equivalent for the underying GEMM operator + /// Helper to construct a transposed equivalent for the underlying GEMM operator static Arguments to_underlying_arguments(Arguments const &args) { return args.transposed_problem(); } diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 3317445e..3508de00 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -693,7 +693,7 @@ public: /// Constructs the GEMM. GemmUniversalAdapter() { } - /// Helper to construct a transposed equivalent for the underying GEMM operator + /// Helper to construct a transposed equivalent for the underlying GEMM operator static Arguments to_underlying_arguments(Arguments const &args) { if (kInternalTranspose) { return args.transposed_problem(); diff --git a/include/cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h b/include/cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h index 7de048bb..84d148d8 100644 --- a/include/cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h +++ b/include/cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h @@ -311,7 +311,7 @@ public: /// Constructs the GEMM. GemmUniversalStreamkWithBroadcast() { } - /// Helper to construct a transposed equivalent for the underying GEMM operator + /// Helper to construct a transposed equivalent for the underlying GEMM operator static Arguments to_underlying_arguments(Arguments const &args) { return args.transposed_problem(); } diff --git a/include/cutlass/gemm/device/gemm_universal_with_absmax.h b/include/cutlass/gemm/device/gemm_universal_with_absmax.h index 2459d3a1..d2172d63 100644 --- a/include/cutlass/gemm/device/gemm_universal_with_absmax.h +++ b/include/cutlass/gemm/device/gemm_universal_with_absmax.h @@ -329,7 +329,7 @@ public: /// Constructs the GEMM. GemmUniversalWithAbsMax() { } - /// Helper to construct a transposed equivalent for the underying GEMM operator + /// Helper to construct a transposed equivalent for the underlying GEMM operator static Arguments to_underlying_arguments(Arguments const &args) { return args.transposed_problem(); } diff --git a/include/cutlass/gemm/device/gemm_universal_with_broadcast.h b/include/cutlass/gemm/device/gemm_universal_with_broadcast.h index 70b18347..f04bf8d5 100644 --- a/include/cutlass/gemm/device/gemm_universal_with_broadcast.h +++ b/include/cutlass/gemm/device/gemm_universal_with_broadcast.h @@ -311,7 +311,7 @@ public: /// Constructs the GEMM. GemmUniversalWithBroadcast() { } - /// Helper to construct a transposed equivalent for the underying GEMM operator + /// Helper to construct a transposed equivalent for the underlying GEMM operator static Arguments to_underlying_arguments(Arguments const &args) { return args.transposed_problem(); } diff --git a/include/cutlass/gemm/device/gemm_with_k_reduction.h b/include/cutlass/gemm/device/gemm_with_k_reduction.h index 2f64d04b..5bde1161 100644 --- a/include/cutlass/gemm/device/gemm_with_k_reduction.h +++ b/include/cutlass/gemm/device/gemm_with_k_reduction.h @@ -340,7 +340,7 @@ public: /// Constructs the GEMM. GemmWithKReduction() = default; - /// Helper to construct a transposed equivalent for the underying GEMM operator + /// Helper to construct a transposed equivalent for the underlying GEMM operator static Arguments to_underlying_arguments(Arguments const &args) { return args.transposed_problem(); } diff --git a/include/cutlass/gemm/device/rank_2k.h b/include/cutlass/gemm/device/rank_2k.h index 8e7f436d..293ca06a 100644 --- a/include/cutlass/gemm/device/rank_2k.h +++ b/include/cutlass/gemm/device/rank_2k.h @@ -473,7 +473,7 @@ public: /// Constructs the Rank2K. Rank2K() { } - /// Helper to construct a transposed equivalent for the underying Rank2K operator + /// Helper to construct a transposed equivalent for the underlying Rank2K operator static Arguments to_underlying_arguments(Arguments const &args) { return args.transposed_problem(); } diff --git a/include/cutlass/gemm/device/rank_k.h b/include/cutlass/gemm/device/rank_k.h index 665c4e3e..80c420cd 100644 --- a/include/cutlass/gemm/device/rank_k.h +++ b/include/cutlass/gemm/device/rank_k.h @@ -436,7 +436,7 @@ public: /// Constructs the RankK. RankK() { } - /// Helper to construct a transposed equivalent for the underying RankK operator + /// Helper to construct a transposed equivalent for the underlying RankK operator static Arguments to_underlying_arguments(Arguments const &args) { return args; } diff --git a/include/cutlass/gemm/device/symm.h b/include/cutlass/gemm/device/symm.h index 69a76996..538d294f 100755 --- a/include/cutlass/gemm/device/symm.h +++ b/include/cutlass/gemm/device/symm.h @@ -528,7 +528,7 @@ public: /// Constructs the Symm. Symm() { } - /// Helper to construct a transposed equivalent for the underying SYMM operator + /// Helper to construct a transposed equivalent for the underlying SYMM operator static Arguments to_underlying_arguments(Arguments const &args) { return args.transposed_problem_size(); } diff --git a/include/cutlass/gemm/device/trmm.h b/include/cutlass/gemm/device/trmm.h index 2a9ed8e1..46f6473e 100644 --- a/include/cutlass/gemm/device/trmm.h +++ b/include/cutlass/gemm/device/trmm.h @@ -300,7 +300,7 @@ class Trmm { static int const kAlignmentBKernel = (SideModeA == SideMode::kRight) ? AlignmentA : AlignmentB; static int const kAlignmentC = EpilogueOutputOp::kCount; static bool const kSplitKSerial = SplitKSerial; - // Complex Transform don't appply to B + // Complex Transform don't apply to B static ComplexTransform const kTransformA = TransformA; static ComplexTransform const kTransformB = ComplexTransform::kNone; static ComplexTransform const kTransformAKernel = (SideModeA == SideMode::kRight) ? @@ -651,7 +651,7 @@ class Trmm { static BlasMode const kBlasMode = BlasMode::kSymmetric; - // Complex Transform don't appply to A or B for SYMM + // Complex Transform don't apply to A or B for SYMM static ComplexTransform const TransformA = ComplexTransform::kNone; static ComplexTransform const TransformB = ComplexTransform::kNone; @@ -353,7 +353,7 @@ struct DefaultSymmComplex< Operator, SplitKSerial, BlasMode::kSymmetric> { static BlasMode const kBlasMode = BlasMode::kSymmetric; - // Complex Transform don't appply to A or B for SYMM + // Complex Transform don't apply to A or B for SYMM static ComplexTransform const TransformA = ComplexTransform::kNone; static ComplexTransform const TransformB = ComplexTransform::kNone; diff --git a/include/cutlass/gemm/kernel/gemv_batched_strided.h b/include/cutlass/gemm/kernel/gemv_batched_strided.h index 3b22c110..42b12c3e 100755 --- a/include/cutlass/gemm/kernel/gemv_batched_strided.h +++ b/include/cutlass/gemm/kernel/gemv_batched_strided.h @@ -70,7 +70,7 @@ namespace detail using CDType = typename FragmentCD::value_type; static_assert(FragmentCD::kElements == FragmentAccumulator::kElements, - "Mistmatch in fragment sizes."); + "Mismatch in fragment sizes."); for (int i = 0; i < FragmentCD::kElements; ++i) { diff --git a/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp b/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp index b5538a72..d12241aa 100755 --- a/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp +++ b/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp @@ -52,7 +52,7 @@ namespace cutlass::gemm::kernel::detail { // Therefore, we don't how many tiles there will be for the scheduler to hand out. // Hence, we have a SM90 style static group scheduler that launches the largest grid possible. // If we had access to host-side problem shapes, one could to use it to figure out the grid shape -// and thereafter use CLC query (which can then be linearized and mapped to an approriate tile coord). +// and thereafter use CLC query (which can then be linearized and mapped to an appropriate tile coord). template class PersistentTileSchedulerSm100Group { diff --git a/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp index ca853cd3..8d6e286b 100644 --- a/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp +++ b/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp @@ -728,7 +728,7 @@ private: auto cluster_start_linear_id = sm_count * wave_idx + cluster_idx; // Determine the offset of this CTA in the preferred cluster shape. - // This calculation aims to accomodate both cases in which this CTA is part of a preferred cluster + // This calculation aims to accommodate both cases in which this CTA is part of a preferred cluster // and those in which it is part of a fallback cluster. // // The calculation is performed by computing the starting M and N index of the preferred cluster that diff --git a/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp b/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp index 610dfc6e..0f074309 100644 --- a/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp +++ b/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp @@ -120,7 +120,7 @@ public: // Tensor A/B could have different buffering, with number of KBLOCK, aka TILEK, // and STAGEs. It let AsymmetricKRatio, equals KBLOCK_A / KBLOCK_B, to control // the balance of A/B loading, make sure A/B's pipeline keep same cadence - // when procude / consume data. + // when produce / consume data. // Currently, AsymmetricKRatio = {1, 2} is the only support. static constexpr bool isAsymmetric = DispatchPolicy::Schedule::isAsymmetric; static constexpr uint32_t AsymmetricKRatio = isAsymmetric ? 2 : 1; diff --git a/include/cutlass/gemm/kernel/tile_scheduler_params.h b/include/cutlass/gemm/kernel/tile_scheduler_params.h index 9a89bf21..96037b12 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler_params.h +++ b/include/cutlass/gemm/kernel/tile_scheduler_params.h @@ -409,7 +409,7 @@ struct PersistentTileSchedulerSm90StreamKParams { FastDivmodU64 divmod_clusters_mnl_{}; // We divide up the number of stream-K tiles amongst G groups of stream-K units. - // The stream-K units within a group collaborate to comptue over the `sk_tiles / G` + // The stream-K units within a group collaborate to compute over the `sk_tiles / G` // tiles assigned to that group. Non-unit group sizes can help to preserve L2 locality of // partial chunks computed by stream-K units -- units 0 in each group will compute identical K extents // of tiles that would be assigned in the same wave according to the rasterization order of the @@ -932,7 +932,7 @@ struct PersistentTileSchedulerSm90StreamKParams { } } - // Given decomposition mode output from heuristic, set all feilds of params. + // Given decomposition mode output from heuristic, set all fields of params. void set_params( DecompositionMode heuristic_mode, uint32_t groups, @@ -954,7 +954,7 @@ struct PersistentTileSchedulerSm90StreamKParams { , uint32_t ktile_start_alignment_count ) { // The highest priority when customers set as splitk mode, may set - // with a adpated splits value rather than the original splits + // with a adapted splits value rather than the original splits // even it does not make sense if (splits > 1 && heuristic_mode == DecompositionMode::SplitK) { set_params_basic( diff --git a/include/cutlass/gemm/threadblock/default_ell_mma.h b/include/cutlass/gemm/threadblock/default_ell_mma.h index e27c582e..0ae82f32 100644 --- a/include/cutlass/gemm/threadblock/default_ell_mma.h +++ b/include/cutlass/gemm/threadblock/default_ell_mma.h @@ -94,7 +94,7 @@ template < typename InstructionShape_, /// Number of stages used in the pipelined mainloop int Stages, - /// Operation perfomed by GEMM + /// Operation performed by GEMM typename Operator, /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. @@ -364,7 +364,7 @@ template < typename InstructionShape, /// Number of stages used in the multistage mainloop int Stages, - /// Operation perfomed by GEMM + /// Operation performed by GEMM typename Operator > struct DefaultEllMma struct DefaultEllMma struct DefaultSparseMma struct DefaultTrmm struct DefaultTrmm struct DefaultTrmm struct DefaultTrmm 2, as a single tile may be covered by four SK-blocks, // e.g.:[partial-block | block | block | partial-block] ). With three or - // less peers, the two non-finishing SK-blocks are not expexted to contend. + // less peers, the two non-finishing SK-blocks are not expected to contend. if ((kReductionStrategy == kMixed) && (sk_waves < sm_occupancy) && (sk_blocks > 2 * sk_tiles)) diff --git a/include/cutlass/gemm/warp/mma_complex_tensor_op.h b/include/cutlass/gemm/warp/mma_complex_tensor_op.h index baaced7c..e4b7cf03 100644 --- a/include/cutlass/gemm/warp/mma_complex_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_complex_tensor_op.h @@ -782,7 +782,7 @@ public: for (int n = 0; n < MmaIterations::kColumn; ++n) { // negate OperandB to accumulate -(a.imag()*b.imag()) - // negating OperandB emits less instrucitons than negating OperandA as OperandB has less elements + // negating OperandB emits less instructions than negating OperandA as OperandB has less elements negate negate_op; // Real-valued accumulator part diff --git a/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h b/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h index e84ae06c..fd90ab8c 100644 --- a/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h +++ b/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h @@ -598,7 +598,7 @@ public: for (int n = 0; n < MmaIterations::kColumn; ++n) { // negate OperandB to accumulate -(a.imag()*b.imag()) - // negating OperandB emits less instrucitons than negating OperandA as OperandB has less elements + // negating OperandB emits less instructions than negating OperandA as OperandB has less elements negate negate_op; // Real-valued accumulator part diff --git a/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h b/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h index 4e16ff89..b0757505 100644 --- a/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h @@ -427,7 +427,7 @@ public: using TransformedFragmentA = Array; - /// Underlying arch::Mma instruction operand fragement for matrix A + /// Underlying arch::Mma instruction operand fragment for matrix A using MmaOperandA = typename ArchMmaOperator::FragmentA; /// Iterates over the B operand in Shared Memory @@ -443,7 +443,7 @@ public: using TransformedFragmentB = Array; - /// Underlying arch::Mma instruction operand fragement for matrix B + /// Underlying arch::Mma instruction operand fragment for matrix B using MmaOperandB = typename ArchMmaOperator::FragmentB; /// Iterates over the C operand in memory @@ -454,7 +454,7 @@ public: /// Storage for C tile using FragmentC = typename IteratorC::Fragment; - /// Underlying arch::Mma instruction operand fragement for matrix C + /// Underlying arch::Mma instruction operand fragment for matrix C using MmaOperandC = typename ArchMmaOperator::FragmentC; /// Number of mma operations performed diff --git a/include/cutlass/gemm/warp/mma_sparse_tensor_op.h b/include/cutlass/gemm/warp/mma_sparse_tensor_op.h index 81668b44..902a3d10 100644 --- a/include/cutlass/gemm/warp/mma_sparse_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_sparse_tensor_op.h @@ -117,7 +117,7 @@ public: /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) using Policy = Policy_; - /// Equivalant base dense mma + /// Equivalent base dense mma using Base = MmaTensorOp; diff --git a/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h b/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h index 1489694e..c70bc581 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h @@ -33,7 +33,7 @@ \brief This defines a "fragment" iterator for visiting the fragments of a warp tile that participate in one warp-level mma operation. - Typically, this is used to access the accumulator tile/fragement of a warp-level mma operation. + Typically, this is used to access the accumulator tile/fragment of a warp-level mma operation. The accumulator tile is then partitioned into smaller tiles/fragments that can be fed into next warp-level mma operation. diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h index 6446b7bd..f37c5c14 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h @@ -62,7 +62,7 @@ namespace warp { /// Tile access iterator -/// Each iteration acess in the tile is +/// Each iteration access in the tile is /// used as multiplicand for one /// warp-level matrix multiplication template < diff --git a/include/cutlass/half.h b/include/cutlass/half.h index f5fb90d2..118a80d7 100644 --- a/include/cutlass/half.h +++ b/include/cutlass/half.h @@ -68,7 +68,7 @@ /////////////////////////////////////////////////////////////////////////////////////////////////// -// Optionally target F16C extentions to accelerate half-precision conversion. +// Optionally target F16C extensions to accelerate half-precision conversion. #if !defined(__CUDA_ARCH__) && (CUTLASS_ENABLE_F16C) #if defined(_MSC_VER) diff --git a/include/cutlass/kernel_hardware_info.h b/include/cutlass/kernel_hardware_info.h index c24e2bab..5d7c685f 100644 --- a/include/cutlass/kernel_hardware_info.h +++ b/include/cutlass/kernel_hardware_info.h @@ -90,7 +90,7 @@ struct KernelHardwareInfo { int max_active_clusters = 0; #if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) ClusterLauncher::LaunchConfig cluster_launch_config = ClusterLauncher::make_cluster_launch_config( - cluster_dims /* minumum grid dim */, cluster_dims, {threads_per_block, 1, 1}); + cluster_dims /* minimum grid dim */, cluster_dims, {threads_per_block, 1, 1}); // Given the kernel function and launch configuration, return the maximum number of clusters that could co-exist on the target device. cudaError_t result = cudaOccupancyMaxActiveClusters(&max_active_clusters, kernel_ptr, &cluster_launch_config.launch_config); if (result != cudaSuccess) { diff --git a/include/cutlass/matrix.h b/include/cutlass/matrix.h index b46cbfec..78e15859 100644 --- a/include/cutlass/matrix.h +++ b/include/cutlass/matrix.h @@ -101,7 +101,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 1-by-2 matrix from scalar elements + /// Constructs a 1-by-2 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1 @@ -599,7 +599,7 @@ template using Matrix1x2 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix1x2 make_Matrix1x2( Element _0_0, Element _0_1 @@ -658,7 +658,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 1-by-3 matrix from scalar elements + /// Constructs a 1-by-3 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, Element _0_2 @@ -1226,7 +1226,7 @@ template using Matrix1x3 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix1x3 make_Matrix1x3( Element _0_0, Element _0_1, Element _0_2 @@ -1285,7 +1285,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 1-by-4 matrix from scalar elements + /// Constructs a 1-by-4 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, Element _0_2, Element _0_3 @@ -1905,7 +1905,7 @@ template using Matrix1x4 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix1x4 make_Matrix1x4( Element _0_0, Element _0_1, Element _0_2, Element _0_3 @@ -1964,7 +1964,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 2-by-1 matrix from scalar elements + /// Constructs a 2-by-1 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, @@ -2471,7 +2471,7 @@ template using Matrix2x1 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix2x1 make_Matrix2x1( Element _0_0, @@ -2532,7 +2532,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 2-by-2 matrix from scalar elements + /// Constructs a 2-by-2 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, @@ -2543,7 +2543,7 @@ struct Matrix { data[2] = _1_0; data[3] = _1_1; } - /// Constucts a 2-by-2 matrix from row vectors + /// Constructs a 2-by-2 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -3258,7 +3258,7 @@ template using Matrix2x2 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix2x2 make_Matrix2x2( Element _0_0, Element _0_1, @@ -3319,7 +3319,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 2-by-3 matrix from scalar elements + /// Constructs a 2-by-3 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, Element _0_2, @@ -3330,7 +3330,7 @@ struct Matrix { data[3] = _1_0; data[4] = _1_1; data[5] = _1_2; } - /// Constucts a 2-by-3 matrix from row vectors + /// Constructs a 2-by-3 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -4128,7 +4128,7 @@ template using Matrix2x3 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix2x3 make_Matrix2x3( Element _0_0, Element _0_1, Element _0_2, @@ -4189,7 +4189,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 2-by-4 matrix from scalar elements + /// Constructs a 2-by-4 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, Element _0_2, Element _0_3, @@ -4200,7 +4200,7 @@ struct Matrix { data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3; } - /// Constucts a 2-by-4 matrix from row vectors + /// Constructs a 2-by-4 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -5134,7 +5134,7 @@ template using Matrix2x4 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix2x4 make_Matrix2x4( Element _0_0, Element _0_1, Element _0_2, Element _0_3, @@ -5195,7 +5195,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 3-by-1 matrix from scalar elements + /// Constructs a 3-by-1 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, @@ -5780,7 +5780,7 @@ template using Matrix3x1 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix3x1 make_Matrix3x1( Element _0_0, @@ -5843,7 +5843,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 3-by-2 matrix from scalar elements + /// Constructs a 3-by-2 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, @@ -5856,7 +5856,7 @@ struct Matrix { data[4] = _2_0; data[5] = _2_1; } - /// Constucts a 3-by-2 matrix from row vectors + /// Constructs a 3-by-2 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -6665,7 +6665,7 @@ template using Matrix3x2 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix3x2 make_Matrix3x2( Element _0_0, Element _0_1, @@ -6728,7 +6728,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 3-by-3 matrix from scalar elements + /// Constructs a 3-by-3 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, Element _0_2, @@ -6741,7 +6741,7 @@ struct Matrix { data[6] = _2_0; data[7] = _2_1; data[8] = _2_2; } - /// Constucts a 3-by-3 matrix from row vectors + /// Constructs a 3-by-3 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -7896,7 +7896,7 @@ template using Matrix3x3 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix3x3 make_Matrix3x3( Element _0_0, Element _0_1, Element _0_2, @@ -7959,7 +7959,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 3-by-4 matrix from scalar elements + /// Constructs a 3-by-4 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, Element _0_2, Element _0_3, @@ -7972,7 +7972,7 @@ struct Matrix { data[8] = _2_0; data[9] = _2_1; data[10] = _2_2; data[11] = _2_3; } - /// Constucts a 3-by-4 matrix from row vectors + /// Constructs a 3-by-4 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -9208,7 +9208,7 @@ template using Matrix3x4 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix3x4 make_Matrix3x4( Element _0_0, Element _0_1, Element _0_2, Element _0_3, @@ -9271,7 +9271,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 4-by-1 matrix from scalar elements + /// Constructs a 4-by-1 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, @@ -9918,7 +9918,7 @@ template using Matrix4x1 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix4x1 make_Matrix4x1( Element _0_0, @@ -9983,7 +9983,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 4-by-2 matrix from scalar elements + /// Constructs a 4-by-2 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, @@ -9998,7 +9998,7 @@ struct Matrix { data[6] = _3_0; data[7] = _3_1; } - /// Constucts a 4-by-2 matrix from row vectors + /// Constructs a 4-by-2 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -10958,7 +10958,7 @@ template using Matrix4x2 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix4x2 make_Matrix4x2( Element _0_0, Element _0_1, @@ -11023,7 +11023,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 4-by-3 matrix from scalar elements + /// Constructs a 4-by-3 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, Element _0_2, @@ -11038,7 +11038,7 @@ struct Matrix { data[9] = _3_0; data[10] = _3_1; data[11] = _3_2; } - /// Constucts a 4-by-3 matrix from row vectors + /// Constructs a 4-by-3 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -12291,7 +12291,7 @@ template using Matrix4x3 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix4x3 make_Matrix4x3( Element _0_0, Element _0_1, Element _0_2, @@ -12356,7 +12356,7 @@ struct Matrix { data = rhs.data; } - /// Constucts a 4-by-4 matrix from scalar elements + /// Constructs a 4-by-4 matrix from scalar elements CUTLASS_HOST_DEVICE Matrix( Element _0_0, Element _0_1, Element _0_2, Element _0_3, @@ -12371,7 +12371,7 @@ struct Matrix { data[12] = _3_0; data[13] = _3_1; data[14] = _3_2; data[15] = _3_3; } - /// Constucts a 4-by-4 matrix from row vectors + /// Constructs a 4-by-4 matrix from row vectors CUTLASS_HOST_DEVICE Matrix( Matrix const &row_0, @@ -14096,7 +14096,7 @@ template using Matrix4x4 = Matrix; -/// Free funciton to infer element type from template arguments +/// Free function to infer element type from template arguments template CUTLASS_HOST_DEVICE Matrix4x4 make_Matrix4x4( Element _0_0, Element _0_1, Element _0_2, Element _0_3, diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 886dc9f2..1d5856f9 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -51,7 +51,7 @@ namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Floating-point rounding style similare to Standard Library's formats but supporting +/// Floating-point rounding style similar to Standard Library's formats but supporting /// additional rounding options. enum class FloatRoundStyle { round_indeterminate, ///< rounding mode unknown @@ -6175,7 +6175,7 @@ private: asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(r[ii]) : "r"(src_reg), "r"(prmt_indices[ii])); } - // In the absense of add.s16x2 instruction, use bit-wise operation to execute signed addition with magic numbers to achieve + // In the absence of add.s16x2 instruction, use bit-wise operation to execute signed addition with magic numbers to achieve // the same result as add.s16x2 instruction. // (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-lop3) // For a logical operation F(a, b, c) the value of kImmLut can be computed by applying the same operation to diff --git a/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp b/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp index 99c5bf70..41bc4786 100644 --- a/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp +++ b/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp @@ -289,7 +289,7 @@ public: static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1); static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile; - static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invaild warp thread shape." ); + static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invalid warp thread shape." ); static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step. static constexpr int64_t WarpTileNCoordLUT = 06723763275316420; static constexpr int64_t WarpTileKCoordLUT = 05410541064206420; @@ -510,7 +510,7 @@ public: static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1); static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile; - static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invaild warp thread shape." ); + static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invalid warp thread shape." ); static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step. static constexpr int64_t WarpTileNCoordLUT = 06723763275316420; static constexpr int64_t WarpTileKCoordLUT = 05410541064206420; diff --git a/include/cutlass/transform/pitch_linear_thread_map.h b/include/cutlass/transform/pitch_linear_thread_map.h index 6a8970e8..ef553aab 100644 --- a/include/cutlass/transform/pitch_linear_thread_map.h +++ b/include/cutlass/transform/pitch_linear_thread_map.h @@ -298,7 +298,7 @@ struct PitchLinearWarpRakedThreadMap { static_assert(Iterations::kCount, "Number of iterations must be non-zero"); - ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) + ///< Delta between accesses (units of elements, concept: PitchLinearShape) using Delta = layout::PitchLinearShape< Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess, Detail::WarpThreadArrangement::kStrided @@ -427,7 +427,7 @@ struct PitchLinearStridedWarpRakedThreadMap { static_assert(Iterations::kCount, "Number of iterations must be non-zero"); - ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) + ///< Delta between accesses (units of elements, concept: PitchLinearShape) using Delta = typename BaseThreadMap::Delta; /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space @@ -531,7 +531,7 @@ struct TransposePitchLinearThreadMap { static_assert(Iterations::kCount, "Number of iterations must be non-zero"); - ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) + ///< Delta between accesses (units of elements, concept: PitchLinearShape) using Delta = layout::PitchLinearShape; @@ -716,7 +716,7 @@ struct PitchLinearWarpStripedThreadMap { static_assert(Iterations::kCount, "Number of iterations must be non-zero"); - ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) + ///< Delta between accesses (units of elements, concept: PitchLinearShape) using Delta = layout::PitchLinearShape< Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess, Detail::WarpThreadArrangement::kStrided * Detail::WarpArrangement::kStrided @@ -897,7 +897,7 @@ struct TransposePitchLinearThreadMap2DThreadTile { /// Shape of access by each thread using ThreadAccessShape = typename ThreadMap::ThreadAccessShape; - ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) + ///< Delta between accesses (units of elements, concept: PitchLinearShape) using Delta = layout::PitchLinearShape; diff --git a/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h b/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h index 48fb983f..e377bba4 100644 --- a/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h +++ b/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h @@ -76,7 +76,7 @@ namespace threadblock { /// To be efficient, this assumes the iterator will be dereferenced and advanced at least once /// outside any looping structure to minimize integer arithmetic. /// -/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing /// the iterator. /// /// diff --git a/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h b/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h index e5c2a5f0..36407098 100644 --- a/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h +++ b/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h @@ -419,7 +419,7 @@ class PredicatedTileAccessIterator, Element>; static bool const transpose = Transpose_; diff --git a/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h b/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h index 3acc31ff..9bf5e858 100644 --- a/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h +++ b/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h @@ -76,10 +76,10 @@ namespace threadblock { /// accesses may be performed without updating internal predicates and are efficient in terms of /// live register state and pointer arithmetic instructions. /// -/// To be efficient, this assumes the iteraor will be dereferenced and advanced at least once +/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once /// outside any looping structure to minimize integer arithmetic. /// -/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing /// the iterator. /// /// diff --git a/include/cutlass/transform/warp/vector_fragment_iterator.h b/include/cutlass/transform/warp/vector_fragment_iterator.h index 707cbcc8..b27b77f9 100644 --- a/include/cutlass/transform/warp/vector_fragment_iterator.h +++ b/include/cutlass/transform/warp/vector_fragment_iterator.h @@ -34,7 +34,7 @@ \brief This defines a "fragment" iterator for visiting the fragments of a warp vector that participate in one warp-level mma operation. - Typically, this is used to access the scale/bias fragement of a warp-level mma operation. + Typically, this is used to access the scale/bias fragment of a warp-level mma operation. The scale/bias vector is then partitioned into smaller fragments that can be fed into next warp-level mma operation. diff --git a/media/docs/pythonDSL/cute_dsl.rst b/media/docs/pythonDSL/cute_dsl.rst index 108837a0..a7f53587 100644 --- a/media/docs/pythonDSL/cute_dsl.rst +++ b/media/docs/pythonDSL/cute_dsl.rst @@ -12,6 +12,7 @@ CuTe DSL JIT Argument Generation JIT Argument: Layouts JIT Caching + JIT Compilation Options Integration with Frameworks Debugging with the DSL Autotuning with the DSL diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_control_flow.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_control_flow.rst index e3c9316e..21f0912e 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_control_flow.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_control_flow.rst @@ -178,7 +178,7 @@ Limitations of Dynamic Control Flow n = 10 # ❌ This loop is dynamic, early-exit isn't allowed. - for i in cutlass.range_dynamic(n): + for i in range(n): if i == 5: break # Early-exit diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_jit_compilation_options.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_jit_compilation_options.rst new file mode 100644 index 00000000..d5dd42f8 --- /dev/null +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_jit_compilation_options.rst @@ -0,0 +1,50 @@ +.. _dsl_jit_compilation_options: +.. |DSL| replace:: CuTe DSL + +.. _JIT_Compilation_Options: + +JIT Compilation Options +======================= + +JIT Compilation Options Overview +-------------------------------- + +When compiling a JIT function using |DSL|, you may want to control various aspects of the compilation process, such as optimization level, or debugging flags. |DSL| provides a flexible interface for specifying these compilation options when invoking ``cute.compile``. + +Compilation options allow you to customize how your JIT-compiled functions are built and executed. This can be useful for: + +* Enabling or disabling specific compiler optimizations +* Generating debug information for troubleshooting + +These options can be passed as keyword arguments to ``cute.compile`` or set globally for all JIT compilations. The available options and their effects are described in the following sections, along with usage examples to help you get started. + + +``cute.compile`` Compilation Options +------------------------------------ + +You can provide additional compilation options as a string when calling ``cute.compile``. The |DSL| uses ``argparse`` to parse these options and will raise an error if any invalid options are specified. + +.. list-table:: + :header-rows: 1 + :widths: 20 20 15 25 + + * - **Option** + - **Description** + - **Default** + - **Type** + * - ``opt-level`` + - Optimization level of compilation. The higher the level, the more optimizations are applied. The valid value range is [0, 3]. + - 3 (highest level of optimization) + - int + * - ``enable-device-assertions`` + - Enable device code assertions. + - False + - bool + +You can use the following code to specify compilation options: + +.. code-block:: python + + jit_executor_with_opt_level_2 = cute.compile(add, 1, 2, options="--opt-level 2") + jit_executor_with_opt_level_1 = cute.compile(add, 1, 2, options="--opt-level 1") + jit_executor_with_enable_device_assertions = cute.compile(add, 1, 2, options="--enable-device-assertions") diff --git a/media/docs/pythonDSL/limitations.rst b/media/docs/pythonDSL/limitations.rst index 59c9ad51..7b42c503 100644 --- a/media/docs/pythonDSL/limitations.rst +++ b/media/docs/pythonDSL/limitations.rst @@ -54,6 +54,7 @@ Programming Model - Modifiable during execution of JIT-compiled functions - Only a specific subset of Python types are supported as dynamic values - Primitive types are automatically converted when passed as function arguments: + - ``int`` → ``Int32`` (may be updated to ``Int64`` in future releases) - ``bool`` → ``Bool`` - ``float`` → ``Float32`` (may be updated to ``Float64`` in future releases) @@ -77,7 +78,7 @@ Programming Model # of the runtime value of `i` xs.append(Float32(3.0)) - for i in range_dynamic(10): + for i in range(10): # This only append one element to the list at compile-time # as loop doesn't unroll at compile-time xs.append(Float32(1.0)) @@ -142,16 +143,29 @@ Programming Model @cute.jit def foo(): a = Int32(1) - for i in range_dynamic(10): + for i in range(10): a = Float32(2) # Changing type inside loop-body is not allowed in the DSL + **Built-in Operators** The DSL transforms built-in operators like ``and``, ``or``, ``max``, ``min``, etc. into MLIR operations. They also follow the same constraints of dependent types. For instance, ``a and b`` requires ``a`` and ``b`` to be of the same type. - Comparison like ``==`` on Sequence of dynamic values is known to not produce - expected result at runtime. + +**Special Variables** + The DSL treats ``_`` as a special variable that it's value is meant to be ignored. + It is not allowed to read ``_`` in the DSL. + + Example illustrating functionality in Python that is not supported in the DSL: + + .. code:: python + + @cute.jit + def foo(): + _ = 1 + print(_) # This is not allowed in the DSL + **Object Oriented Programming** The DSL is implemented on top of Python and supports Python's object-oriented programming (OOP) features @@ -179,7 +193,7 @@ Programming Model @cute.jit def foo(a: Int32, res: cute.Tensor): foo = Foo(a) - for i in cutlass.range_dynamic(10): + for i in range(10): foo.set_a(i) # This fails to compile because `a` is assigned a local value defined within the for-loop body diff --git a/python/CuTeDSL/base_dsl/ast_helpers.py b/python/CuTeDSL/base_dsl/ast_helpers.py index 756d151f..b857e40e 100644 --- a/python/CuTeDSL/base_dsl/ast_helpers.py +++ b/python/CuTeDSL/base_dsl/ast_helpers.py @@ -300,6 +300,21 @@ def if_executor( class range: + """ + A range-like object for dynamic loop iteration in the DSL. + + This class provides a range interface similar to Python's built-in range, + but is designed to be preprocessed into constructs for dynamic + loop execution. + + The class supports both single-argument (stop) and three-argument + (start, stop, step) constructors with additional parameters for loop + optimization: + + - unroll: Number of iterations to unroll (0 or 1 = no unrolling) + - unroll_full: Whether to fully unroll the loop + - pipelining: Compiler generated pipeline configuration + """ @overload def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None): pass @@ -460,7 +475,31 @@ def range_value_check(*args): Ensure all `range_constexpr` bounds are compile-time constants (Python ints). """ try: - return tuple(arg.__index__() for arg in args) + args = tuple(arg.__index__() for arg in args) + + # Compute range size and warn if it's too large + start = 0 + end = 0 + step = 1 + if len(args) == 1: + end = args[0] + elif len(args) == 2: + start = args[0] + end = args[1] + elif len(args) == 3: + start = args[0] + end = args[1] + step = args[2] + + range_length = (abs(end - start) - 1) // abs(step) + 1 + if range_length >= 64: + warnings.warn( + f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.", + category=UserWarning, + stacklevel=2, + ) + + return (start, end, step) except: raise DSLRuntimeError( "`range_constexpr` requires constexpr (compile-time constant) for all arguments.", @@ -477,8 +516,8 @@ def range_perf_warning(filename, lineno, *args): if not has_dynamic_expr: warnings.warn_explicit( ( - "The loop was previously unrolled in Python, but now it may not unroll in IR. This may cause performance regression." - "If you want to unroll the loop in Python, please use `range_constexpr` instead of `range`." + "This loop is no longer unrolled and may cause performance regression. " + "Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants." ), category=UserWarning, filename=filename, diff --git a/python/CuTeDSL/base_dsl/ast_preprocessor.py b/python/CuTeDSL/base_dsl/ast_preprocessor.py index ea73831c..bffbc7f2 100644 --- a/python/CuTeDSL/base_dsl/ast_preprocessor.py +++ b/python/CuTeDSL/base_dsl/ast_preprocessor.py @@ -102,6 +102,8 @@ class ScopeManager: return cls([]) def add_to_scope(self, name: str) -> None: + if name == "_": + return self.scopes[-1].add(name) def get_active_symbols(self) -> List[Set[str]]: @@ -361,13 +363,13 @@ class DSLPreprocessor(ast.NodeTransformer): isinstance(func, ast.Name) and func.id in self.SUPPORTED_FOR_RANGE_STATEMENTS ): - return func.id, True + return func.id, True, len(iter_node.keywords) != 0 if ( isinstance(func, ast.Attribute) and func.attr in self.SUPPORTED_FOR_RANGE_STATEMENTS ): - return func.attr, False - return None, None + return func.attr, False, len(iter_node.keywords) != 0 + return None, None, None def transform(self, original_function, exec_globals): """ @@ -378,6 +380,7 @@ class DSLPreprocessor(ast.NodeTransformer): transformed_tree = self.transform_function( original_function.__name__, original_function ) + self.function_globals = None unified_tree = ast.Module(body=transformed_tree, type_ignores=[]) unified_tree = ast.fix_missing_locations(unified_tree) @@ -731,7 +734,7 @@ class DSLPreprocessor(ast.NodeTransformer): self.scope_manager.add_to_scope(node.target.id) # For static for loop (for with range_constexpr or not range based for), preprocessor keeps the loop. - range_kind, is_builtin_range = self._get_range_kind(node.iter) + range_kind, is_builtin_range, has_keyword = self._get_range_kind(node.iter) if range_kind == "range_constexpr" or range_kind == None: self.generic_visit(node) if range_kind == "range_constexpr": @@ -752,7 +755,7 @@ class DSLPreprocessor(ast.NodeTransformer): warnings.simplefilter("default", DeprecationWarning) # reset filter warning_call = None - if range_kind == "range" and is_builtin_range: + if range_kind == "range" and is_builtin_range and not has_keyword: # Warn about possible performance regression due to behavior change warning_call = ast.Expr( ast.Call( @@ -1109,6 +1112,12 @@ class DSLPreprocessor(ast.NodeTransformer): self.generic_visit(node) return node + def visit_Name(self, node): + self.generic_visit(node) + if node.id == "_" and isinstance(node.ctx, ast.Load): + raise DSLAstPreprocessorError("Read '_' is not allowed") + return node + def check_decorator(self, node: ast.AST) -> bool: """ Check if the function has the correct decorator for preprocessing. diff --git a/python/CuTeDSL/base_dsl/compiler.py b/python/CuTeDSL/base_dsl/compiler.py index 2e5b75cd..8776c91b 100644 --- a/python/CuTeDSL/base_dsl/compiler.py +++ b/python/CuTeDSL/base_dsl/compiler.py @@ -19,7 +19,9 @@ from typing import Sequence, Optional, Tuple import os import sys import inspect +import argparse from .common import DSLRuntimeError +from .utils.logger import log _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) sys.path.append(_SCRIPT_PATH) @@ -182,7 +184,67 @@ class Compiler: return self.jit(module, opt_level, shared_libs) +class CompileOptions: + def __init__(self, options: str = ""): + """ + This class encapsulates all compilation options relevant to function compilation. + It provides a convenient way to manage and pass compilation options, + particularly for controlling compilation settings. + By centralizing these options, it ensures consistent and flexible configuration of + compilation parameters such as optimization level, debugging control, etc. + + :param options: The options for the function. Will be parsed by argparse. + :type options: str + """ + if not isinstance(options, str): + raise DSLRuntimeError( + f"Invalid compilation `options`: {options}, it should be a string" + ) + self._parser = argparse.ArgumentParser() + self._parser.add_argument("--opt-level", nargs="?", type=int, default=3) + self._parser.add_argument( + "--enable-device-assertions", action="store_true", default=False + ) + try: + self._options = self._parser.parse_args(options.split()) + except SystemExit as e: + # catch argparse error and raise as DSLRuntimeError + raise DSLRuntimeError( + f"Invalid compile options: '{options}'. Please check the option values and format." + ) + log().info("`cute.compile` CompileOptions: options=" + options) + + def to_str(self): + """ + Generate a string representation of all compilation options + which will be used in pipeline options. + """ + option_strings = [] + for key, value in vars(self._options).items(): + hyphen_key = key.replace("_", "-") + if isinstance(value, bool): + formatted_value = "true" if value else "false" + else: + formatted_value = str(value) + option_strings.append(f"{hyphen_key}={formatted_value}") + + return " ".join(option_strings) + + def compile(func, *args, **kwargs): + """ + This function is used to compile a `cute.jit` decorated function. + It will process the compile options and input parameters, do explicit compilation and return the jit executor. + + :param func: The function to compile. It can be a regular function, a method or a class instance. + :param args: The arguments to pass to the function. + :param kwargs: The keyword arguments to pass to the function. It can contain `options` like + `opt_level` to control the compilation flags. + + :return: The jit executor. + + :raises: DSLRuntimeError if the function is not decorated with `cute.jit` or is not callable. + """ if func is None: raise DSLRuntimeError("Function is not set or invalid.") @@ -217,5 +279,8 @@ def compile(func, *args, **kwargs): if not hasattr(func, "_dsl_object"): raise DSLRuntimeError("Function is not decorated with jit decorator.") + # process compile options, extract the options and remove them from the kwargs + options = kwargs.pop("options", "") + func._dsl_object.compile_options = CompileOptions(options) fcn_ptr = func._dsl_object._preprocess_and_execute(func) return func._dsl_object._func(fcn_ptr, *args, **kwargs) diff --git a/python/CuTeDSL/base_dsl/dsl.py b/python/CuTeDSL/base_dsl/dsl.py index db3f0392..c6ece00c 100644 --- a/python/CuTeDSL/base_dsl/dsl.py +++ b/python/CuTeDSL/base_dsl/dsl.py @@ -38,6 +38,7 @@ import warnings from . import typing as t from .env_manager import EnvironmentVarManager +from .compiler import CompileOptions # ============================================================================= # CUDA Python @@ -232,6 +233,50 @@ def new_from_mlir_values(obj, values): return obj +class DSLCallable: + """ + Wrapper class for a callable object used within the DSL. + + DSLCallable is designed to wrap a function and provide additional + introspection utilities such as retrieving the argument specification + and signature. It ensures that the wrapped function can only be called + once, after which the reference to the function is cleared to prevent + further invocations. This is useful in scenarios where a function should + only be executed a single time within the DSL's execution model. + + Attributes: + func (callable): The function to be wrapped and managed. + + Methods: + __call__(*args, **kwargs): Calls the wrapped function and clears it. + get_arg_spec(): Returns the argument specification of the function. + get_signature(): Returns the signature of the function. + """ + + def __init__(self, func): + self.func = func + + def __call__(self, *args, **kwargs): + ret = self.__func__(*args, **kwargs) + self.func = None + return ret + + @property + def __func__(self): + assert self.func is not None, "DSLCallable is already called" + return self.func + + @property + def __name__(self): + return self.__func__.__name__ + + def get_arg_spec(self): + return inspect.getfullargspec(self.__func__) + + def get_signature(self): + return inspect.signature(self.__func__) + + class BaseDSL: gpu_module = None @@ -306,6 +351,8 @@ class BaseDSL: self.kernel_symbols = [] # used to generate unique name for gpu.launch self.launch_inner_count = 0 + # initialize default compile options + self.compile_options = CompileOptions() if preprocess: self.preprocessor = DSLPreprocessor() @@ -392,26 +439,24 @@ class BaseDSL: if hasattr(func, "_transformed_ast"): # If the function ptr is already materialized, use the existing one func._dsl_object.frame = func._decorator_frame - if func._transformed_ast is None: func._transformed_ast = func._dsl_object.run_preprocessor(func) if func._transformed_ast is None: - del func._decorator_frame del func._transformed_ast + func._dsl_object.frame = None return func - fcn_ptr = func._dsl_object.get_function_ptr(func, func._transformed_ast) + fcn_ptr = func._dsl_object.get_function_ptr(func) # If the function is decorated, de-decorate it fcn_ptr = BaseDSL._get_original_function(fcn_ptr, func.__name__) - return fcn_ptr + func._dsl_object.frame = None + return DSLCallable(fcn_ptr) return func - def jit_runner(self, frame, executor, *dargs, **dkwargs): + def jit_runner(self, executor, frame, *dargs, **dkwargs): """ Decorator to mark a function for JIT compilation. """ - # Set the frame, that can be used AST preprocessor - self.frame = frame log().info("jit_runner") def jit_runner_decorator(func): @@ -444,7 +489,7 @@ class BaseDSL: frame = inspect.currentframe().f_back # Instantiate the DSL Class main_dsl = cls._get_dsl() - return main_dsl.jit_runner(frame, main_dsl._func, *dargs, **dkwargs) + return main_dsl.jit_runner(main_dsl._func, frame, *dargs, **dkwargs) @classmethod def kernel(cls, *dargs, **dkwargs): @@ -454,7 +499,7 @@ class BaseDSL: frame = inspect.currentframe().f_back # Instantiate the DSL Class main_dsl = cls._get_dsl() - return main_dsl.jit_runner(frame, main_dsl._kernel_helper, *dargs, **dkwargs) + return main_dsl.jit_runner(main_dsl._kernel_helper, frame, *dargs, **dkwargs) @abstractmethod def _kernel_helper(self, func, *args, **kwargs): @@ -627,6 +672,12 @@ class BaseDSL: pass @abstractmethod + def _get_module_globals(self): + """ + Get the module's globals. + """ + pass + def _get_globals(self): """ Combines global and local variables from the current context and the @@ -639,7 +690,11 @@ class BaseDSL: AST preprocessor generates a new python code, so the resulting globals dictionary is used to execute the python code. """ - pass + all_globals = self._get_module_globals().copy() + if self.frame: + all_globals.update(self.frame.f_globals) + all_globals.update(self.frame.f_locals) + return all_globals def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool: return isinstance( @@ -881,20 +936,15 @@ class BaseDSL: Get python location information and generate MLIR location """ - frame = self.frame - if frame is None: - print("Frame is None") + if self.frame is None: + log().debug("Frame is None") return None - file_loc = ir.Location.file(frame.f_code.co_filename, frame.f_lineno, 0) + file_loc = ir.Location.file( + self.frame.f_code.co_filename, self.frame.f_lineno, 0 + ) - def print_all_frames(): - for i, frame in enumerate(inspect.stack()): - print( - f"Frame {i}: {frame.function} in {frame.filename}, line {frame.lineno}" - ) - - loc = ir.Location.name(frame.f_code.co_name, childLoc=file_loc) + loc = ir.Location.name(self.frame.f_code.co_name, childLoc=file_loc) return loc def compile_and_jit(self, module, pipeline, shared_libs, function_name=""): @@ -992,6 +1042,8 @@ class BaseDSL: for attr, value in self.envar.__dict__.items(): if value is not None: s.write(str(value).encode()) + # Add compile options to the hash + s.write(self.compile_options.to_str().encode()) module_hash = self.get_version().copy() module_hash.update(s.getvalue()) module_hash = module_hash.hexdigest() @@ -1145,6 +1197,8 @@ class BaseDSL: self.launch_inner_count = 0 # reset num_kernels to 0 for next compilation. self.num_kernels = 0 + # reset the compile options after the compilation is done. + self.compile_options = CompileOptions() def generate_mlir( self, @@ -1226,9 +1280,11 @@ class BaseDSL: return transformed_ast return None - def get_function_ptr(self, original_function, transformed_ast): + def get_function_ptr(self, original_function): file_name = inspect.getsourcefile(original_function) - code_object = compile(transformed_ast, filename=file_name, mode="exec") + code_object = compile( + original_function._transformed_ast, filename=file_name, mode="exec" + ) return self.preprocessor.exec( original_function.__name__, original_function, @@ -1236,10 +1292,6 @@ class BaseDSL: self._get_globals(), ) - @lru_cache(maxsize=None) - def _get_function_signature(self, func): - return inspect.signature(func) - def _get_function_bound_args(self, sig, func_name, *args, **kwargs): """ Binds provided arguments to a function's signature and applies default values. @@ -1260,12 +1312,11 @@ class BaseDSL: ) return bound_args - def _canonicalize_args(self, *args, **kwargs): + def _canonicalize_args(self, sig, *args, **kwargs): """ Canonicalize the input arguments so that returned args only contain positional arguments and kwargs only contain keyword arguments. """ - sig = self._get_function_signature(self.funcBody) function_name = self.funcBody.__name__ bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs) canonicalized_args = bound_args.args @@ -1276,8 +1327,11 @@ class BaseDSL: if not self.funcBody: raise DSLRuntimeError("Function body is not set.") - # Pass the actual function object to _get_function_signature. - sig = self._get_function_signature(self.funcBody) + # Pass the actual function object to inspect.signature to get the signature. + if isinstance(self.funcBody, DSLCallable): + sig = self.funcBody.get_signature() + else: + sig = inspect.signature(self.funcBody) function_name = self.funcBody.__name__ bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs) @@ -1292,6 +1346,8 @@ class BaseDSL: f"Missing required argument in `{function_name}`: '{param.name}'" ) + return sig + def _func(self, funcBody, *args, **kwargs): """Decorator for MLIR functions. It cuts the boilerplate code, does the following: @@ -1324,13 +1380,16 @@ class BaseDSL: self.print_warning("Cache is disabled as user wants to compile only.") # Check the number of arguments - self._check_arg_count(*args, **kwargs) + sig = self._check_arg_count(*args, **kwargs) - args_spec = inspect.getfullargspec(funcBody) + if isinstance(funcBody, DSLCallable): + args_spec = funcBody.get_arg_spec() + else: + args_spec = inspect.getfullargspec(funcBody) # Canonicalize the input arguments canonicalized_args, canonicalized_kwargs = self._canonicalize_args( - *args, **kwargs + sig, *args, **kwargs ) # Simple name mangling @@ -1528,7 +1587,10 @@ class BaseDSL: kernelGenHelper = dkwargs.get("kernelGenHelper", None) kernel_name = funcBody.__name__ - args_spec = inspect.getfullargspec(funcBody) + if isinstance(funcBody, DSLCallable): + args_spec = funcBody.get_arg_spec() + else: + args_spec = inspect.getfullargspec(funcBody) self.funcBody = funcBody # Give each kernel a unique name. (The same kernel may be @@ -1568,11 +1630,11 @@ class BaseDSL: ), "kernelGenHelper should be explicitly specified!" # check arguments - self._check_arg_count(*args, **kwargs) + sig = self._check_arg_count(*args, **kwargs) # Canonicalize the input arguments canonicalized_args, canonicalized_kwargs = self._canonicalize_args( - *args, **kwargs + sig, *args, **kwargs ) kernel_operands, kernel_types, kernel_arg_attrs = ( diff --git a/python/CuTeDSL/base_dsl/typing.py b/python/CuTeDSL/base_dsl/typing.py index 3724d325..b46cff6d 100644 --- a/python/CuTeDSL/base_dsl/typing.py +++ b/python/CuTeDSL/base_dsl/typing.py @@ -527,7 +527,16 @@ class IntegerMeta(NumericMeta): return 2**cls.width - 1 def recast_width(cls, width): - return eval(f"Int{width}") + type_map = { + 8: Int8, + 16: Int16, + 32: Int32, + 64: Int64, + 128: Int128, + } + if width not in type_map: + raise TypeError(f"Unsupported width: {width}") + return type_map[width] class FloatMeta(NumericMeta): @@ -603,7 +612,14 @@ class FloatMeta(NumericMeta): return cls._mantissa_width def recast_width(cls, width): - return eval(f"Float{width}") + type_map = { + 16: Float16, + 32: Float32, + 64: Float64, + } + if width not in type_map: + raise TypeError(f"Unsupported width: {width}") + return type_map[width] def _arith_signless_to_int(a, target_type): diff --git a/python/CuTeDSL/cutlass/cute/__init__.py b/python/CuTeDSL/cutlass/cute/__init__.py index 0cfb0e03..9076fcb0 100644 --- a/python/CuTeDSL/cutlass/cute/__init__.py +++ b/python/CuTeDSL/cutlass/cute/__init__.py @@ -118,6 +118,9 @@ from .core import ( make_tiled_copy, make_tiled_copy_S, make_tiled_copy_D, + make_tiled_copy_A, + make_tiled_copy_B, + make_tiled_copy_C, make_tiled_copy_C_atom, basic_copy, basic_copy_if, diff --git a/python/CuTeDSL/cutlass/cute/arch/__init__.py b/python/CuTeDSL/cutlass/cute/arch/__init__.py index 42806994..01198215 100644 --- a/python/CuTeDSL/cutlass/cute/arch/__init__.py +++ b/python/CuTeDSL/cutlass/cute/arch/__init__.py @@ -90,6 +90,7 @@ __all__ = [ # "alloc_smem", "get_dyn_smem", + "get_dyn_smem_size", # # tmem.py # diff --git a/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py b/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py index f247cf60..80a4c1d0 100644 --- a/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py +++ b/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py @@ -26,7 +26,19 @@ from cutlass._mlir.dialects.nvvm import ( RoundingModeKind, ) -from ..typing import Int, Boolean, Int32, Float32, Numeric, as_numeric +from ..typing import ( + Int, + Boolean, + Int16, + Uint16, + Int32, + Uint32, + Int64, + Float32, + BFloat16, + Numeric, + as_numeric, +) WARP_SIZE = 32 FULL_MASK = 0xFFFFFFFF @@ -190,19 +202,97 @@ def shuffle_sync_op( """ if not isinstance(value, Numeric): value = as_numeric(value) - return type(value)( - nvvm.shfl_sync( - type(value).mlir_type, + if value.width > 64: + raise ValueError("shuffle_sync only supports values up to 64 bits") + + orig_type = type(value) + if value.width < 32: + if value.dtype.is_float: + value = value.to(Float32) + else: + if value.signed: + value = value.to(Int32) + else: + value = value.to(Uint32) + return orig_type( + nvvm.shfl_sync( + type(value).mlir_type, + Int32(mask).ir_value(loc=loc, ip=ip), + value.ir_value(loc=loc, ip=ip), + Int32(offset).ir_value(loc=loc, ip=ip), + Int32(mask_and_clamp).ir_value(loc=loc, ip=ip), + kind, + loc=loc, + ip=ip, + ) + ) + elif value.width == 32: + return orig_type( + nvvm.shfl_sync( + type(value).mlir_type, + Int32(mask).ir_value(loc=loc, ip=ip), + value.ir_value(loc=loc, ip=ip), + Int32(offset).ir_value(loc=loc, ip=ip), + Int32(mask_and_clamp).ir_value(loc=loc, ip=ip), + kind, + loc=loc, + ip=ip, + ) + ) + else: + if value.width != 64: + raise ValueError( + "shuffle_sync only supports 64 bits values when the bit width is larger than 32" + ) + value = llvm.bitcast( + T.i64(), value.to(ir.Value, loc=loc, ip=ip), loc=loc, ip=ip + ) + # extract low 32 bits + low_32_bits = llvm.trunc( + T.i32(), value, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip + ) + # extract high 32 bits + high_32_bits = llvm.lshr( + value, Int64(32).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + high_32_bits = llvm.trunc( + T.i32(), high_32_bits, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip + ) + + low_32_bits_shfl = nvvm.shfl_sync( + T.i32(), Int32(mask).ir_value(loc=loc, ip=ip), - value.ir_value(loc=loc, ip=ip), + low_32_bits, + Int32(offset).ir_value(loc=loc, ip=ip), + Int32(mask_and_clamp).ir_value(loc=loc, ip=ip), + kind, + loc=loc, + ip=ip, + ) + high_32_bits_shfl = nvvm.shfl_sync( + T.i32(), + Int32(mask).ir_value(loc=loc, ip=ip), + high_32_bits, Int32(offset).ir_value(loc=loc, ip=ip), Int32(mask_and_clamp).ir_value(loc=loc, ip=ip), kind, loc=loc, ip=ip, ) - ) + # combine low and high 32 bits + low_64_bit = llvm.zext(T.i64(), low_32_bits_shfl, loc=loc, ip=ip) + high_64_bit = llvm.zext(T.i64(), high_32_bits_shfl, loc=loc, ip=ip) + shlf_res = llvm.shl( + high_64_bit, + Int64(32).ir_value(loc=loc, ip=ip), + llvm.IntegerOverflowFlags.none, + loc=loc, + ip=ip, + ) + shlf_res = llvm.or_(shlf_res, low_64_bit, loc=loc, ip=ip) + shlf_res = llvm.bitcast(orig_type.mlir_type, shlf_res, loc=loc, ip=ip) + return orig_type(shlf_res) shuffle_sync = partial(shuffle_sync_op, kind=nvvm.ShflKind.idx) shuffle_sync_up = partial(shuffle_sync_op, kind=nvvm.ShflKind.up) diff --git a/python/CuTeDSL/cutlass/cute/arch/smem.py b/python/CuTeDSL/cutlass/cute/arch/smem.py index 4e5dee7b..37f87ea6 100644 --- a/python/CuTeDSL/cutlass/cute/arch/smem.py +++ b/python/CuTeDSL/cutlass/cute/arch/smem.py @@ -94,3 +94,15 @@ def get_dyn_smem( alignment, ) return _cute_nvgpu_ir.arch_get_dyn_smem(ptr=ptr_ty, loc=loc, ip=ip) + + +@dsl_user_op +def get_dyn_smem_size(*, loc=None, ip=None) -> int: + """ + Gets the size in bytes of the dynamic shared memory that was specified at kernel launch time. + This can be used for bounds checking during shared memory allocation. + + :return: The size of dynamic shared memory in bytes + :rtype: int + """ + return _cute_nvgpu_ir.arch_get_dyn_smem_size(loc=loc, ip=ip) diff --git a/python/CuTeDSL/cutlass/cute/core.py b/python/CuTeDSL/cutlass/cute/core.py index 9e240634..e3f6b1e7 100644 --- a/python/CuTeDSL/cutlass/cute/core.py +++ b/python/CuTeDSL/cutlass/cute/core.py @@ -31,6 +31,7 @@ from typing import ( Optional, ) from enum import Enum, auto +from typing_extensions import deprecated from cutlass.cutlass_dsl import ( const, @@ -517,10 +518,14 @@ class ScaledBasis: sb3 = ScaledBasis(4, [0, 1]) # 4 * E([0, 1]) # Scaled basis elements are commonly used in layout strides - layout = make_layout((4, 8), stride=(ScaledBasis(1, 0), ScaledBasis(1, 1))) + layout = make_layout((4, 8), stride=(ScaledBasis(2, 0), ScaledBasis(1, 1))) - # This creates a layout with strides (1@0, 1@1) representing + # This creates a layout with strides (2@0, 1@1) representing # a coordinate system where each dimension has its own basis + + # Example: Mapping coordinates to indices using the layout + coord = (2, 3) + idx = crd2idx(coord, layout) # Maps (2, 3) to (4, 3) """ def __init__(self, value, mode) -> None: @@ -712,8 +717,9 @@ class Swizzle(ir.Value): e.g. Given 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx + the result is - 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY + 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ `xor` YY """ @@ -897,7 +903,7 @@ class _Layout(Layout): @ir.register_value_caster(_cute_ir.ComposedLayoutType.get_static_typeid(), replace=True) class ComposedLayout(ir.Value): - """ComposedLayout represents the functional composition of layouts in CuTe. + r"""ComposedLayout represents the functional composition of layouts in CuTe. A ComposedLayout is formed by the composition of three components: inner o offset o outer, where: @@ -907,7 +913,10 @@ class ComposedLayout(ir.Value): - outer: The outer layout that is applied first ComposedLayout implements the functional composition operation where: - R(c) := (inner o offset o outer)(c) := inner(offset + outer(c)) + + .. math:: + + R(c) := (inner \\circ offset \\circ outer)(c) := inner(offset + outer(c)) This composition allows for complex transformations of coordinates and indices, enabling operations like tiling, partitioning, and reshaping of data. @@ -1670,7 +1679,10 @@ def print_tensor(tensor: Tensor, *, verbose: bool = False, loc=None, ip=None): :type ip: insertion pointer, optional :raises NotImplementedError: If the tensor type doesn't support trivial dereferencing - Example output: + **Example output:** + + .. code-block:: text + tensor(raw_ptr<@..., Float32, generic, align(4)> o (8,5):(5,1), data= [[-0.4326, -0.5434, 0.1238, 0.7132, 0.8042], [-0.8462, 0.9871, 0.4389, 0.7298, 0.6948], @@ -1973,7 +1985,8 @@ def find_if( :return: Index if found at top level, tuple of indices showing nested position, or None if not found :rtype: Union[int, Tuple[int, ...], None] - Examples: + **Examples:** + .. code-block:: python # Find the first position of x in t @@ -2186,6 +2199,23 @@ def is_congruent( ) -> bool: """ Returns whether a is congruent to b. + + Congruence is an equivalence relation between hierarchical structures. + + Two objects are congruent if: + * They have the same rank, AND + * They are both non-tuple values, OR + * They are both tuples AND all corresponding elements are congruent. + + Congruence requires type matching at each level -- scalar values match with + scalar values, and tuples match with tuples of the same rank. + + :param a: First object to compare + :type a: Union[XTuple, Layout, ComposedLayout, Tensor] + :param b: Second object to compare + :type b: Union[XTuple, Layout, ComposedLayout, Tensor] + :return: True if a and b are congruent, False otherwise + :rtype: bool """ if isinstance(a, (Layout, ComposedLayout, Tensor)): a = a.shape @@ -2204,6 +2234,22 @@ def is_weakly_congruent( ) -> bool: """ Returns whether a is weakly congruent to b. + + Weak congruence is a partial order on hierarchical structures. + + Object X is weakly congruent to object Y if: + * X is a non-tuple value, OR + * X and Y are both tuples of the same rank AND all corresponding elements are weakly congruent. + + Weak congruence allows scalar values to match with tuples, making it useful + for determining whether an object has a hierarchical structure "up to" another. + + :param a: First object to compare + :type a: Union[XTuple, Layout, ComposedLayout, Tensor] + :param b: Second object to compare + :type b: Union[XTuple, Layout, ComposedLayout, Tensor] + :return: True if a and b are weakly congruent, False otherwise + :rtype: bool """ if isinstance(a, (Layout, ComposedLayout, Tensor)): a = a.shape @@ -2261,8 +2307,11 @@ def get(input, mode: List[int], *, loc=None, ip=None): **Examples:** - For a layout like ((4,8),2):((16,1),8), get with mode=[0,1] would extract - the element 8 from the shape component. + .. code-block:: python + + layout = make_layout(((4, 8), (16, 1), 8), stride=((1, 4), (32, 0), 512)) + sub_layout = get(layout, mode=[0, 1]) # 8:4 + sub_layout = get(layout, mode=[1]) # (16, 1):(32, 0) """ # Empty mode returns input and terminates the recursive call if not mode: @@ -5065,6 +5114,11 @@ def make_layout_tv( * 2 elements per thread """ + if not isinstance(thr_layout, Layout): + raise TypeError(f"expected a Layout for thr_layout, but got {type(thr_layout)}") + if not isinstance(val_layout, Layout): + raise TypeError(f"expected a Layout for val_layout, but got {type(val_layout)}") + # Take the raked_products to compute the Layout_MN # (M,N) -> (thr_idx, val_idx) layout_mn = raked_product(thr_layout, val_layout, loc=loc, ip=ip) @@ -5081,8 +5135,52 @@ def make_layout_tv( return (tiler_mn, layout_tv) +def _make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): + if type(tiler_mn) is tuple: + tiler_mn = _pack_tile(tiler_mn, loc=loc, ip=ip) + + assert isinstance(tiler_mn, ir.Value) and _cute_ir.TileType.isinstance( + tiler_mn.type + ), f"tiler_mn must be a Tile, but got {type(tiler_mn)}" + assert is_static(layout_tv.type) and is_static( + tiler_mn.type + ), "layout tv and tiler mn must be static" + tiled_copy_ty = _cute_nvgpu_ir.TiledCopyType.get( + atom.type, layout_tv.type, tiler_mn.type + ) + + val = _cute_ir.make_tiled_copy(tiled_copy_ty, atom._trait.value, loc=loc, ip=ip) + # Instead of modifying atom which might have been provided by the user, create a brand new + # trait instance and replace the Atom ir.Value with the tiled one + trait = new_from_mlir_values(atom._trait, [val]) + return TiledCopy(atom.op, trait) + + +@deprecated("Use make_tiled_copy_tv instead") +def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): + """Create a tiled type given a TV partitioner and tiler. + + :param atom: Copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc. + :type atom: CopyAtom + :param layout_tv: Thread-value layout + :type layout_tv: Layout + :param tiler_mn: Tile size + :type tiler_mn: Tiler + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) + + @dsl_user_op -def make_tiled_copy_tv(atom, thr_layout, val_layout, *, loc=None, ip=None) -> TiledCopy: +def make_tiled_copy_tv( + atom: CopyAtom, thr_layout: Layout, val_layout: Layout, *, loc=None, ip=None +) -> TiledCopy: """Create a tiled copy given separate thread and value layouts. A TV partitioner is inferred based on the input layouts. The input thread layout @@ -5105,30 +5203,17 @@ def make_tiled_copy_tv(atom, thr_layout, val_layout, *, loc=None, ip=None) -> Ti tiler_mn, layout_tv = make_layout_tv(thr_layout, val_layout, loc=loc, ip=ip) tiler_mn = _pack_tile(product_each(tiler_mn, loc=loc, ip=ip), loc=loc, ip=ip) - if not is_static(layout_tv.type) or not is_static(tiler_mn.type): - raise ValueError( - f"expects layout tv and tiler mn, but got {layout_tv.type} and {tiler_mn.type}" - ) - tiled_copy_ty = _cute_nvgpu_ir.TiledCopyType.get( - atom.type, layout_tv.type, tiler_mn.type - ) - val = _cute_ir.make_tiled_copy(tiled_copy_ty, atom._trait.value, loc=loc, ip=ip) - # Instead of modifying atom which might have been provided by the user, create a brand new - # trait instance and replace the Atom ir.Value with the tiled one - trait = new_from_mlir_values(atom._trait, [val]) - return TiledCopy(atom.op, trait) + return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) @dsl_user_op -def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): - """Create a tiled type given a TV partitioner and tiler. +def make_tiled_copy_A(atom, tiled_mma, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the A-Layout of tiled_mma. - :param atom: Copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc. + :param atom: Copy atom :type atom: CopyAtom - :param layout_tv: Thread-value layout - :type layout_tv: Layout - :param tiler_mn: Tile size - :type tiler_mn: Tiler + :param tiled_mma: Tiled MMA + :type tiled_mma: TiledMma :param loc: Source location for MLIR, defaults to None :type loc: Optional[Location], optional :param ip: Insertion point, defaults to None @@ -5138,21 +5223,65 @@ def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): :rtype: TiledCopy """ - # tiler_mn = pack_tuple(tiler_mn, make_tile) - if type(tiler_mn) is tuple: - tiler_mn = _pack_tile(tiler_mn, loc=loc, ip=ip) - - assert is_static(layout_tv.type) and is_static( - tiler_mn.type - ), "layout tv and tiler mn must be static" - tiled_copy_ty = _cute_nvgpu_ir.TiledCopyType.get( - atom.type, layout_tv.type, tiler_mn.type + return _make_tiled_copy( + atom, + tiled_mma.tv_layout_A_tiled, + (tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_tiled_copy_B(atom, tiled_mma, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the B-Layout of tiled_mma. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_mma: Tiled MMA + :type tiled_mma: TiledMma + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, + tiled_mma.tv_layout_B_tiled, + (tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_tiled_copy_C(atom, tiled_mma, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the C-Layout of tiled_mma. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_mma: Tiled MMA + :type tiled_mma: TiledMma + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, + tiled_mma.tv_layout_C_tiled, + (tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(1)), + loc=loc, + ip=ip, ) - val = _cute_ir.make_tiled_copy(tiled_copy_ty, atom._trait.value, loc=loc, ip=ip) - # Instead of modifying atom which might have been provided by the user, create a brand new - # trait instance and replace the Atom ir.Value with the tiled one - trait = new_from_mlir_values(atom._trait, [val]) - return TiledCopy(atom.op, trait) @dsl_user_op @@ -5172,7 +5301,7 @@ def make_tiled_copy_S(atom, tiled_copy, *, loc=None, ip=None): :rtype: TiledCopy """ - return make_tiled_copy( + return _make_tiled_copy( atom, tiled_copy.layout_src_tv_tiled, tiled_copy.tiler_mn, loc=loc, ip=ip ) @@ -5194,7 +5323,7 @@ def make_tiled_copy_D(atom, tiled_copy, *, loc=None, ip=None): :rtype: TiledCopy """ - return make_tiled_copy( + return _make_tiled_copy( atom, tiled_copy.layout_dst_tv_tiled, tiled_copy.tiler_mn, loc=loc, ip=ip ) @@ -5273,7 +5402,7 @@ def make_tiled_copy_C_atom(atom: CopyAtom, mma: TiledMma, *, loc=None, ip=None): tiler_mn = _pack_tile(tiler, loc=loc, ip=ip) - return make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) + return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) #################################################################################################### @@ -5297,7 +5426,7 @@ def gemm( ) -> None: """The GEMM algorithm. - Computes ``D <- AB + C`` where ``C`` and ``D`` can alias. Note that some MMA Atoms (e.g. + Computes ``D <- A * B + C`` where ``C`` and ``D`` can alias. Note that some MMA Atoms (e.g. warpgroup-wide or tcgen05 MMAs) require manually setting an "accumulate" boolean field. All tensors must be partitioned according to the provided MMA Atom. @@ -6416,7 +6545,8 @@ class struct: """ Decorator to abstract C structure in Python DSL. - Usage: + **Usage:** + .. code-block:: # Supports base_dsl scalar int/float elements, array and nested struct: @@ -6424,12 +6554,15 @@ class struct: class complex: real : cutlass.Float32 imag : cutlass.Float32 + + @cute.struct class StorageA: mbarA : cute.struct.MemRange[cutlass.Int64, stage] compA : complex intA : cutlass.Int16 + # Supports aligment for its elements: @cute.struct class StorageB: @@ -6442,6 +6575,7 @@ class struct: x: cute.struct.Align[cutlass.Int32, 16] compA: cute.struct.Align[complex, 16] + # Statically get size and alignment: size = StorageB.__sizeof__() align = StorageB.__alignof__() diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py index 89f7061e..f8374407 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py @@ -94,12 +94,6 @@ def make_tiled_tma_atom( ip=ip, ) - # Wrap smem_layout in a composed layout to make it a TMA-friendly layout - if isinstance(smem_layout, Layout): - smem_layout = core.make_composed_layout( - core.make_swizzle(0, 4, 3), 0, smem_layout - ) - if isinstance(op, CopyBulkTensorTileG2SOp): if num_multicast != 1: raise ValueError( diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py index ccd06d01..9b4aa0db 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py @@ -37,7 +37,7 @@ from .cpasync.copy import ( def make_tiled_tma_atom_A( op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], gmem_tensor: Tensor, - smem_layout: Layout, + smem_layout: Union[Layout, core.ComposedLayout], mma_tiler_mnk: Shape, tiled_mma: core.TiledMma, cluster_shape_vmnk: Shape, @@ -76,7 +76,7 @@ def make_tiled_tma_atom_A( :param gmem_tensor: The GMEM tensor to be loaded by this copy atom :type gmem_tensor: Tensor :param smem_layout: Shared memory layout to load the tensor into (PDSL) - :type smem_layout: Layout + :type smem_layout: Union[Layout, core.ComposedLayout] :param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions :type mma_tiler_mnk: Shape :param tiled_mma: The TiledMMA that will consume the load as operands @@ -142,7 +142,7 @@ def make_tiled_tma_atom_A( def make_tiled_tma_atom_B( op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], gmem_tensor: Tensor, - smem_layout: Layout, + smem_layout: Union[Layout, core.ComposedLayout], mma_tiler_mnk: Shape, tiled_mma: core.TiledMma, cluster_shape_vmnk: Shape, @@ -181,7 +181,7 @@ def make_tiled_tma_atom_B( :param gmem_tensor: The GMEM tensor to be loaded by this copy atom :type gmem_tensor: Tensor :param smem_layout: Shared memory layout to load the tensor into (PDSL) - :type smem_layout: Layout + :type smem_layout: Union[Layout, core.ComposedLayout] :param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions :type mma_tiler_mnk: Shape :param tiled_mma: The TiledMMA that will consume the load as operands diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py index 4afeb527..2831bec6 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py @@ -42,6 +42,9 @@ __all__ = [ "MmaF16BF16Op", "MmaI8Op", "MmaFP8Op", + "MmaMXF8Op", + "MmaMXF4Op", + "MmaMXF4NVF4Op", "SmemLayoutAtomKind", # # helpers.py @@ -54,4 +57,6 @@ __all__ = [ "get_tmem_copy_properties", "find_tmem_tensor_col_offset", "make_tmem_copy", + "make_s2t_copy", + "get_s2t_smem_desc_tensor", ] diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py index 1c4439a0..df954b09 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py @@ -23,6 +23,8 @@ from ..common import OpError from ...core import CopyOp, Trait from ...typing import Numeric +from .mma import CtaGroup + class Repetition(enum.Enum): """ @@ -469,3 +471,193 @@ class St32x32bOp(_StBase): class St32x32bTrait(Trait): pass + + +@dataclass(frozen=True) +class _S2TCopyBase(CopyOp): + cta_group: CtaGroup + + admissible_archs = [ + "sm_100a", + "sm_100f", + ] + + def __post_init__(self) -> None: + # Arch verification + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + # Verify that the user provided enum values + if not isinstance(self.cta_group, CtaGroup): + raise OpError( + self, + "expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance", + ) + + def __str__(self) -> str: + res = ( + f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation" + + f"\n CTA group = {self.cta_group}" + ) + + return res + + +@dataclass(frozen=True) +class Cp128x256bOp(_S2TCopyBase): + """ + 128x256b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.128x256b`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp128x256bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 128, + 256, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.none, + ) + return Cp128x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp128x256bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Cp128x128bOp(_S2TCopyBase): + """ + 128x128b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.128x128b`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp128x128bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 128, + 128, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.none, + ) + return Cp128x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp128x128bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Cp4x256bOp(_S2TCopyBase): + """ + 4x256b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.4x256b`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp4x256bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 4, + 256, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.none, + ) + return Cp4x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp4x256bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Cp4x32x128bOp(_S2TCopyBase): + """ + 32x128b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.32x128b`` qualifier with ``warpx4`` broadcast qualifier enabled. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp4x32x128bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 32, + 128, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.x4, + ) + return Cp4x32x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp4x32x128bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Cp2x64x128b0213Op(_S2TCopyBase): + """ + 64x128b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::02_13`` broadcast qualifier enabled. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp2x64x128b0213Trait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 64, + 128, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.lw_0213, + ) + return Cp2x64x128b0213Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp2x64x128b0213Trait(Trait): + pass + + +@dataclass(frozen=True) +class Cp2x64x128b0123Op(_S2TCopyBase): + """ + 64x128b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::01_23`` broadcast qualifier enabled. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp2x64x128b0123Trait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 64, + 128, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.lw_0123, + ) + return Cp2x64x128b0123Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp2x64x128b0123Trait(Trait): + pass diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py index cac64131..0ad27e62 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py @@ -299,3 +299,30 @@ def make_tmem_copy( ) new_trait = type(atom._trait)(tiled_copy_val) return core.TiledCopy(atom.op, new_trait) + + +@dsl_user_op +def make_s2t_copy( + atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None +) -> core.TiledCopy: + """ + Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor. + """ + tiled_copy_val = _cute_nvgpu_ir.atom_make_s2t_copy( + atom._trait.value, tmem_tensor.value, loc=loc, ip=ip + ) + new_trait = type(atom._trait)(tiled_copy_val) + return core.TiledCopy(atom.op, new_trait) + + +@dsl_user_op +def get_s2t_smem_desc_tensor( + atom: core.CopyAtom, smem_tensor: Tensor, *, loc=None, ip=None +) -> Tensor: + """ + Returns the SMEM descriptor tensor from a S2T copy atom and a SMEM tensor. + """ + smem_desc_tensor = _cute_nvgpu_ir.atom_get_copy_s2t_smem_desc_view( + atom._trait.value, smem_tensor.value, loc=loc, ip=ip + ) + return smem_desc_tensor diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py index 7f8f0de9..3a938523 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py @@ -20,9 +20,12 @@ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir import ir from ..common import OpError -from ...core import MmaOp, Trait, _pack_shape, rank, depth, _Tensor +from ... import core +from ...core import Trait, _pack_shape, rank, depth, _Tensor from ...typing import ( Shape, + Float4E2M1FN, + Float8E8M0FNU, Float8E5M2, Float8E4M3FN, Float16, @@ -35,6 +38,7 @@ from ...typing import ( Int32, Numeric, AddressSpace, + Pointer, ) @@ -104,7 +108,6 @@ class CtaGroup(enum.Enum): def __repr__(self) -> str: return f"<{self.__class__.__name__}.{self.name}>" - class Field(enum.Enum): """ An enumeration for the fields of the MMA Atom that can be modified at runtime. @@ -113,6 +116,8 @@ class Field(enum.Enum): NEGATE_A = "neg_a" NEGATE_B = "neg_b" ACCUMULATE = "accum_c" + SFA = "sf_a" + SFB = "sf_b" def __str__(self) -> str: return f"{self.__class__.__name__}.{self.name}" @@ -124,9 +129,9 @@ class Field(enum.Enum): return self.value -# Base class for all tcgen05 MMA Ops used to factor out some internal code +# Base class for all tcgen05 MMA Ops with syntax `tcgen05.mma.cta_group.kind` used to factor out some internal code @dataclass(frozen=True) -class MmaOp(MmaOp): +class MmaOp(core.MmaOp): a_dtype: Type[Numeric] b_dtype: Type[Numeric] acc_dtype: Type[Numeric] @@ -256,6 +261,155 @@ class MmaTrait(Trait): ) +# Base class for all tcgen05 BlockScaled MMA Ops with syntax `tcgen05.mma.cta_group.kind.block_scale` used to factor out some internal code +@dataclass(frozen=True) +class BlockScaledMmaOp(core.MmaOp): + a_dtype: Type[Numeric] + b_dtype: Type[Numeric] + acc_dtype: Float32 + sf_dtype: Type[Numeric] + sf_vec_size: int + shape_mnk: Shape + cta_group: CtaGroup + a_src: OperandSource + a_major_mode: OperandMajorMode + b_major_mode: OperandMajorMode + + admissible_archs = [ + "sm_100a", + ] + + def __post_init__(self) -> None: + # Verify arch + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + # Verify that the user provided enum values + if not isinstance(self.cta_group, CtaGroup): + raise OpError( + self, + "expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance", + ) + if not isinstance(self.a_src, OperandSource): + raise OpError( + self, + "expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance", + ) + if not isinstance(self.a_major_mode, OperandMajorMode): + raise OpError( + self, + "expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", + ) + if not isinstance(self.b_major_mode, OperandMajorMode): + raise OpError( + self, + "expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", + ) + # Verify the instruction shape + if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1): + raise OpError( + self, + f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, " + f"but got {self.shape_mnk}", + ) + m, n = self.shape_mnk[0], self.shape_mnk[1] + if self.cta_group == CtaGroup.ONE: + if m != 128: + raise OpError(self, f"expects the M-mode to be 128, but got {m}") + + if (n < 8) or (n > 256) or (n % 8 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}", + ) + else: + if m not in [128, 256]: + raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}") + if (n < 16) or (n > 256) or (n % 16 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 16 <= N <= 256 and N % 16 == 0, but got {n}", + ) + if self.sf_vec_size not in [16, 32]: + raise OpError( + self, + f"expects the scale factor vector size to be 16 or 32, but got {self.sf_vec_size}", + ) + + def __str__(self) -> str: + return ( + self.__class__.descriptive_name # type: ignore + + f"\n A data type = {self.a_dtype}" + + f"\n B data type = {self.b_dtype}" + + f"\n Accumulator data type = {self.acc_dtype}" + + f"\n Scale factor data type = {self.sf_dtype}" + + f"\n Scale factor vector size = {self.sf_vec_size}" + + f"\n CTA group = {self.cta_group}" + + f"\n A source location = {self.a_src}" + + f"\n A major mode = {self.a_major_mode}" + + f"\n B major mode = {self.b_major_mode}" + + f"\n Instruction shape MNK = {self.shape_mnk}" + ) + + def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand A, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand B, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + +class BlockScaledMmaTraits(Trait): + admissible_fields = [ + Field.ACCUMULATE, + Field.NEGATE_A, + Field.NEGATE_B, + Field.SFA, + Field.SFB, + ] + + def set(self, field, value, *, loc=None, ip=None) -> None: + if field not in self.admissible_fields: + raise ValueError( + f"expects field to be one of {self.admissible_fields}, but got {field}" + ) + if field in [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B]: + value = Boolean(value).ir_value(loc=loc, ip=ip) + elif field in [Field.SFA, Field.SFB]: + if not isinstance(value, Pointer): + raise ValueError( + f"expects value to be a pointer for {field}, but got {type(value).__name__}" + ) + value = value.value + + field_name = f"#cute_nvgpu.atom_mma_field_sm100_block_scaled<{field._to_ir_field_name()}>" + attr = ir.Attribute.parse(field_name) + self.value = _cute_nvgpu_ir.atom_set_value( + self.value, attr, value, loc=loc, ip=ip + ) + + # # TF32 MMA # @@ -602,6 +756,262 @@ class MmaFP8Trait(MmaTrait): pass +# +# MXF8F6F4 MMA +# + + +@dataclass(frozen=True) +class MmaMXF8Op(BlockScaledMmaOp): + """ + MXF8 tcgen05 BlockScaled MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::mxf8f6f4`` qualifier. + """ + + descriptive_name = "tcgen05 MXF8 BlockScaled MMA Operation" + + def __init__( + self, + ab_dtype: Type[Numeric], + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + ) -> None: + super().__init__( + ab_dtype, + ab_dtype, + Float32, + Float8E8M0FNU, + 32, + instruction_shape, + cta_group, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self) -> None: + # Input data type verification + if self.a_dtype not in [Float8E5M2, Float8E4M3FN]: + raise OpError( + self, + "expects the 'ab_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN", + ) + assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" + # Instruction shape verification + instruction_k = 32 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.sf_dtype.mlir_type, + self.a_src._to_ir(), + self.sf_vec_size, + ) + return MmaMXF8Trait( + _cute_nvgpu_ir.make_sm100_mma_bs( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + loc=loc, + ip=ip, + ) + ) + + +class MmaMXF8Trait(BlockScaledMmaTraits): + pass + + +# +# MXF4 MMA +# + + +@dataclass(frozen=True) +class MmaMXF4Op(BlockScaledMmaOp): + """ + MXF4 tcgen05 BlockScaled MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::mxf4`` qualifier. + """ + + descriptive_name = "tcgen05 MXF4 BlockScaled MMA Operation" + + def __init__( + self, + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + ) -> None: + super().__init__( + Float4E2M1FN, + Float4E2M1FN, + Float32, + Float8E8M0FNU, + 32, + instruction_shape, + cta_group, + a_src, + OperandMajorMode.K, + OperandMajorMode.K, + ) + self._verify() + + def _verify(self) -> None: + # Instruction shape verification + instruction_k = 64 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.sf_dtype.mlir_type, + self.a_src._to_ir(), + self.sf_vec_size, + ) + return MmaMXF4Trait( + _cute_nvgpu_ir.make_sm100_mma_bs( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + loc=loc, + ip=ip, + ) + ) + + +class MmaMXF4Trait(BlockScaledMmaTraits): + pass + + +# +# MXF4NVF4 MMA +# + + +@dataclass(frozen=True) +class MmaMXF4NVF4Op(BlockScaledMmaOp): + """ + MXF4NVF4 tcgen05 BlockScaled MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::mxf4nvf4`` qualifier. + """ + + descriptive_name = "tcgen05 MXF4NVF4 BlockScaled MMA Operation" + + def __init__( + self, + sf_dtype: Type[Numeric], + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + ) -> None: + super().__init__( + Float4E2M1FN, + Float4E2M1FN, + Float32, + sf_dtype, + 16, + instruction_shape, + cta_group, + a_src, + OperandMajorMode.K, + OperandMajorMode.K, + ) + self._verify() + + def _verify(self) -> None: + # Scale Factor data type verification + if self.sf_dtype not in [Float8E8M0FNU, Float8E4M3FN]: + raise OpError( + self, + "expects the 'sf_dtype' Op parameter to be one of Float8E8M0FNU", + ) + # Instruction shape verification + instruction_k = 64 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.sf_dtype.mlir_type, + self.a_src._to_ir(), + self.sf_vec_size, + ) + return MmaMXF4NVF4Trait( + _cute_nvgpu_ir.make_sm100_mma_bs( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + loc=loc, + ip=ip, + ) + ) + + +class MmaMXF4NVF4Trait(BlockScaledMmaTraits): + pass + #################################################################################################### # # SMEM layout atoms diff --git a/python/CuTeDSL/cutlass/cute/testing.py b/python/CuTeDSL/cutlass/cute/testing.py index 4a1bb016..88e0da04 100644 --- a/python/CuTeDSL/cutlass/cute/testing.py +++ b/python/CuTeDSL/cutlass/cute/testing.py @@ -28,11 +28,12 @@ import cutlass.base_dsl.jit_executor import cutlass.cute as cute from cutlass._mlir.dialects import builtin, cf, nvvm, vector from cutlass.cute import core, nvgpu -from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, t +from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, t, dsl_user_op -def assert_(cond, msg=None): - cf.assert_(t.Boolean(cond).ir_value(), msg if msg else "") +@dsl_user_op +def assert_(cond, msg=None, *, loc=None, ip=None): + cf.assert_(t.Boolean(cond).ir_value(), msg if msg else "", loc=loc, ip=ip) def _maybe_recast_tensor_from_f4(src: core.Tensor, tv_layout: core.Layout): @@ -214,7 +215,14 @@ def convert(src: core.Tensor, dst: core.Tensor): dst.shape ), "Shape of src and dst tensors should be the same rank." # find leading mode - leading_mode = np.argmin([np.min(s) for s in src.stride]) + leading_mode = [ + idx + for idx, (shape, stride) in enumerate(zip(src.shape, src.stride)) + if shape > 1 and stride == 1 + ] + if len(leading_mode) != 1: + raise ValueError(f"Leading mode should be unique, but got {leading_mode}") + leading_mode = leading_mode[0] elem_per_copy = 2 @@ -345,7 +353,7 @@ def benchmark( callable: Callable, *, warmup_iterations: int = 10, - profiling_iterations: int = 100, + iterations: int = 100, stream: Optional[cuda_driver.CUstream] = None, kernel_arguments: Optional[JitArguments] = None, workspace_generator: Optional[Callable[[], JitArguments]] = None, @@ -365,7 +373,7 @@ def benchmark( pass time_us = benchmark(user_function, kernel_arguments=JitArguments(a, b, c, stream) - warmup_iterations=10, profiling_iterations=100 + warmup_iterations=10, iterations=100 stream=stream) To prevent skewing results by repeately accessing the L2 cache, use the workspace_count and workspace_generator @@ -388,7 +396,7 @@ def benchmark( workspace_generator=workspace_generator, workspace_count=10, warmup_iterations=10000, - profiling_iterations=1000) + iterations=1000) To benchmark you may always configure the function being profiled (callable), the warmup iterations, and the number of profiling iterations. @@ -402,8 +410,8 @@ def benchmark( :type callable: Callable :param warmup_iterations: Number of warmup iterations, defaults to 10 :type warmup_iterations: int, optional - :param profiling_iterations: Number of benchmark iterations, defaults to 100 - :type profiling_iterations: int, optional + :param iterations: Number of benchmark iterations, defaults to 100 + :type iterations: int, optional :param stream: Stream kernel is launched in, defaults to CUDA stream default :type stream: CUstream, None :param kernel_arguments: Kernel arguments to launch callable with, defaults to None @@ -502,7 +510,7 @@ def benchmark( stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal ) _cuda_success(err, "Error on stream capture") - _loop_and_call_kernel(profiling_iterations, workspace_index) + _loop_and_call_kernel(iterations, workspace_index) err, gprofile = cuda_runtime.cudaStreamEndCapture(stream) _cuda_success(err, "Error on stream capture") @@ -557,7 +565,7 @@ def benchmark( # Record start event err = cuda_driver.cuEventRecord(start_event, stream) _cuda_success(err, "Error on recording event") - _loop_and_call_kernel(profiling_iterations, workspace_index) + _loop_and_call_kernel(iterations, workspace_index) # Record end event err = cuda_driver.cuEventRecord(end_event, stream) _cuda_success(err, "Error on recording event") @@ -573,6 +581,30 @@ def benchmark( err = cuda_driver.cuEventDestroy(end_event) _cuda_success(err, "Error on destroying event") - return elapsed_time / profiling_iterations * 1e3 + return elapsed_time / iterations * 1e3 +def get_workspace_count( + one_workspace_bytes: int, warmup_iterations: int, iterations: int +) -> int: + """Calculate the number of workspaces needed to fill L2 cache. + + :param one_workspace_bytes: Size of one workspace in bytes + :type one_workspace_bytes: int + :param warmup_iterations: Number of warmup iterations + :type warmup_iterations: int + :param iterations: Number of iterations + :type iterations: int + :return: Number of workspaces needed + :rtype: int + """ + num_l2_cache_bytes = cutlass.utils.HardwareInfo().get_l2_cache_size_in_bytes() + return max( + 1, + min( + warmup_iterations + iterations, # Don't create more workspaces than needed + (num_l2_cache_bytes + one_workspace_bytes - 1) + // one_workspace_bytes, # Ceiling division + ), + ) + diff --git a/python/CuTeDSL/cutlass/pipeline/sm90.py b/python/CuTeDSL/cutlass/pipeline/sm90.py index 53e4dc8e..71b99519 100644 --- a/python/CuTeDSL/cutlass/pipeline/sm90.py +++ b/python/CuTeDSL/cutlass/pipeline/sm90.py @@ -91,33 +91,6 @@ class PipelineAsync: - D: Data ready (producer has written data to buffer) - R: Consumer reading (consumer is consuming data from buffer) - **Example:** - - .. code-block:: python - - # Create pipeline with 5 stages - pipeline = PipelineAsync.create( - num_stages=5, # number of pipeline stages - producer_group=producer_warp, - consumer_group=consumer_warp - barrier_storage=smem_ptr, # smem pointer for array of mbarriers in shared memory - ) - - # Producer side - producer = pipeline.make_pipeline_producer(producer_warp) - for i in range(num_iterations): - producer.acquire() # Wait for buffer to be empty - # Write data to pipeline buffer - producer.commit() # Signal buffer is full - producer.advance() # Move index to next stage - - # Consumer side - consumer = pipeline.make_pipeline_consumer(consumer_warp) - for i in range(num_iterations): - consumer.wait() # Wait for buffer to be full - # Read data from pipeline buffer - consumer.release() # Signal buffer is empty - consumer.advance() # Move index to next stage """ sync_object_full: SyncObject @@ -259,16 +232,6 @@ class PipelineAsync: state.advance() self.producer_acquire(state) - # Util methods to manage produer and consumer - def make_pipeline_producer(self, group: CooperativeGroup): - state = make_pipeline_state(PipelineUserType.Producer, self.num_stages) - return PipelineProducer(self, state, group) - - def make_pipeline_consumer(self, group: CooperativeGroup): - state = make_pipeline_state(PipelineUserType.Consumer, self.num_stages) - return PipelineConsumer(self, state, group) - - @dataclass(frozen=True) class PipelineTmaAsync(PipelineAsync): """ @@ -593,211 +556,3 @@ class PipelineTmaStore(PipelineAsync): self.sync_object_full.tail() -################################################################# -# Utilities to help user of pipeline to simplify the workflow -################################################################# - - -class PipelineProducer: - """A class representing a producer in an asynchronous pipeline. - - The Producer class manages the producer side of an asynchronous pipeline, handling - synchronization and state management for producing data. It provides methods for - acquiring, committing, and advancing through pipeline stages. - - :ivar _pipeline: The asynchronous pipeline this producer belongs to - :type _pipeline: PipelineAsync - :ivar _state: The current state of the producer in the pipeline - :type _state: PipelineState - :ivar _group: The cooperative group this producer operates in - :type _group: CooperativeGroup - - **Examples:** - - .. code-block:: python - - pipeline = PipelineAsync.create(...) - producer = pipeline.create_producer(producer_group, stages) - for i in range(iterations): - producer.acquire() # Wait for buffer to be empty - # Produce data - producer.commit() # Signal data is ready - producer.advance() # Move to next stage - """ - - _pipeline: PipelineAsync - _state: PipelineState - _group: CooperativeGroup - - def __init__(self, pipeline, state, group: CooperativeGroup): - """Initialize a new Producer instance. - - :param pipeline: The pipeline this producer belongs to - :type pipeline: PipelineAsync - :param state: Initial pipeline state - :type state: PipelineState - :param group: The cooperative group for synchronization - :type group: CooperativeGroup - """ - self._pipeline = pipeline - self._state = state - self._group = group - - @property - def index(self): - """Get the index of the current pipeline stage.""" - return self._state.index - - def get_barrier(self): - """Get the barrier pointer for the current pipeline stage. - - :return: Pointer to the barrier for the current stage - :rtype: cute.Pointer - """ - return self._pipeline.producer_get_barrier(self._state) - - def acquire(self): - """Wait for the current buffer to be empty before producing data. - This is a blocking operation. - """ - self._pipeline.producer_acquire(self._state) - - def try_acquire(self): - """Try to acquire the current buffer without blocking. - - :return: True if acquisition was successful, False otherwise - :rtype: bool - """ - self._pipeline.producer_try_acquire(self._state) - - def commit(self): - """Signal that data production is complete for the current stage. - This allows consumers to start processing the data. - """ - self._pipeline.producer_commit(self._state) - - def tail(self): - """Ensure all used buffers are properly synchronized before producer exit. - This should be called before the producer finishes to avoid dangling signals. - """ - self._pipeline.producer_tail(self._state) - - def advance(self): - """Move to the next pipeline stage.""" - self._state.advance() - - def __extract_mlir_values__(self): - """Extract MLIR values from the current state. - - :return: List of MLIR values representing the current state - :rtype: list - """ - # TODO: need to handle pipeline as well - return self._state.__extract_mlir_values__() - - def __new_from_mlir_values__(self, values): - """Create a new Producer instance from MLIR values. - - :param values: MLIR values to initialize the state - :type values: Any - :return: New Producer instance with state initialized from values - :rtype: Producer - """ - return PipelineProducer( - self._pipeline, self._state.__new_from_mlir_values__(values), self._group - ) - - -class PipelineConsumer: - """A class representing a consumer in an asynchronous pipeline. - - The Consumer class manages the consumer side of an asynchronous pipeline, handling - synchronization and state management for consuming data. It provides methods for - waiting, releasing, and advancing through pipeline stages. - - :ivar _pipeline: The asynchronous pipeline this consumer belongs to - :type _pipeline: PipelineAsync - :ivar _state: The current state of the consumer in the pipeline - :type _state: PipelineState - :ivar _group: The cooperative group this consumer operates in - :type _group: CooperativeGroup - - **Examples:** - .. code-block:: python - - pipeline = PipelineAsync.create(...) - consumer = pipeline.create_consumer(consumer_group, stages) - for i in range(iterations): - consumer.wait() # Wait for data to be ready - # Consume data - consumer.release() # Signal buffer is empty - consumer.advance() # Move to next stage - """ - - _pipeline: PipelineAsync - _state: PipelineState - _group: CooperativeGroup - - def __init__(self, pipeline, state: PipelineState, group: CooperativeGroup): - """Initialize a new Consumer instance. - - :param pipeline: The pipeline this consumer belongs to - :type pipeline: PipelineAsync - :param state: Initial pipeline state - :type state: PipelineState - :param group: The cooperative group for synchronization - :type group: CooperativeGroup - """ - self._pipeline = pipeline - self._group = group - self._state = state - - @property - def index(self): - """Get the index of the current pipeline stage.""" - return self._state.index - - def wait(self): - """Wait for data to be ready in the current buffer. - This is a blocking operation. - """ - self._pipeline.consumer_wait(self._state) - - def try_wait(self): - """Try to check if data is ready without blocking. - - :return: True if data is ready, False otherwise - :rtype: bool - """ - self._pipeline.consumer_try_wait(self._state) - - def release(self): - """Signal that data consumption is complete for the current stage. - This allows producers to start producing new data. - """ - self._pipeline.consumer_release(self._state) - - def advance(self): - """Move to the next pipeline stage.""" - self._state.advance() - - def __extract_mlir_values__(self): - """Extract MLIR values from the current state. - - :return: List of MLIR values representing the current state - :rtype: list - """ - return self._state.__extract_mlir_values__() - - def __new_from_mlir_values__(self, values): - """Create a new Consumer instance from MLIR values. - - :param values: MLIR values to initialize the state - :type values: Any - :return: New Consumer instance with state initialized from values - :rtype: Consumer - """ - # TODO: need to call pipeline.__new_from_mlir_values__ recursively - return PipelineConsumer( - self._pipeline, self._state.__new_from_mlir_values__(values), self._group - ) diff --git a/python/CuTeDSL/cutlass/torch.py b/python/CuTeDSL/cutlass/torch.py index 32bf0738..066c2816 100644 --- a/python/CuTeDSL/cutlass/torch.py +++ b/python/CuTeDSL/cutlass/torch.py @@ -29,7 +29,7 @@ from cutlass.cute.typing import ( from cutlass.cute.runtime import from_dlpack import cutlass.cute as cute import torch -from cuda import cuda +import cuda.bindings.driver as cuda def dtype(ty: Type[Numeric]): diff --git a/python/CuTeDSL/cutlass/utils/__init__.py b/python/CuTeDSL/cutlass/utils/__init__.py index 30bd2d4c..39add25e 100644 --- a/python/CuTeDSL/cutlass/utils/__init__.py +++ b/python/CuTeDSL/cutlass/utils/__init__.py @@ -28,12 +28,22 @@ from .blackwell_helpers import ( make_smem_layout_b, make_smem_layout_epi, make_trivial_tiled_mma, + make_blockscaled_trivial_tiled_mma, ) from .hopper_helpers import ( sm90_get_smem_store_op, ) +from .blockscaled_layout import ( + BlockScaledBasicChunk, + tile_atom_to_shape_SF, + make_smem_layout_sfa, + make_smem_layout_sfb, + make_tmem_layout_sfa, + make_tmem_layout_sfb, +) + from .grouped_gemm_tile_scheduler_helper import ( GroupSearchResult, GroupedGemmGroupSearchState, @@ -50,7 +60,12 @@ from .smem_allocator import SmemAllocator from .layout import LayoutEnum +from .smem_capacity import ( + get_smem_capacity_in_bytes, +) + __all__ = [ + "get_smem_capacity_in_bytes", "SmemAllocator", "LayoutEnum", "WorkTileInfo", diff --git a/python/CuTeDSL/cutlass/utils/ampere_helpers.py b/python/CuTeDSL/cutlass/utils/ampere_helpers.py index 1ba97e1c..1341756f 100644 --- a/python/CuTeDSL/cutlass/utils/ampere_helpers.py +++ b/python/CuTeDSL/cutlass/utils/ampere_helpers.py @@ -10,14 +10,22 @@ # is strictly prohibited. from enum import Enum +from typing_extensions import deprecated +import warnings +@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead") class SmemCapacity(Enum): SM80_SMEM_CAPACITY_BYTES = (164 - 1) * 1024 SM86_SMEM_CAPACITY_BYTES = (100 - 1) * 1024 SM89_SMEM_CAPACITY_BYTES = (100 - 1) * 1024 +warnings.warn( + "SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead", + DeprecationWarning, + stacklevel=2, +) # Dictionary to map compute capability to SMEM capacity SMEM_CAPACITY = { "sm80": SmemCapacity.SM80_SMEM_CAPACITY_BYTES.value, diff --git a/python/CuTeDSL/cutlass/utils/blackwell_helpers.py b/python/CuTeDSL/cutlass/utils/blackwell_helpers.py index 167f3efb..6fb6bf4d 100644 --- a/python/CuTeDSL/cutlass/utils/blackwell_helpers.py +++ b/python/CuTeDSL/cutlass/utils/blackwell_helpers.py @@ -12,6 +12,8 @@ from enum import Enum from math import log2, ceil from typing import List, Type, Union, Tuple +from typing_extensions import deprecated +import warnings from cutlass.cutlass_dsl import ( Float16, @@ -22,6 +24,7 @@ from cutlass.cutlass_dsl import ( Int8, Float8E4M3FN, Float8E5M2, + Float4E2M1FN, Numeric, NumericMeta, dsl_user_op, @@ -34,6 +37,9 @@ from cutlass.cute.nvgpu.tcgen05 import ( MmaTF32Op, MmaI8Op, MmaFP8Op, + MmaMXF8Op, + MmaMXF4Op, + MmaMXF4NVF4Op, OperandSource, OperandMajorMode, CtaGroup, @@ -58,6 +64,24 @@ from cutlass.cute.nvgpu.cpasync import ( from cutlass.utils.layout import LayoutEnum +@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead") +class SmemCapacity(Enum): + SM100_SMEM_CAPACITY_BYTES = (228 - 1) * 1024 + SM120_SMEM_CAPACITY_BYTES = (100 - 1) * 1024 + + +warnings.warn( + "SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead", + DeprecationWarning, + stacklevel=2, +) +# Dictionary to map compute capability to SMEM capacity +SMEM_CAPACITY = { + "sm100": SmemCapacity.SM100_SMEM_CAPACITY_BYTES.value, + "sm120": SmemCapacity.SM120_SMEM_CAPACITY_BYTES.value, +} + + @dsl_user_op def compute_epilogue_tile_shape( cta_tile_shape: cute.Shape, @@ -822,18 +846,6 @@ def make_smem_layout_epi( return epi_smem_layout_staged -class SmemCapacity(Enum): - SM100_SMEM_CAPACITY_BYTES = (228 - 1) * 1024 - SM120_SMEM_CAPACITY_BYTES = (100 - 1) * 1024 - - -# Dictionary to map compute capability to SMEM capacity -SMEM_CAPACITY = { - "sm100": SmemCapacity.SM100_SMEM_CAPACITY_BYTES.value, - "sm120": SmemCapacity.SM120_SMEM_CAPACITY_BYTES.value, -} - - @dsl_user_op def make_trivial_tiled_mma( ab_dtype: Type[Numeric], @@ -917,6 +929,76 @@ def make_trivial_tiled_mma( return cute.make_tiled_mma(cute.make_mma_atom(mma_op)) +@dsl_user_op +def make_blockscaled_trivial_tiled_mma( + ab_dtype: Type[Numeric], + a_leading_mode: OperandMajorMode, + b_leading_mode: OperandMajorMode, + sf_dtype: Type[Numeric], + sf_vec_size: int, + cta_group: CtaGroup, + mma_tiler_mn: Tuple[int, int], + a_source: OperandSource = OperandSource.SMEM, + *, + loc=None, + ip=None, +) -> cute.TiledMma: + """Make a BlockScaled tiled MMA atom with given data type, leading dimension, cta group and mma tile shape. + By default, the MMA atom is created with SMEM operand source for A. + + :param ab_dtype: Data type of operands A and B. + :type ab_dtype: type[Numeric] + :param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N). + :type a_leading_mode: tcgen05.OperandMajorMode + :param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N). + :type b_leading_mode: tcgen05.OperandMajorMode + :param sf_dtype: Data type of the Scale Factor. + :type sf_dtype: type[Numeric] + :param sf_vec_size: The vector size of the Scale Factor. + :type sf_vec_size: int + :param cta_group: The CTA group to use. + :type cta_group: tcgen05.CtaGroup + :param mma_tiler_mn: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mn: Tuple[int, int] + :param a_source: The source of operand A (SMEM by default or TMEM). + :type a_source: OperandSource + + :return: A tiled MMA atom. + :rtype: cute.TiledMma + + :raises TypeError: If the data type is not supported. + """ + if ab_dtype in {Float8E4M3FN, Float8E5M2}: + mma_op = MmaMXF8Op( + ab_dtype, + (*mma_tiler_mn, 32), + cta_group, + a_source, + a_leading_mode, + b_leading_mode, + ) + elif ab_dtype == Float4E2M1FN: + if sf_vec_size == 32: + mma_op = MmaMXF4Op( + (*mma_tiler_mn, 64), + cta_group, + a_source, + ) + elif sf_vec_size == 16: + mma_op = MmaMXF4NVF4Op( + sf_dtype, + (*mma_tiler_mn, 64), + cta_group, + a_source, + ) + else: + raise ValueError(f"unsupported sf_vec_size, got {sf_vec_size}") + else: + raise TypeError(f"unsupported ab_dtype, got {ab_dtype}") + + return cute.make_tiled_mma(cute.make_mma_atom(mma_op)) + + @dsl_user_op def cluster_shape_to_tma_atom_A( cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None diff --git a/python/CuTeDSL/cutlass/utils/blockscaled_layout.py b/python/CuTeDSL/cutlass/utils/blockscaled_layout.py new file mode 100644 index 00000000..fa1e2eb7 --- /dev/null +++ b/python/CuTeDSL/cutlass/utils/blockscaled_layout.py @@ -0,0 +1,287 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from dataclasses import dataclass, field +from typing import Union + +from cutlass.cutlass_dsl import dsl_user_op + +import cutlass.cute as cute +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir + + +@dataclass(frozen=True) +class BlockScaledBasicChunk: + """ + The basic scale factor atom layout decided by tcgen05 BlockScaled MMA Ops. + + This class represents the fixed layout pattern for scale factors used in + tcgen05 BlockScaled MMA Ops. The layout is determined by the + instruction specification and cannot be modified. + See `PTX documentation `. + """ + + sf_vec_size: int + major_mode: OperandMajorMode = OperandMajorMode.K + _layout: cute.Layout = field(init=False, repr=False) + + def __post_init__(self) -> None: + if self.major_mode == OperandMajorMode.K: + # K-major layout: (AtomMN, AtomK) + atom_shape = ((32, 4), (self.sf_vec_size, 4)) + atom_stride = ((16, 4), (0, 1)) + else: + # MN-major layout: (AtomK, AtomMN) + atom_shape = ((self.sf_vec_size, 4), (32, 4)) + atom_stride = ((0, 1), (16, 4)) + + object.__setattr__( + self, "_layout", cute.make_layout(atom_shape, stride=atom_stride) + ) + + @property + def layout(self) -> cute.Layout: + """ + Get the layout for this block scaled chunk. + + :return: The layout representing the scale factor atom + :rtype: cute.Layout + """ + return self._layout + + +@dsl_user_op +def tile_atom_to_shape_SF( + Shape: cute.Shape, + sf_vec_size: int, + *, + loc=None, + ip=None, +) -> cute.Layout: + """ + A helper function to get dynamic SFA/SFB layout by filling dynamic A/B shape to the scale factor atom layout. + + :param Shape: The shape of the A/B tensor + :param sf_vec_size: Scale factor vector size + + :return: The layout of the SFA/SFB tensor + :rtype: cute.Layout + """ + # ((Atom_MN, Rest_MN),(Atom_K, Rest_K),RestL) + sf_layout = cute.tile_to_shape( + BlockScaledBasicChunk(sf_vec_size).layout, Shape, (2, 1, 3) + ) + return sf_layout + + +@dsl_user_op +def make_smem_layout_sfa( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + sf_vec_size: int, + num_stages: int, + *, + loc=None, + ip=None, +) -> cute.Layout: + """ + Make smem layout for SFA based on: + 1. BlockScaledBasicChunk + 2. MMA tiler shape + 3. Scale factor vector size + 4. Number of stages + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param num_stages: The number of stages + :type num_stages: int + + :return: Smem layout for SFA + :rtype: cute.Layout + """ + # (CTA_Tile_Shape_M, MMA_Tile_Shape_K) + sfa_tile_shape = ( + mma_tiler_mnk[0] // cute.size(tiled_mma.thr_id.shape), + mma_tiler_mnk[2], + ) + + # ((Atom_M, Rest_M),(Atom_K, Rest_K)) + smem_layout = cute.tile_to_shape( + BlockScaledBasicChunk(sf_vec_size).layout, + sfa_tile_shape, + (2, 1), + ) + + mma_tile_inst_k = 4 + # (CTA_Tile_Shape_M, MMA_Inst_Shape_K) + sfa_tile_shape = cute.shape_div(sfa_tile_shape, (1, mma_tile_inst_k)) + # ((Atom_Inst_M, Atom_Inst_K), MMA_M, MMA_K)) + smem_layout = cute.tiled_divide(smem_layout, sfa_tile_shape) + + atom_m = 128 + tiler_inst = ((atom_m, sf_vec_size),) + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K) + smem_layout = cute.logical_divide(smem_layout, tiler_inst) + + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE) + sfa_smem_layout_staged = cute.append( + smem_layout, + cute.make_layout( + num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout)) + ), + ) + + return sfa_smem_layout_staged + + +@dsl_user_op +def make_smem_layout_sfb( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + sf_vec_size: int, + num_stages: int, + *, + loc=None, + ip=None, +) -> cute.Layout: + """ + Make smem layout for SFB based on: + 1. BlockScaledBasicChunk + 2. MMA tiler shape + 3. Scale factor vector size + 4. Number of stages + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param num_stages: The number of stages + :type num_stages: int + + :return: Smem layout for SFA + :rtype: cute.Layout + """ + # (Round_Up(CTA_Tile_Shape_N, 128), MMA_Tile_Shape_K) + sfb_tile_shape = ( + cute.round_up(mma_tiler_mnk[1], 128), + mma_tiler_mnk[2], + ) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K)) + smem_layout = cute.tile_to_shape( + BlockScaledBasicChunk(sf_vec_size).layout, + sfb_tile_shape, + (2, 1), + ) + + mma_tile_inst_k = 4 + # (CTA_Tile_Shape_N, MMA_Inst_Shape_K) + sfb_tile_shape = cute.shape_div(sfb_tile_shape, (1, mma_tile_inst_k)) + # ((Atom_Inst_N, Atom_Inst_K), MMA_N, MMA_K) + smem_layout = cute.tiled_divide(smem_layout, sfb_tile_shape) + + atom_n = 128 + tiler_inst = ((atom_n, sf_vec_size),) + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K) + smem_layout = cute.logical_divide(smem_layout, tiler_inst) + + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE) + sfb_smem_layout_staged = cute.append( + smem_layout, + cute.make_layout( + num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout)) + ), + ) + + return sfb_smem_layout_staged + + +@dsl_user_op +def make_tmem_layout_sfa( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + sf_vec_size: int, + smem_layout: cute.Layout, + *, + loc=None, + ip=None, +) -> cute.Layout: + """Make tmem layout for SFA based on: + 1. SFA smem layout per stage + 2. Cta tile shape m + 3. tiled MMA atom thr size + 4. Scale factor vector size + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param smem_layout: The smem layout of SFA per stage + :type smem_layout: cute.Layout + + :return: TMEM layout for SFA + :rtype: cute.Layout + """ + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size + + sfa_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfa( + smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size + ) + return _cute_ir.static(sfa_layout_ty, loc=loc, ip=ip) + + +@dsl_user_op +def make_tmem_layout_sfb( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + sf_vec_size: int, + smem_layout: cute.Layout, + *, + loc=None, + ip=None, +) -> cute.Layout: + """Make tmem layout for SFB based on: + 1. SFB smem layout per stage + 2. Cta tile shape m + 3. tiled MMA atom thr size + 4. Scale factor vector size + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param smem_layout: The smem layout of SFB per stage + :type smem_layout: cute.Layout + + :return: TMEM layout for SFB + :rtype: cute.Layout + """ + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size + + sfb_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfb( + smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size + ) + return _cute_ir.static(sfb_layout_ty, loc=loc, ip=ip) diff --git a/python/CuTeDSL/cutlass/utils/hopper_helpers.py b/python/CuTeDSL/cutlass/utils/hopper_helpers.py index 3b94d694..4cd2bae3 100644 --- a/python/CuTeDSL/cutlass/utils/hopper_helpers.py +++ b/python/CuTeDSL/cutlass/utils/hopper_helpers.py @@ -11,6 +11,8 @@ from typing import Type, Tuple from enum import Enum +from typing_extensions import deprecated +import warnings from cutlass.utils.layout import LayoutEnum from cutlass.cutlass_dsl import ( @@ -34,6 +36,23 @@ from cutlass.cute.nvgpu.warpgroup import ( OperandSource, ) + +@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead") +class SmemCapacity(Enum): + SM90_SMEM_CAPACITY_BYTES = (228 - 1) * 1024 + + +warnings.warn( + "SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead", + DeprecationWarning, + stacklevel=2, +) +# Dictionary to map compute capability to SMEM capacity +SMEM_CAPACITY = { + "sm90": SmemCapacity.SM90_SMEM_CAPACITY_BYTES.value, +} + + @dsl_user_op def sm90_get_smem_store_op( layout_d: LayoutEnum, @@ -79,15 +98,6 @@ def sm90_get_smem_store_op( return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip) -class SmemCapacity(Enum): - SM90_SMEM_CAPACITY_BYTES = (228 - 1) * 1024 - - -# Dictionary to map compute capability to SMEM capacity -SMEM_CAPACITY = { - "sm90": SmemCapacity.SM90_SMEM_CAPACITY_BYTES.value, -} - def make_trivial_tiled_mma( a_dtype: Type[Numeric], b_dtype: Type[Numeric], diff --git a/python/CuTeDSL/cutlass/utils/smem_allocator.py b/python/CuTeDSL/cutlass/utils/smem_allocator.py index ded0ae43..9490fb58 100644 --- a/python/CuTeDSL/cutlass/utils/smem_allocator.py +++ b/python/CuTeDSL/cutlass/utils/smem_allocator.py @@ -14,7 +14,7 @@ from typing import Type, Union, overload from cutlass.cutlass_dsl import Int8, Numeric, NumericMeta import cutlass.cute as cute -from cutlass.cute.arch import get_dyn_smem +from cutlass.cute.arch import get_dyn_smem, get_dyn_smem_size class SmemAllocator: @@ -60,6 +60,7 @@ class SmemAllocator: :return: Pointer to the start of the allocated memory block or struct instance :rtype: cute.Pointer :raises ValueError: If size is negative or alignment is less than 1 + :raises RuntimeError: If allocation would exceed available shared memory """ if isinstance(size_or_type, cute.struct): alignment = max(byte_alignment, size_or_type.__alignof__()) @@ -80,6 +81,14 @@ class SmemAllocator: byte_alignment - self._allocated_bytes % byte_alignment ) self._allocated_bytes += num_bytes + + # Check bounds against available dynamic shared memory + cute.testing.assert_( + self._allocated_bytes <= get_dyn_smem_size(), + f"Allocation failed: shared memory allocation exceeds available memory set in kernel launch. " + f"Allocated bytes: {self._allocated_bytes} bytes. " + f"Please reduce the allocation or set a larger smem size in kernel launch.", + ) return ptr def allocate_array(self, element_type: Type[Numeric], num_elems: int = 1): diff --git a/python/CuTeDSL/cutlass/utils/smem_capacity.py b/python/CuTeDSL/cutlass/utils/smem_capacity.py new file mode 100644 index 00000000..87ddb990 --- /dev/null +++ b/python/CuTeDSL/cutlass/utils/smem_capacity.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + + +SMEM_CAPACITY_MAP = { + "sm_120": (100 - 1) * 1024, + "sm_100": (228 - 1) * 1024, + "sm_90": (228 - 1) * 1024, + "sm_80": (164 - 1) * 1024, + "sm_86": (100 - 1) * 1024, + "sm_89": (100 - 1) * 1024, +} + + +def get_smem_capacity_in_bytes(compute_capability: str) -> int: + if compute_capability not in SMEM_CAPACITY_MAP: + raise ValueError(f"Unsupported compute capability: {compute_capability}") + return SMEM_CAPACITY_MAP[compute_capability] diff --git a/python/CuTeDSL/cutlass_dsl/cutlass.py b/python/CuTeDSL/cutlass_dsl/cutlass.py index 9a61a746..e2461d50 100644 --- a/python/CuTeDSL/cutlass_dsl/cutlass.py +++ b/python/CuTeDSL/cutlass_dsl/cutlass.py @@ -159,7 +159,11 @@ class CutlassBaseDSL(BaseDSL): pipeline = super()._get_pipeline(pipeline) if pipeline == None: # cubin format is required to be cubin as we launch cuda module at python level. - return "builtin.module(cute-to-nvvm{cubin-format=bin opt-level=3})" + return ( + "builtin.module(cute-to-nvvm{cubin-format=bin " + + self.compile_options.to_str() + + "})" + ) return pipeline @@ -294,13 +298,8 @@ class CutlassBaseDSL(BaseDSL): self, _CutlassIrKernelGenHelper, funcBody, *args, **kwargs ) - def _get_globals(self): - caller_globals = self.frame.f_globals - caller_locals = self.frame.f_locals - all_globals = globals().copy() - all_globals.update(caller_globals) - all_globals.update(caller_locals) - return all_globals + def _get_module_globals(self): + return globals() def _preprocess_launch_config_args(self, args, kwargs): """Helper to preprocess args and kwargs for LaunchConfig""" @@ -459,7 +458,10 @@ class KernelLauncher: def _check_func_args(self, funcBody, *func_args, **func_kwargs): # Get function signature - sig = inspect.signature(funcBody) + if isinstance(funcBody, DSLCallable): + sig = funcBody.get_signature() + else: + sig = inspect.signature(funcBody) # func_args and func_kwargs should match funcBody's signature, # no extra or missing arguments. @@ -485,6 +487,7 @@ class KernelLauncher: ret, name = kernel_generator(*self.func_args, **self.func_kwargs, config=config) self.dsl.kernel_symbols.append(name) + self.dsl.frame = None return ret.launch_op_ret def __call__(self, *args, **kwargs): @@ -537,14 +540,18 @@ def pack_from_irvalue( mixed_values[idx] = obj elif not isinstance(obj, type) and hasattr(obj, "__new_from_mlir_values__"): mixed_values[idx] = obj.__new_from_mlir_values__(chunk) + elif isinstance(chunk, list) and chunk[0] is None: + mixed_values[idx] = class_types[idx] else: - try: - if isinstance(chunk, list) and chunk[0] is None: - mixed_values[idx] = class_types[idx] - else: + if len(chunk) == 1: + try: mixed_values[idx] = t.as_numeric(chunk[0]) - except DSLRuntimeError as e: - mixed_values[idx] = chunk[0] + except ValueError: + # Suppress the conversion error and try new_from_mlir_values below + pass + + if mixed_values[idx] is None: + mixed_values[idx] = new_from_mlir_values(obj, chunk) log().debug("------------------ ") for idx, packed in enumerate(mixed_values): diff --git a/python/CuTeDSL/requirements.txt b/python/CuTeDSL/requirements.txt index aec440f9..093ed1d6 100644 --- a/python/CuTeDSL/requirements.txt +++ b/python/CuTeDSL/requirements.txt @@ -1,3 +1,3 @@ # Use `pip install -r requirements.txt` with the present file to install a # wheel consistent with the present state of the github repository -nvidia-cutlass-dsl==4.1.0.dev0 +nvidia-cutlass-dsl==4.1.0 diff --git a/python/cutlass_library/emit_kernel_listing.py b/python/cutlass_library/emit_kernel_listing.py index 0e6d8883..4ae2d8ed 100755 --- a/python/cutlass_library/emit_kernel_listing.py +++ b/python/cutlass_library/emit_kernel_listing.py @@ -278,7 +278,11 @@ def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode ): - profiler_reference_computing = "--verification-providers=device --providers=cutlass" + # For functional testing, we prefer to run reference computing on device if any + reference_device_archs = ["100a"] + run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False + profiler_flags_for_verification = "device" if run_reference_on_device else "host" + # beta values for L0 and L1 # TODO: randomize beta values for wider coverage beta_values = [0.5] @@ -408,7 +412,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \ f"({sm100_mma_filter_regex_2sm})|" \ f"({block_scaled_filter_regex_1sm})|" \ - f"({block_scaled_filter_regex_2sm})|" + f"({block_scaled_filter_regex_2sm})" # CTA tiles for sm120 MMA - only run one tile size to reduce build/test times sm120_mma_kernel_cta_tiles = [ # h1688, s1688, i16832, i8816 @@ -545,11 +549,22 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode elif "ue8m0xf8_ue8m0xf8" in kernel_name: runtime_input_datatypes = [['e4m3','e4m3']] + if "bstensorop" in kernel_name or is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind): + profiler_flags_for_verification = "host" + + # reduce L1 test runtime if reference kernel is not running on device. + if mode == "functional_L1" and profiler_flags_for_verification == "host" : + problem_waves = [0.5, 2.5] + + if dynamic_cluster: if mode == "functional_L0": runtime_cluster_shapes = [[1,1,1], [2,2,1]] else: runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1], [2,4,1], [4,2,1], [4,4,1]] + # reduce L1 test runtime if reference kernel is not running on device. + if profiler_flags_for_verification == "host": + runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1]] cta_tile_shape_m, cta_tile_shape_n, cta_tile_shape_k = operation.tile_description.threadblock_shape else: runtime_cluster_shapes = [operation.tile_description.cluster_shape] @@ -643,11 +658,8 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode batch_count = 3 if mode == "functional_L0" else 5 gemm_op = "gemm" - profiler_reference_computing_override = profiler_reference_computing grouped = is_grouped(manifest.operations_by_name[kernel_name].gemm_kind) num_groups = 1 - if "bstensorop" in kernel_name: - profiler_reference_computing_override = "--mode=trace" if grouped: gemm_op = "grouped_gemm" num_groups = 3 # small to limit test time in host block-scaled reference kernels @@ -695,7 +707,9 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode metadata_dict["runtime_params"]["runtime_datatype_b"] = runtime_datatype_b testcase_metadata = [ - f"cutlass_profiler --operation={gemm_op} {profiler_reference_computing_override} --error-on-no-match --error-if-nothing-is-profiled" + + f"cutlass_profiler --operation={gemm_op}" + + (f" --verification-providers=device --providers=cutlass" if profiler_flags_for_verification == "device" else " --mode=trace") + + f" --error-on-no-match --error-if-nothing-is-profiled" + f" --kernels={kernel_name}" + f" --m={str(m)}" + f" --n={str(n)}" + diff --git a/tools/library/src/reference/blockwise_gemm_reference_operation.h b/tools/library/src/reference/blockwise_gemm_reference_operation.h index 591a5ce3..fd988f89 100644 --- a/tools/library/src/reference/blockwise_gemm_reference_operation.h +++ b/tools/library/src/reference/blockwise_gemm_reference_operation.h @@ -606,10 +606,46 @@ void initialize_blockwise_gemm_reference_operations_given_C_and_D(Manifest &mani float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); make_blockwise_gemm< @@ -620,10 +656,46 @@ void initialize_blockwise_gemm_reference_operations_given_C_and_D(Manifest &mani float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1 , 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, @@ -633,11 +705,46 @@ void initialize_blockwise_gemm_reference_operations_given_C_and_D(Manifest &mani float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 128, 128); - + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, @@ -647,10 +754,46 @@ void initialize_blockwise_gemm_reference_operations_given_C_and_D(Manifest &mani float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1 , 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); }