From a1aaf2300a8fc3a8106a05436e1a2abad0930443 Mon Sep 17 00:00:00 2001 From: Junkai-Wu Date: Thu, 3 Jul 2025 20:07:53 +0800 Subject: [PATCH] v4.1 release --- CHANGELOG.md | 42 +- README.md | 91 +- .../72a_blackwell_nvfp4_bf16_gemm.cu | 26 +- .../72b_blackwell_nvfp4_nvfp4_gemm.cu | 30 +- .../72c_blackwell_mixed_mxfp8_bf16_gemm.cu | 28 +- .../75_blackwell_grouped_gemm_block_scaled.cu | 24 +- .../77_blackwell_fmha_bwd.cu | 253 +- examples/77_blackwell_fmha/CMakeLists.txt | 31 +- examples/77_blackwell_fmha/README.md | 6 + .../collective/fmha_fusion.hpp | 118 +- .../device/fmha_device_bwd.hpp | 44 +- .../kernel/fmha_kernel_bwd_convert.hpp | 29 +- .../kernel/fmha_kernel_bwd_sum_OdO.hpp | 22 +- ...00_fmha_bwd_kernel_tma_warpspecialized.hpp | 244 +- .../sm100_fmha_mla_tma_warpspecialized.hpp | 6 + .../reference/fmha_bwd_reference.hpp | 132 +- .../79a_blackwell_geforce_nvfp4_bf16_gemm.cu | 24 +- .../79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu | 26 +- ...ell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu | 22 +- ...9d_blackwell_geforce_nvfp4_grouped_gemm.cu | 26 +- ...lackwell_geforce_mxfp8_bf16_sparse_gemm.cu | 18 +- ...ackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu | 18 +- .../84a_blackwell_nvfp4_bf16_sparse_gemm.cu | 32 +- ..._blackwell_mixed_mxfp8_bf16_sparse_gemm.cu | 32 +- .../python/CuTeDSL/ampere/call_from_jit.py | 259 ++ .../python/CuTeDSL/ampere/elementwise_add.py | 80 +- .../CuTeDSL/ampere/elementwise_apply.py | 121 +- .../CuTeDSL/ampere/flash_attention_v2.py | 54 +- examples/python/CuTeDSL/ampere/sgemm.py | 71 +- .../python/CuTeDSL/ampere/tensorop_gemm.py | 852 ++-- .../python/CuTeDSL/blackwell/dense_gemm.py | 151 +- .../blackwell/dense_gemm_persistent.py | 172 +- .../blackwell/dense_gemm_software_pipeline.py | 1852 +++++++++ examples/python/CuTeDSL/blackwell/fmha.py | 2235 +++++----- .../python/CuTeDSL/blackwell/grouped_gemm.py | 340 +- .../blackwell/mamba2_ssd/mamba2_ssd.py | 3619 +++++++++++++++++ .../mamba2_ssd/mamba2_ssd_reference.py | 397 ++ .../mamba2_ssd/mamba2_ssd_tile_scheduler.py | 200 + .../python/CuTeDSL/cute/ffi/CMakeLists.txt | 2 +- examples/python/CuTeDSL/hopper/dense_gemm.py | 52 +- .../notebooks/cute_layout_algebra.ipynb | 2 +- .../python/CuTeDSL/notebooks/data_types.ipynb | 43 +- .../python/CuTeDSL/notebooks/tensorssa.ipynb | 2 +- include/cute/algorithm/copy.hpp | 7 +- include/cute/atom/copy_atom.hpp | 125 - include/cute/atom/mma_atom.hpp | 442 +- include/cute/atom/mma_traits_sm100.hpp | 35 + include/cute/layout.hpp | 183 +- include/cute/numeric/arithmetic_tuple.hpp | 264 +- include/cute/numeric/int.hpp | 15 +- include/cute/numeric/numeric_types.hpp | 69 +- include/cute/pointer.hpp | 26 +- include/cute/pointer_flagged.hpp | 17 - include/cute/tensor.hpp | 6 + include/cute/tensor_impl.hpp | 116 +- include/cute/util/print_latex.hpp | 438 ++ include/cute/util/print_svg.hpp | 257 ++ include/cute/util/print_tensor.hpp | 188 + include/cute/util/type_traits.hpp | 23 + include/cutlass/arch/mma_sm100.h | 118 + .../detail/collective/mixed_input_utils.hpp | 106 +- .../collective/builders/sm100_builder.inl | 99 + .../cutlass/epilogue/collective/detail.hpp | 10 + .../collective/sm100_epilogue_nosmem.hpp | 20 +- .../cutlass/epilogue/fusion/operations.hpp | 19 + .../sm90_callbacks_tma_warpspecialized.hpp | 108 + ...90_visitor_compute_tma_warpspecialized.hpp | 4 +- ...sm90_visitor_store_tma_warpspecialized.hpp | 10 +- .../fusion/sm90_visitor_topk_softmax.hpp | 27 +- .../gemm/collective/builders/sm100_common.inl | 41 + .../builders/sm100_simt_builder.inl | 216 + .../gemm/collective/collective_builder.hpp | 1 + .../gemm/collective/collective_mma.hpp | 1 + .../collective/sm80_mma_array_multistage.hpp | 412 ++ ...ma_gmma_rs_warpspecialized_mixed_input.hpp | 7 +- ...ma_gmma_rs_warpspecialized_mixed_input.hpp | 6 +- include/cutlass/gemm/dispatch_policy.hpp | 22 + .../cutlass/gemm/kernel/gemm_universal.hpp | 1 + ...mm_tma_warpspecialized_input_transform.hpp | 33 - .../cutlass/gemm/kernel/sm70_gemm_array.hpp | 279 ++ include/cutlass/integer_subbyte.h | 9 +- include/cutlass/numeric_size.h | 8 +- .../cute_dsl_general/dsl_control_flow.rst | 187 +- media/docs/pythonDSL/faqs.rst | 2 +- media/docs/pythonDSL/limitations.rst | 5 +- python/CuTeDSL/base_dsl/ast_helpers.py | 396 +- python/CuTeDSL/base_dsl/ast_preprocessor.py | 805 ++-- python/CuTeDSL/base_dsl/dsl.py | 48 +- python/CuTeDSL/base_dsl/env_manager.py | 6 +- python/CuTeDSL/base_dsl/jit_executor.py | 51 +- python/CuTeDSL/base_dsl/typing.py | 186 +- python/CuTeDSL/cutlass/cute/__init__.py | 4 + python/CuTeDSL/cutlass/cute/arch/__init__.py | 8 +- python/CuTeDSL/cutlass/cute/arch/elect.py | 11 +- python/CuTeDSL/cutlass/cute/arch/mbar.py | 131 +- .../cutlass/cute/arch/nvvm_wrappers.py | 36 + python/CuTeDSL/cutlass/cute/core.py | 685 +++- .../cutlass/cute/nvgpu/cpasync/__init__.py | 2 +- .../cutlass/cute/nvgpu/cpasync/copy.py | 27 +- .../cutlass/cute/nvgpu/cpasync/helpers.py | 14 +- python/CuTeDSL/cutlass/cute/nvgpu/helpers.py | 102 +- .../cutlass/cute/nvgpu/tcgen05/copy.py | 10 +- .../CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py | 5 +- .../cutlass/cute/nvgpu/warpgroup/mma.py | 4 +- python/CuTeDSL/cutlass/cute/runtime.py | 13 +- python/CuTeDSL/cutlass/cute/testing.py | 350 +- python/CuTeDSL/cutlass/cute/typing.py | 2 + python/CuTeDSL/cutlass/pipeline/__init__.py | 62 + python/CuTeDSL/cutlass/pipeline/helpers.py | 645 +++ python/CuTeDSL/cutlass/pipeline/sm100.py | 452 ++ python/CuTeDSL/cutlass/pipeline/sm90.py | 803 ++++ python/CuTeDSL/cutlass/torch.py | 123 +- python/CuTeDSL/cutlass/utils/__init__.py | 16 +- .../cutlass/utils/blackwell_helpers.py | 143 + .../CuTeDSL/cutlass/utils/hopper_helpers.py | 8 +- python/CuTeDSL/cutlass/utils/pipeline.py | 1023 ----- .../CuTeDSL/cutlass/utils/smem_allocator.py | 137 +- python/CuTeDSL/cutlass_dsl/__init__.py | 5 + python/CuTeDSL/cutlass_dsl/cutlass.py | 142 +- .../cutlass_dsl/cutlass_ast_decorators.py | 41 + python/CuTeDSL/requirements.txt | 2 +- .../backend/evt/frontend/frontend_base.py | 11 +- .../backend/evt/frontend/python_ast.py | 11 +- python/cutlass/backend/evt/ir/dag_ir.py | 20 +- python/cutlass/backend/evt/ir/tensor.py | 21 +- .../backend/evt/passes/pass_dag_2_tree.py | 116 +- python/cutlass/backend/library.py | 2 + python/cutlass/epilogue/__init__.py | 1 + python/cutlass/epilogue/evt_ops.py | 6 + python/cutlass_library/emit_kernel_listing.py | 6 +- .../python/cutlass/evt/evt_compute_sm80_90.py | 76 + test/python/cutlass/evt/evt_mixed_sm80_90.py | 45 + test/python/cutlass/evt/evt_store_sm80_90.py | 25 + test/python/cutlass/evt/utils/evt_testbed.py | 11 +- test/unit/common/filter_architecture.cpp | 1 + ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 47 + ...gemm_s8_s8_s32_tensorop_s32_with_fusion.cu | 47 + test/unit/conv/device_3x/testbed_conv.hpp | 3 + test/unit/cute/core/array_subbyte.cpp | 13 + test/unit/gemm/device/CMakeLists.txt | 29 + test/unit/gemm/device/gemm_testbed_3x.hpp | 226 +- .../gemm/device/gemm_testbed_3x_ptr_array.hpp | 184 +- .../sm100_gemm_f32_f32_f32_simt_align1.cu | 321 ++ ..._gemm_f32_f32_f32_simt_align1_bias_relu.cu | 337 ++ ..._gemm_f32_f32_f32_simt_align1_ptr_array.cu | 329 ++ .../include/cutlass/library/arch_mappings.h | 2 +- .../block_scaled_gemm_reference_operation.h | 34 +- .../block_scaled_gemm_operation_profiler.h | 8 +- .../profiler/gemm_operation_profiler.h | 8 +- .../grouped_gemm_operation_profiler.h | 15 +- .../block_scaled_gemm_operation_profiler.cu | 47 +- tools/profiler/src/gemm_operation_profiler.cu | 48 +- .../src/grouped_gemm_operation_profiler.cu | 43 +- .../cutlass/util/mixed_dtype_utils.hpp | 10 +- .../cutlass/util/reference/host/conv.hpp | 84 +- 155 files changed, 18407 insertions(+), 6068 deletions(-) create mode 100644 examples/python/CuTeDSL/ampere/call_from_jit.py create mode 100644 examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py create mode 100644 examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py create mode 100644 examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_reference.py create mode 100644 examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_tile_scheduler.py create mode 100644 include/cute/util/print_latex.hpp create mode 100644 include/cute/util/print_svg.hpp create mode 100644 include/cute/util/print_tensor.hpp create mode 100644 include/cutlass/arch/mma_sm100.h create mode 100644 include/cutlass/gemm/collective/builders/sm100_simt_builder.inl create mode 100644 include/cutlass/gemm/collective/sm80_mma_array_multistage.hpp create mode 100644 include/cutlass/gemm/kernel/sm70_gemm_array.hpp create mode 100644 python/CuTeDSL/cutlass/pipeline/__init__.py create mode 100644 python/CuTeDSL/cutlass/pipeline/helpers.py create mode 100644 python/CuTeDSL/cutlass/pipeline/sm100.py create mode 100644 python/CuTeDSL/cutlass/pipeline/sm90.py delete mode 100644 python/CuTeDSL/cutlass/utils/pipeline.py create mode 100644 test/unit/gemm/device/sm100_gemm_f32_f32_f32_simt_align1.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f32_f32_f32_simt_align1_bias_relu.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f32_f32_f32_simt_align1_ptr_array.cu diff --git a/CHANGELOG.md b/CHANGELOG.md index 229261bc..67cc633a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,46 @@ # CUTLASS 4.x +## [4.1.0](https://github.com/NVIDIA/cutlass/tree/main) (2025-06-30) + +### CuTe DSL +* More examples demonstrating how to use CuTe DSL to write peak-performance kernels + - [Blackwell Mamba2 SSD](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py) +* API updates + - for loop + - Python built-in ``range`` now always generates IR and executes at runtime + - ``cutlass.range`` is advanced ``range`` with IR level unrolling and pipelining control + - Deprecated ``cutlass.range_dynamic``, please replace with ``range`` or ``cutlass.range`` + - **Experimental** Added ``pipelining`` control for compiler generated software pipeline code + - while/if + - ``while``/``if`` now by default generates IR and executes at runtime unless ``cutlass.const_expr`` is specified for the predicate + - Deprecated ``cutlass.dynamic_expr``, please remove it + - Rename mbarrier functions to reduce ambiguity + - Modify SyncObject API (`MbarrierArray`, `NamedBarrier`, `TmaStoreFence`) to match `std::barrier` + - Change pipeline `create` function to take only keyword arguments, and make `barrier_storage` optional. + +### CUTLASS C++ +* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). + - Add variable sequence length support for FMHA Backward kernel. + - Add varlen test support to Backward runner. + - Codes support empty batch sequences. +* Replace `subbyte_iterator` with `cute::recast_ptr` when constructing logical iterators/arrays. +* CuTe changes: + - Rewrite ArithTuple and ScaledBasis for robustness and clarity. + - Remove buggy and kludgy `get_layoutA|B|C_MN` and friends from Atoms/TiledX. + - Factor out `print_latex` and friends and rewrite. + - Factor out `print_svg` and friends and rewrite. +* Support Blackwell SM100 SIMT FFMA2 kernels. +* Support residual add for implicit gemm kernels. +* Various fixes for CUTLASS C++ Python interface's EVT tracer: + - Add verifier for sm90 to report the invalid input. + - When adding an edge to the graph, if the edge already exists, add an identity compute node to avoid having multiple parallel edges. + - Register operations of tanh, sigmoid, exp, gelu to the python ast frontend. + - Replace the NotImplemented Error by packing all nodes into a single topological visitor node as a fallback. +* Fix profiler bugs in exhaustive perf search. + - Fix incorrect cluster shape output issue when doing exhaustive search. + - Fix a bug in profiler grouped GEMM for setting tile scheduler swizzles, cluster shapes, and raster orders. + ## [4.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.0.0) (2025-06-03) ### CuTe DSL @@ -9,7 +49,7 @@ - [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL) - [DSL quick start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html) - [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html) -* [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass) +* [Overhauled documentation with a new dedicated website](https://docs.nvidia.com/cutlass) * Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels - [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py) - [Blackwell SM100 grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py) diff --git a/README.md b/README.md index 3e0afb1e..709dffd7 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ ![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") # Overview -# CUTLASS 4.0.0 +# CUTLASS 4.1.0 -_CUTLASS 4.0.0 - May 2025_ +_CUTLASS 4.1.0 - July 2025_ CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. It incorporates strategies for @@ -43,62 +43,45 @@ To get started quickly - please refer : - [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html). - [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html). -# What's New in CUTLASS 4.0 +# What's New in CUTLASS 4.1 -### CuTe DSL -* CuTe DSL, a Python DSL centered around CuTe's abstractions - - [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL) - - [DSL quick start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html) - - [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html) -* [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass) -* Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels - - [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py) - - [Blackwell SM100 grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py) - - [Blackwell SM100 fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py) - - [Hopper GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm.py) - - [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py) - - [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py) - - [SmemAllocator to facilitate shared memory allocation and management](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/smem_allocator.py) - - [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/ffi/jit_argument.py) -* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks) +## CuTe DSL +* More examples demonstrating how to use CuTe DSL to write peak-performance kernels + - [Blackwell Mamba2 SSD](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py) * API updates - - Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer`` + - for loop + - Python built-in ``range`` now always generates IR and executes at runtime + - ``cutlass.range`` is advanced ``range`` with IR level unrolling and pipelining control + - Deprecated ``cutlass.range_dynamic``, please replace with ``range`` or ``cutlass.range`` + - **Experimental** Added ``pipelining`` control for compiler generated software pipeline code + - while/if + - ``while``/``if`` now by default generates IR and executes at runtime unless ``cutlass.const_expr`` is specified for the predicate + - Deprecated ``cutlass.dynamic_expr``, please remove it + - Rename mbarrier functions to reduce ambiguity + - Modify SyncObject API (`MbarrierArray`, `NamedBarrier`, `TmaStoreFence`) to match `std::barrier` + - Change pipeline `create` function to take only keyword arguments, and make `barrier_storage` optional. -### CUTLASS C++ -* Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9 - - 100f, 101f, 120f were added to support Family Specific Architecture Features which allows running the same binary on different chips belonging to the same Family (e.g. sm100) without recompiling. Note 101a is supported since CUTLASS 3.9 -* Instruction shapes and redundant accumulation type have been removed from CUTLASS 3.x-style library kernel names to disambiguate kernels and shorten names. - - For example: - + `(old) cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma` - + `(new) cutlass3x_sm90_tensorop_gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma` - - If you are using the CUTLASS library kernel names directly (e.g. to compile a subset of the CUTLASS library with `-DCUTLASS_LIBRARY_KERNELS`, filter kernels in the CUTLASS profiler with `--kernels`), please update your uses accordingly, this is a breaking change. -* Further improved [Blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMMs on Hopper and Blackwell. - - Added non-power-of-two tile sizes. - - Improved performance for K-major scale factors. - - The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell SM100 versions. -* Enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). - - Support LSE output in FMHA Forward kernel. - - Enhance performance measurement: support of different warmup iterations; buffer rotation to keep L2 cold; separate testing of persistent and non-persistent. - - Enhance testing of variable sequence length. - - Disable B2B mode in MLA to simplify the sample. - - Clarify that `fmha_gen` sample only supports head dim 128. - - Fixes for split-kv output in MLA. -* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support. - - Enable runtime datatype for Blackwell SM100 grouped GEMM. Profiler support is also added. - - Enable kernel parameter exploration for Blackwell SM100 grouped GEMM - raster_order, swizzle. -* Add [Blackwell SM100 implicit GEMM conv fprop/dgrad/wgrad unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/). -* Add dynamic and preferred cluster support for convolution Blackwell SM100 kernels. -* Fix profiler issues which cause no output or not supported error for some kernels. -* Optimizations for Blackwell SM100 and SM120 block scaled kernels. -* Support for Blackwell SM120 blockwise dense gemm in CUTLASS library and profiler. -* New [Hopper SM90 FMHA example](https://github.com/NVIDIA/cutlass/tree/main/examples/88_hopper_fmha/), similar in design to the existing [Blackwell FMHA](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). +## CUTLASS C++ +* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). + - Add variable sequence length support for FMHA Backward kernel. + - Add varlen test support to Backward runner. + - Codes support empty batch sequences. +* Replace `subbyte_iterator` with `cute::recast_ptr` when constructing logical iterators/arrays. * CuTe changes: - - Rework `cute::copy_if` so that the predicate tensor is also a true CuTe Tensor rather than a lambda and introduces transform-tensors to avoid any extra register or load/store overhead in using bool-tensors. - - New [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/tiled_copy_if.cu) to show the usage of copy_if in tile copy. - - Add [CuTe C++ reduce op](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/tensor_reduce.hpp). - - Add several [unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/core/tensor_algs.cpp) for CuTe tensor algorithms. -* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! -* Optimal code generation with CUDA toolkit versions 12.9. + - Rewrite ArithTuple and ScaledBasis for robustness and clarity. + - Remove buggy and kludgy `get_layoutA|B|C_MN` and friends from Atoms/TiledX. + - Factor out `print_latex` and friends and rewrite. + - Factor out `print_svg` and friends and rewrite. +* Support Blackwell SM100 SIMT FFMA2 kernels. +* Support residual add for implicit gemm kernels. +* Various fixes for CUTLASS C++ Python interface's EVT tracer: + - Add verifier for sm90 to report the invalid input. + - When adding an edge to the graph, if the edge already exists, add an identity compute node to avoid having multiple parallel edges. + - Register operations of tanh, sigmoid, exp, gelu to the python ast frontend. + - Replace the NotImplemented Error by packing all nodes into a single topological visitor node as a fallback. +* Fix profiler bugs in exhaustive perf search. + - Fix incorrect cluster shape output issue when doing exhaustive search. + - Fix a bug in profiler grouped GEMM for setting tile scheduler swizzles, cluster shapes, and raster orders. Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits. CUTLASS team is working on a fix. diff --git a/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu index f729b43d..403472ad 100644 --- a/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu +++ b/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu @@ -40,10 +40,10 @@ Similar to 70_blackwell_gemm, this kernel leverages: 1. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/). - - 2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM - which allows us to decouple the execution of MMA and epilogue into separate warps. - + + 2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM + which allows us to decouple the execution of MMA and epilogue into separate warps. + 3. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). Usage: @@ -119,7 +119,7 @@ using MmaTileShape = Shape<_256,_256,_256>; // M using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, + ArchTag, OperatorClass, MmaTileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, @@ -190,13 +190,7 @@ cutlass::HostTensor block_referen template auto make_iterator(T* ptr) { - using namespace cute; - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -329,7 +323,7 @@ bool initialize_block( } cutlass::reference::host::TensorFillRandomUniform( view, seed, scope_max, scope_min, 0); - + return true; } @@ -413,7 +407,7 @@ bool verify(const Options &options) { auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); - + cutlass::reference::host::GettBlockScalingEpilogueParams< ElementAccumulator, // ElementScalar ElementAccumulator, // ElementAccumulator @@ -514,9 +508,9 @@ int main(int argc, char const **args) { cudaDeviceProp props; int current_device_id; CUDA_CHECK(cudaGetDevice(¤t_device_id)); - + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - + if (!(props.major == 10 && props.minor == 0)) { std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; return 0; diff --git a/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu index 8be4f639..6c28c552 100644 --- a/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu +++ b/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu @@ -39,10 +39,10 @@ 1. Blockscaled tcgen05.mma instructions. 2. Per-SM memory called Tensor Memory (TMEM) - - 3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM - which allows us to decouple the execution of MMA and epilogue into separate warps. - + + 3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM + which allows us to decouple the execution of MMA and epilogue into separate warps. + 4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). Usage: @@ -129,13 +129,13 @@ constexpr int OutputSFVectorSize = InputSFVectorSize; // With BlockScaleFactor generation. using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< OutputSFVectorSize, - ElementD, - ElementCompute, + ElementD, + ElementCompute, ElementSFD, LayoutSFDTag, ElementC>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, + ArchTag, OperatorClass, MmaTileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, @@ -219,13 +219,7 @@ cutlass::HostTensor block_N template auto make_iterator(T* ptr) { - using namespace cute; - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -358,7 +352,7 @@ bool initialize_block( } cutlass::reference::host::TensorFillRandomUniform( view, seed, scope_max, scope_min, 0); - + return true; } @@ -456,7 +450,7 @@ bool verify(const Options &options) { decltype(tensor_B), // TensorB decltype(tensor_SFB) // TensorSfB > mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; - + Tensor tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); Tensor tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); Tensor tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD); @@ -569,9 +563,9 @@ int main(int argc, char const **args) { cudaDeviceProp props; int current_device_id; CUDA_CHECK(cudaGetDevice(¤t_device_id)); - + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - + if (!(props.major == 10 && props.minor == 0)) { std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; return 0; diff --git a/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu index 1d6c1f3c..aff311c1 100644 --- a/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu +++ b/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu @@ -41,10 +41,10 @@ 1. Blockscaled tcgen05.mma instructions. 2. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/). - - 3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM - which allows us to decouple the execution of MMA and epilogue into separate warps. - + + 3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM + which allows us to decouple the execution of MMA and epilogue into separate warps. + 4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). Usage: @@ -120,7 +120,7 @@ using MmaTileShape = Shape<_256,_256,_256>; // M using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, + ArchTag, OperatorClass, MmaTileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, @@ -191,13 +191,7 @@ cutlass::HostTensor block_referen template auto make_iterator(T* ptr) { - using namespace cute; - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -330,7 +324,7 @@ bool initialize_block( } cutlass::reference::host::TensorFillRandomUniform( view, seed, scope_max, scope_min, 0); - + return true; } @@ -414,7 +408,7 @@ bool verify(const Options &options) { auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); - + cutlass::reference::host::GettBlockScalingEpilogueParams< ElementAccumulator, // ElementScalar ElementAccumulator, // ElementAccumulator @@ -515,14 +509,14 @@ int main(int argc, char const **args) { cudaDeviceProp props; int current_device_id; CUDA_CHECK(cudaGetDevice(¤t_device_id)); - + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - + if (!(props.major == 10 && props.minor == 0)) { std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; return 0; } - + // // Parse options // diff --git a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu index 50d37945..0632714a 100644 --- a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu +++ b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu @@ -130,8 +130,8 @@ constexpr int OutputSFVectorSize = 16; using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor< cutlass::epilogue::thread::SiLu, OutputSFVectorSize, - ElementD, - ElementAccumulator, + ElementD, + ElementAccumulator, ElementSFD, LayoutC, ElementC>; @@ -222,7 +222,7 @@ using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutS using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< - OutputSFVectorSize, + OutputSFVectorSize, cute::is_same_v ? cute::UMMA::Major::K : cute::UMMA::Major::MN >; @@ -287,13 +287,7 @@ cutlass::DeviceAllocation norm_constant_device; template auto make_iterator(T* ptr) { - using namespace cute; - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -529,7 +523,7 @@ bool initialize_block( } cutlass::reference::host::TensorFillRandomUniform( view, seed, scope_max, scope_min, 0); - + return true; } @@ -785,9 +779,9 @@ bool verify(const Options &options) { decltype(tensor_SFA), decltype(tensor_B), decltype(tensor_SFB) - > + > mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; - + auto tensor_C = cute::make_tensor(make_iterator(block_C.at(i).host_data()), layout_C); auto tensor_ref_D = cute::make_tensor(make_iterator(block_ref_D.at(i).host_data()), layout_D); @@ -857,7 +851,7 @@ int run(Options &options, bool host_problem_shapes_available = true) } else { std::cout << " Verfication is turned off for this run." << std::endl; - } + } // Run profiling loop if (options.iterations > 0) @@ -933,7 +927,7 @@ int main(int argc, char const **args) { std::cout << "Running kernel with 1SM MMA config:" << std::endl; run(options, false /*host_problem_shapes_available*/); std::cout << "Running kernel with 2SM MMA config:" << std::endl; - run(options, false /*host_problem_shapes_available*/); + run(options, false /*host_problem_shapes_available*/); #endif return 0; diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu b/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu index 1c02a29e..67188b51 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu @@ -114,12 +114,16 @@ struct Options { int h_k = 1; int q = 1024; int k = 1024; + std::vector varlen_q; + std::vector varlen_k; int d = 128; int iterations = 3; bool verify = false; bool verbose = false; bool causal = false; + bool residual = false; + bool varlen = false; int sm_count = 0; std::string kernel_filter; @@ -177,13 +181,75 @@ struct Options { cmd.get_cmd_line_argument("h", h, -1); if (h == -1) h = 2048 / d; + varlen = cmd.check_cmd_line_flag("varlen"); + cmd.get_cmd_line_argument("q", q, -1); cmd.get_cmd_line_argument("k", k, -1); + cmd.get_cmd_line_argument("b", b, -1); + std::string varlen_q_str; + cmd.get_cmd_line_argument("varlen-q", varlen_q_str); + std::string varlen_k_str; + cmd.get_cmd_line_argument("varlen-k", varlen_k_str); + + if (varlen && ! varlen_q_str.empty()) { + varlen_q.clear(); + while (! varlen_q_str.empty()) { + size_t pos = varlen_q_str.find(':'); + varlen_q.push_back(std::stoi(varlen_q_str.substr(0, pos))); + if (pos == std::string::npos) { + break; + } + varlen_q_str = varlen_q_str.substr(pos + 1); + } + if (b == -1) { + b = static_cast(varlen_q.size()); + } + if (b != static_cast(varlen_q.size())) { + std::cout << "Error: Invalid --varlen-q length\n"; + std::exit(-1); + } + int new_q = 0; + for (auto elem : varlen_q) { + new_q += elem; + } + if (q != -1) { + std::cout << "Error: Can't provide --q and --varlen-q\n"; + std::exit(-1); + } + q = new_q; + } + + if (varlen && ! varlen_k_str.empty()) { + varlen_k.clear(); + while (! varlen_k_str.empty()) { + size_t pos = varlen_k_str.find(':'); + varlen_k.push_back(std::stoi(varlen_k_str.substr(0, pos))); + if (pos == std::string::npos) { + break; + } + varlen_k_str = varlen_k_str.substr(pos + 1); + } + if (b == -1) { + b = static_cast(varlen_k.size()); + } + if (b != static_cast(varlen_k.size())) { + std::cout << " Error: Invalid --varlen-k length\n"; + std::exit(-1); + } + int new_k = 0; + for (auto elem : varlen_k) { + new_k += elem; + } + if (k != -1) { + std::cout << "Error: Can't provide --k and --varlen-k\n"; + std::exit(-1); + } + k = new_k; + } + if (q == -1) q = k; if (k == -1) k = q; if (q == -1 && k == -1) q = k = defaults.q; - - cmd.get_cmd_line_argument("b", b, -1); if (b == -1) b = 16384 / k; if (b == 0) b = 1; @@ -195,9 +261,15 @@ struct Options { if (mask == "causal") { causal = true; } + else if (mask == "residual") { + residual = true; + } else { causal = defaults.causal; } + if (varlen) { + residual = true; + } skip_reference = cmd.check_cmd_line_flag("skip-reference"); cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count); @@ -226,11 +298,18 @@ struct Options { << " --h= Sets the H extent\n" << " --q= Sets the Q extent\n" << " --k= Sets the K extent\n" - << " --d= Sets the D extentn" + << " --varlen-q=: Sets the variable Q extent per batch (colon separated)\n" + << " --varlen-k=: Sets the variable K extent per batch (colon separated)\n" + << " --d= Sets the D extent\n" << " --iterations= Benchmarking iterations\n" << " --verify Verify results\n" << " --verbose Print smem and execution time per kernel\n" - << " --mask= Enables masking\n" + << " --mask= Enables masking\n" + << " --varlen Enables variable sequence length\n" + << " B*Q and B*K become the total sequence length\n" + << " and are split B-ways, alternatingly +10% and -10%\n" + << " with the last batch sized to make it fit\n" + << " implies at least residual masking for correctness\n" << " --sm-count Sets SM count rather than querying it\n" << " --kernel-filter= Sets regexp to match kernel against\n" << "\n"; @@ -307,6 +386,7 @@ struct ExampleResult { /////////////////////////////////////////////////////////////////////////////////////////////////// template< + bool kIsVarlen, class TileShape, class DispatchPolicy, class ActiveMask, @@ -322,9 +402,11 @@ struct BwdRunner { using ElementAccumulator = float; // Q K D (H B) - using ProblemShapeType = cute::tuple>; - - using Operation = cutlass::fmha::device::Sm100FmhaBwd; + using ProblemShape = std::conditional_t< + kIsVarlen, + cute::tuple>, + cute::tuple> + >; using TensorStride = Stride>; // Seq D (H B) using StrideQ = TensorStride; @@ -363,6 +445,9 @@ struct BwdRunner { DeviceAllocation block_O; DeviceAllocation block_LSE; + DeviceAllocation block_cumulative_seqlen_q; + DeviceAllocation block_cumulative_seqlen_kv; + DeviceAllocation block_dQ; DeviceAllocation block_dK; DeviceAllocation block_dV; @@ -375,7 +460,7 @@ struct BwdRunner { // // Methods // - bool verify(const ProblemShapeType& problem_shape) { + bool verify(const ProblemShape& problem_shape) { auto [Q, K, D, HB] = problem_shape; auto [H, B] = HB; @@ -459,22 +544,89 @@ struct BwdRunner { return passed_dQ && passed_dK && passed_dV; } + auto initialize_problem_shape(Options const& options) { + if constexpr (kIsVarlen) { + int num_batches = options.b; + + // generate Q as --b times + // gaussian (--Q, --Q / 2) sampled positive + // track cumulative + std::mt19937 rng(0x202305151552ull); + std::normal_distribution dist_q(options.q, options.q / 2); + std::normal_distribution dist_kv(options.k, options.k / 2); + + auto generate_positive_int = [](auto& dist, auto& gen) { + // "0" is a valid value we test here + return std::max(0, static_cast(dist(gen))); + }; + + std::vector cumulative_seqlen_q = {0}; + std::vector cumulative_seqlen_kv = {0}; + + int total_seqlen_q = 0; + int total_seqlen_kv = 0; + int max_seqlen_q = 0; + int max_seqlen_kv = 0; + + const bool kVarlenSame = false; + for (int i = 0; i < num_batches; i++) { + int seqlen_q = (! options.varlen_q.empty()) ? options.varlen_q.at(i) : + kVarlenSame ? options.q : + generate_positive_int(dist_q, rng); + int seqlen_kv = (! options.varlen_k.empty()) ? options.varlen_k.at(i) : + kVarlenSame ? options.k : + generate_positive_int(dist_kv, rng); + + total_seqlen_q += seqlen_q; + total_seqlen_kv += seqlen_kv; + + max_seqlen_q = std::max(max_seqlen_q, seqlen_q); + max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv); + + cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q); + cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv); + } + + block_cumulative_seqlen_q.reset(cumulative_seqlen_q.size()); + block_cumulative_seqlen_q.copy_from_host(cumulative_seqlen_q.data(), cumulative_seqlen_q.size()); + block_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size()); + block_cumulative_seqlen_kv.copy_from_host(cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size()); + + ProblemShape problem_shape{ + {max_seqlen_q, block_cumulative_seqlen_q.get(), total_seqlen_q}, + {max_seqlen_kv, block_cumulative_seqlen_kv.get(), total_seqlen_kv}, + options.d, {options.h, options.b} + }; + auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, make_shape(options.h, 1)); + + return cute::make_tuple(problem_shape, tensor_shape); + } + else { + ProblemShape problem_shape{options.q, options.k, options.d, {options.h, options.b}}; + return cute::make_tuple(problem_shape, problem_shape); + } + } + /// Initialize operands to be used in the GEMM and reference GEMM - void initialize(const ProblemShapeType& problem_shape, Options const& options) { - auto [Q, K, D, HB] = problem_shape; + ProblemShape initialize(Options const& options) { + auto [problem_shape, tensor_shape] = initialize_problem_shape(options); + auto [Q, K, D, HB] = tensor_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment - Q = cutlass::round_up(Q, 8); // Alignment - auto shape_QO = select<0,2,3>(problem_shape); - auto shape_KV = select<1,2,3>(problem_shape); - auto shape_LSE = select<0,3>(problem_shape); + // for varlen, Q == total_Q, K == total_K, B = 1 + // but in problem_shape, they've got to be max_Q/max_K, and B = B + + auto shape_QO = make_shape(Q, D, make_shape(H, B)); + auto shape_KV = make_shape(K, D, make_shape(H, B)); + auto shape_LSE = make_shape(Q, make_shape(H, B)); + + stride_Q = make_stride(D, _1{}, make_stride(D*Q, B == 1 ? 0 : D*Q*H)); + stride_K = make_stride(D, _1{}, make_stride(D*K, B == 1 ? 0 : D*K*H)); + stride_LSE = make_stride(_1{}, make_stride(Q, B == 1 ? 0 : Q*H)); - stride_Q = make_stride(D, _1{}, make_stride(D*Q, D*Q*H)); - stride_K = make_stride(D, _1{}, make_stride(D*K, D*K*H)); stride_V = stride_K; stride_O = stride_Q; - stride_LSE = make_stride(_1{}, make_stride(Q, Q*H)); stride_dQ = stride_Q; stride_dK = stride_K; @@ -505,6 +657,13 @@ struct BwdRunner { initialize_block(block_V, seed + 2021, options.init_style_v); initialize_block(block_dO, seed + 2020, options.init_style_do); + initialize_block(block_dQ, seed + 2030, InitStyle::kOne); + initialize_block(block_dK, seed + 2031, InitStyle::kOne); + initialize_block(block_dV, seed + 2032, InitStyle::kOne); + initialize_block(block_ref_dQ, seed + 2033); + initialize_block(block_ref_dK, seed + 2034); + initialize_block(block_ref_dV, seed + 2035); + Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), select<0,2,3>(problem_shape), stride_Q); @@ -528,15 +687,19 @@ struct BwdRunner { if (! options.skip_reference) { fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{}); } + + return problem_shape; } ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { - auto problem_shape = make_shape(options.q, options.k, options.d, make_shape(options.h, options.b)); - - initialize(problem_shape, options); + auto problem_shape = initialize(options); ElementAccumulator softmax_scale = 1.0f / sqrtf(options.d); + ExampleResult example_result; + + using Operation = cutlass::fmha::device::Sm100FmhaBwd; + typename Operation::Arguments arguments{ problem_shape, block_Q.get(), stride_Q, @@ -554,8 +717,6 @@ struct BwdRunner { Operation op; - ExampleResult example_result; - example_result.smem_size = Operation::Kernel::SharedStorageSize; size_t workspace_size = 0; @@ -650,7 +811,7 @@ struct BwdRunner { runtime_ms /= static_cast(options.iterations); - double flops = 10.0 * (std::is_same_v ? 0.5 : 1.0); + double flops = 10.0 * (std::is_same_v ? 0.5 : 1.0); flops *= static_cast(get<0>(problem_shape)); flops *= static_cast(get<1>(problem_shape)); flops *= static_cast(get<2>(problem_shape)); @@ -688,11 +849,18 @@ struct BwdRunner { /////////////////////////////////////////////////////////////////////////////////////////////////// +int main_result = 0; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + /// Helper to print a description of the example run and its result void print_result(const std::string& description, ExampleResult result, bool verbose) { std::ios fmt(nullptr); fmt.copyfmt(std::cout); std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] "); + if (! result.passed) { + main_result = -1; + } std::cout << std::setw(32) << std::left << description; std::cout.copyfmt(fmt); std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl; @@ -706,14 +874,28 @@ void print_result(const std::string& description, ExampleResult result, bool ver struct KernelCoop {}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +auto dispatch_bool(bool value, Fn fn) { + if (value) { + return fn(std::true_type{}); + } + else { + return fn(std::false_type{}); + } +} + ////////////////////////////////////////////////////////////////////////////////////////////////// template void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { - BwdRunner runner; - auto result = runner.run(options, hw_info); - print_result(name, result, options.verbose); + dispatch_bool(options.varlen, [&](auto is_varlen) { + BwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }); }; using HeadDim = _64; @@ -726,9 +908,11 @@ void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf template void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { - BwdRunner runner; - auto result = runner.run(options, hw_info); - print_result(name, result, options.verbose); + dispatch_bool(options.varlen, [&](auto is_varlen) { + BwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }); }; using HeadDim = _128; @@ -803,7 +987,10 @@ int main_single(int argc, char const **args) { auto with_causal = [&](auto fn) { if (options.causal) { - fn(CausalMask{}); + fn(CausalForBackwardMask{}); + } + else if (options.residual) { + fn(ResidualMaskForBackward{}); } else { fn(NoMask{}); @@ -823,7 +1010,7 @@ int main_single(int argc, char const **args) { }); #endif - return 0; + return main_result; } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -831,8 +1018,6 @@ int main_single(int argc, char const **args) { int main(int argc, char const **args) { std::vector full_arguments(args, args + argc); - int result = 0; - bool recursed = false; for (size_t i = 1; i < full_arguments.size(); i++) { if (full_arguments[i].find(',') != std::string::npos) { @@ -859,7 +1044,7 @@ int main(int argc, char const **args) { main_single(argc, args); } - return result; + return main_result; } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index 50ca376f..ae3ceb0c 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -141,25 +141,26 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC TEST_COMMAND_OPTIONS TEST_BASIC TEST_VARLEN + # NOTE: bwd doesn't support GQA yet, --h_k will just get ignored in these tests + TEST_VARLEN_00 + TEST_VARLEN_01 + TEST_VARLEN_02 + TEST_VARLEN_03 + TEST_VARLEN_04 + TEST_VARLEN_05 + TEST_VARLEN_06 + TEST_VARLEN_07 + TEST_VARLEN_08 + TEST_VARLEN_09 + TEST_VARLEN_10 + TEST_VARLEN_11 + TEST_VARLEN_12 + TEST_VARLEN_13 + TEST_VARLEN_14 ) target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO}) target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v) - - cutlass_example_add_executable( - 77_blackwell_fmha_bwd_sat_${PREC} - 77_blackwell_fmha_bwd.cu - TEST_COMMAND_OPTIONS - TEST_BASIC - # TEST_GEN_VARLEN - TEST_GEN_HDIM64 - # TEST_GEN_GQA - # TEST_GEN_REMAP - # TEST_GEN_CACHEONLY) - ) - target_include_directories(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) - target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC) - target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v) endforeach() # Add a target that builds all examples diff --git a/examples/77_blackwell_fmha/README.md b/examples/77_blackwell_fmha/README.md index a1536dc8..58ad99a8 100644 --- a/examples/77_blackwell_fmha/README.md +++ b/examples/77_blackwell_fmha/README.md @@ -55,6 +55,12 @@ The approach of this implementation is to reuse the selection logic of the colle The example builds six binaries, showcasing TMA and `cp.async` usage, as well as a back-to-back gemm (essentially turning the softmax into a no-op) for fp8 and fp16. For detailed information on how to invoke them, check out either the tests in `CMakeLists.txt` or the `--help` for them. +# Changes + +* 4.1.0: Enhanced testing of variable sequence length; disabled B2B mode in MLA + to simplify the sample, clarified that `fmha_gen` sample only supports head + dim 128. + # Copyright Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index 6478b5d5..78147962 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -132,6 +132,58 @@ struct ResidualMask : NoMask { } }; +struct ResidualMaskForBackward : NoMask { + + using Base = NoMask; + + template + CUTLASS_DEVICE int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return 1; + } + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // if the sequence length does not divide the tile size evenly + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return get_trip_count(blk_coord, tile_shape, problem_size) - 1; + } + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // This is useful is seqlen_k % kBlockN != 0 since it masks + // the remaining elements out from softmax. + // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar + // issues as they are transparently taken care of by TMA and the + // epilogue, if it is instantiated with predication support. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (! elem_less(pos, select<0,1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } +}; + struct CausalMask : NoMask { using Base = NoMask; @@ -197,25 +249,57 @@ struct CausalMask : NoMask { }; +struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward { + + using Base = CausalMask; + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is what we demonstrate here + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to add an offset like so: + // get<0>(pos) + offset_q < get<1>(pos) + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + bool masked = (get<0>(pos) < get<1>(pos)) || !elem_less(pos, problem_size); + if (masked) { + acc_qk(i) = -INFINITY; + } + } + } + +}; + struct VariableLength { int max_length; int* cumulative_length = nullptr; + int total_length = -1; CUTE_HOST_DEVICE operator int() const { return max_length; } }; -template struct is_variable_length : std::false_type {}; -template<> struct is_variable_length : std::true_type {}; -template constexpr bool is_variable_length_v = is_variable_length::value; +template struct is_variable_length_impl : std::false_type {}; +template<> struct is_variable_length_impl : std::true_type {}; +template constexpr bool is_variable_length_v = is_variable_length_impl>::value; template CUTE_HOST_DEVICE constexpr auto apply_variable_length(Shape const& shape, Idx const& idx) { return transform_leaf(shape, [&](auto const& s) { - if constexpr (is_variable_length_v>) { + if constexpr (is_variable_length_v) { return s.cumulative_length[idx+1] - s.cumulative_length[idx]; } else { @@ -230,7 +314,7 @@ constexpr auto apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) { auto new_shape = apply_variable_length(shape, idx); auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) { - if constexpr (is_variable_length_v>) { + if constexpr (is_variable_length_v) { return cute::make_tuple(c, s.cumulative_length[idx]); } else { @@ -240,6 +324,30 @@ apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) { return cute::make_tuple(new_shape, new_coord); } +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length_offset(Shape const& shape, Coord const& coord) { + auto idx = back(back(coord)); + auto result_shape = transform_leaf(shape, [&](auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx+1] - s.cumulative_length[idx]; + } + else { + return s; + } + }); + auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx]; + } + else { + return _0{}; + } + }); + return cute::make_tuple(result_shape, result_offset); +} + } // namespace cutlass::fmha::collective namespace cute { diff --git a/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp b/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp index 80fcdf9f..3c8f7195 100644 --- a/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp +++ b/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp @@ -50,13 +50,19 @@ namespace cutlass::fmha::device { ////////////////////////////// CUTLASS 3.x API ///////////////////////////////// //////////////////////////////////////////////////////////////////////////////// -template +template< + class ProblemShape, + class Element, + class ElementAccumulator, + class TileShape, + class Mask +> class Sm100FmhaBwd { public: /// Argument structure: User API struct Arguments { // Q K D HB - cute::tuple> problem_size; + ProblemShape problem_shape; const Element* ptr_Q; cute::tuple> stride_Q; @@ -86,14 +92,16 @@ public: }; using OperationSumOdO = cutlass::fmha::device::FMHA< - cutlass::fmha::kernel::FmhaKernelBwdSumOdO + cutlass::fmha::kernel::FmhaKernelBwdSumOdO >; using OperationConvert = cutlass::fmha::device::FMHA< - cutlass::fmha::kernel::FmhaKernelBwdConvert + cutlass::fmha::kernel::FmhaKernelBwdConvert >; using Operation = cutlass::fmha::device::FMHA< - cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized + cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized< + ProblemShape, Element, ElementAccumulator, TileShape, Mask + > >; using Kernel = typename Operation::Kernel; @@ -113,15 +121,15 @@ private: ElementAccumulator* sum_odo = nullptr, ElementAccumulator* scaled_lse = nullptr) { using namespace cute; - auto [Q, K, D, HB] = args.problem_size; + auto [Q_, K, D, HB] = args.problem_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment - Q = cutlass::round_up(Q, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H)); auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H)); auto log2_e = log2f(expf(1.0f)); return typename OperationSumOdO::Arguments { - args.problem_size, + args.problem_shape, args.ptr_O, args.stride_O, args.ptr_dO, args.stride_dO, sum_odo, stride_sum_OdO, @@ -133,13 +141,13 @@ private: static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) { using namespace cute; - auto [Q, K, D, HB] = args.problem_size; + auto [Q_, K, D, HB] = args.problem_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment - Q = cutlass::round_up(Q, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H)); return typename OperationConvert::Arguments { - args.problem_size, + args.problem_shape, src, stride_src_dQ, nullptr, stride_src_dQ, nullptr, stride_src_dQ, @@ -156,7 +164,7 @@ private: ElementAccumulator* scaled_lse = nullptr, cute::tuple> const& stride_scaled_lse = {}, ElementAccumulator* dQ_acc = nullptr, cute::tuple> const& stride_dQ = {}) { return typename Operation::Arguments{ - args.problem_size, + args.problem_shape, { args.ptr_Q, args.stride_Q, args.ptr_K, args.stride_K, args.ptr_V, args.stride_V, @@ -199,10 +207,10 @@ public: /// Gets the workspace size static size_t get_workspace_size(Arguments const& args) { - auto [Q, K, D, HB] = args.problem_size; + auto [Q_, K, D, HB] = args.problem_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment - Q = cutlass::round_up(Q, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment size_t workspace_bytes = 0; // OdO vector workspace_bytes += B*H*Q * sizeof(ElementAccumulator); @@ -219,10 +227,10 @@ public: CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ=" << workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null")); - auto [Q, K, D, HB] = args.problem_size; + auto [Q_, K, D, HB] = args.problem_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment - Q = cutlass::round_up(Q, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment ElementAccumulator* sum_OdO = reinterpret_cast(workspace_sum_OdO); ElementAccumulator* scaled_lse = reinterpret_cast(workspace_scaled_lse); ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); @@ -248,10 +256,10 @@ public: CUTLASS_TRACE_HOST("Universal::initialize() - workspace " << workspace << ", stream: " << (stream ? "non-null" : "null")); - auto [Q, K, D, HB] = args.problem_size; + auto [Q_, K, D, HB] = args.problem_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment - Q = cutlass::round_up(Q, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment char* workspace_chr = reinterpret_cast(workspace); ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); workspace_chr += B*H*Q * sizeof(ElementAccumulator); diff --git a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp index c2618bcb..c7f869f9 100644 --- a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp +++ b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp @@ -39,11 +39,11 @@ namespace cutlass::fmha::kernel { using namespace cute; -template +template struct FmhaKernelBwdConvert { struct Arguments { - tuple> problem_size; + ProblemShape problem_shape; const ElementAcc* ptr_src_dQ; tuple> stride_src_dQ; @@ -85,11 +85,11 @@ struct FmhaKernelBwdConvert { static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq; static bool can_implement(Arguments const& args) { - return get<2>(args.problem_size) % kElementsPerLoad == 0; + return get<2>(args.problem_shape) % kElementsPerLoad == 0; } static dim3 get_grid_shape(Params const& params) { - dim3 grid(size<3,0>(params.problem_size), size<3,1>(params.problem_size), ceil_div(std::max(size<0>(params.problem_size), size<1>(params.problem_size)), kBlockSeq)); + dim3 grid(size<3,0>(params.problem_shape), size<3,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq)); return grid; } @@ -102,18 +102,25 @@ struct FmhaKernelBwdConvert { return args; } - template - CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) { + template + CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count) { auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y; auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y; + int seqlen = count; + if constexpr (is_variable_length_v) { + int offset = count.cumulative_length[blockIdx.y]; + ptr_dest_bh += offset * get<0>(stride_dest); + seqlen = count.cumulative_length[blockIdx.y + 1] - offset; + } + for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) { int idx_s = idx_s_t + kBlockSeq * blockIdx.z; - if (idx_s >= count) continue; + if (idx_s >= seqlen) continue; auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src); auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest); - for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) { + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) { ElementAcc value_src[kElementsPerLoad]; Element value_dest[kElementsPerLoad]; @@ -132,13 +139,13 @@ struct FmhaKernelBwdConvert { CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { if (params.ptr_src_dQ != nullptr) { - copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_size)); + copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape)); } if (params.ptr_src_dK != nullptr) { - copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_size)); + copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape)); } if (params.ptr_src_dV != nullptr) { - copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_size)); + copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape)); } } }; diff --git a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp index 44080e2d..98c127da 100644 --- a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp +++ b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -39,11 +39,11 @@ namespace cutlass::fmha::kernel { using namespace cute; -template +template struct FmhaKernelBwdSumOdO { struct Arguments { - cute::tuple> problem_size; + ProblemShape problem_shape; const Element* ptr_O; cute::tuple> stride_O; @@ -86,11 +86,11 @@ struct FmhaKernelBwdSumOdO { static const int kIterationsQ = kBlockQ / kNumThreadsQ; static bool can_implement(Arguments const& args) { - return get<2>(args.problem_size) % kElementsPerLoad == 0; + return get<2>(args.problem_shape) % kElementsPerLoad == 0; } static dim3 get_grid_shape(Params const& params) { - dim3 grid(ceil_div(size<0>(params.problem_size), kBlockQ), size<3,0>(params.problem_size), size<3,1>(params.problem_size)); + dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<3,0>(params.problem_shape), size<3,1>(params.problem_shape)); return grid; } @@ -110,10 +110,20 @@ struct FmhaKernelBwdSumOdO { auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse); auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse); + auto problem_q = get<0>(params.problem_shape); + int seqlen_q = problem_q; + if constexpr (is_variable_length_v) { + int offset = problem_q.cumulative_length[blockIdx.z]; + ptr_O_bh += offset * get<0>(params.stride_O); + ptr_dO_bh += offset * get<0>(params.stride_dO); + ptr_lse_bh += offset * get<0>(params.stride_lse); + seqlen_q = problem_q.cumulative_length[blockIdx.z + 1] - offset; + } + CUTLASS_PRAGMA_UNROLL for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) { int idx_q = idx_q_t + kBlockQ * blockIdx.x; - if (idx_q >= get<0>(params.problem_size)) continue; + if (idx_q >= seqlen_q) continue; ElementAcc acc = 0; auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O); auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO); @@ -121,7 +131,7 @@ struct FmhaKernelBwdSumOdO { auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse); auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse); - for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) { + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) { Element value_O[kElementsPerLoad]; Element value_dO[kElementsPerLoad]; diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp index 8c5401a9..c4e3f9d5 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -43,6 +43,8 @@ #include "collective/fmha_common.hpp" +#include + namespace cutlass::fmha::kernel { using namespace cutlass::fmha::collective; @@ -50,6 +52,7 @@ using namespace cutlass::fmha::collective; using namespace cute; template< + class ProblemShape, class Element, class ElementAcc, class TileShape, @@ -274,7 +277,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); - using ProblemShape = Shape>; // Q K D (H B), eventuall D = (D_QK, D_VO) using TensorStride = TensorStrideContiguousK; // S D (H B) using RowTensorStride = Stride<_1, Stride>; // S (H B) @@ -360,7 +362,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { static Params to_underlying_arguments(Arguments const& args, void*) { - auto [Q, K, D, HB] = args.problem_shape; + auto [Q_, K_, D, HB] = args.problem_shape; + int Q = Q_; + int K = K_; + + if constexpr (is_variable_length_v) { + Q = Q_.total_length; + } + if constexpr (is_variable_length_v) { + K = K_.total_length; + } auto params_kq = CollectiveMmaKQ::to_underlying_arguments( make_shape(K, Q, D, HB), @@ -378,7 +389,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { TMA_DQ tma_red_dq = make_tma_copy( SM90_TMA_REDUCE_ADD{}, - make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q, D, HB), args.mainloop.stride_dq_acc), + make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc), SmemLayoutDQ{}(_, _, _0{}) ); @@ -416,10 +427,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } - template + template CUTLASS_DEVICE void load( BlkCoord const& blk_coord, - ProblemShape const& problem_shape, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, int iter_index, int iter_count, MainloopArguments const& mainloop_args, @@ -440,10 +452,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { uint16_t mcast_mask = 0; - auto mK = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); - auto mQ = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); - auto mV = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB)); - auto mDO = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB)); + auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); + auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB)); + auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); + auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB)); + + auto mK = domain_offset(select<1,2,3>(blk_offset), mK_in); + auto mV = domain_offset(select<1,2,3>(blk_offset), mV_in); + auto mQ = domain_offset(select<0,2,3>(blk_offset), mQ_in); + auto mDO = domain_offset(select<0,2,3>(blk_offset), mDO_in); auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{}); auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step{}); @@ -478,7 +495,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { // set up lse and sum_odo - auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord; pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); @@ -515,11 +532,13 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); - cutlass::arch::cp_async_zfill<16>( - shared_tensors.smem_lse.begin() + smem_idx, - &mLSE(gmem_idx, blk_coord_batch), - gmem_idx < Q - ); + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_lse_producer_state; @@ -556,11 +575,13 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); - cutlass::arch::cp_async<16>( - shared_tensors.smem_sum_odo.begin() + smem_idx, - &mSumOdO(gmem_idx, blk_coord_batch), - gmem_idx < Q - ); + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_sum_odo_producer_state; @@ -588,11 +609,13 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { // load LSE smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; - cutlass::arch::cp_async<16>( - shared_tensors.smem_lse.begin() + smem_idx, - &mLSE(gmem_idx, blk_coord_batch), - gmem_idx < Q - ); + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_lse_producer_state; @@ -616,11 +639,13 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { // load sum_OdO smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; - cutlass::arch::cp_async_zfill<16>( - shared_tensors.smem_sum_odo.begin() + smem_idx, - &mSumOdO(gmem_idx, blk_coord_batch), - gmem_idx < Q - ); + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_sum_odo_producer_state; @@ -631,10 +656,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } - template + template CUTLASS_DEVICE void mma( BlkCoord const& blk_coord, - ProblemShape const& problem_shape, + ProblemShape_ const& problem_shape, int iter_index, int iter_count, MainloopArguments const& mainloop_args, @@ -932,32 +957,79 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { ); auto thr_copy = copy_op.get_slice(_0{}); + Tensor quantized_regs = quantize(regs); + Tensor tCr = thr_copy.partition_S(quantized_regs); Tensor tCg = thr_copy.partition_D(gmem); - Tensor tCr = thr_copy.partition_S(quantize(regs)); Tensor tPc = thr_copy.partition_D(preds); copy_if(copy_op, tPc, tCr, tCg); } - template + template + CUTLASS_DEVICE void epilogue_clear( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args) { + + auto [Q, K, D, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) { + if (elem_less(cDK(i), select<1,2>(problem_shape))) { + gDK(i) = Element(0); + } + } + for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) { + if (elem_less(cDV(i), select<1,2>(problem_shape))) { + gDV(i) = Element(0); + } + } + } + + + template CUTLASS_DEVICE void epilogue( BlkCoord const& blk_coord, - ProblemShape const& problem_shape, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, MainloopArguments const& mainloop_args, EpilogueArguments const& epilogue_args, PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { auto [Q, K, D, HB] = problem_shape; - auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord; auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); tDKtDK.data() = TmemAllocation::kDK; - auto mDK = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in); auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, blk_coord_k, _0{}, blk_coord_batch); @@ -992,12 +1064,13 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); tDVtDV.data() = TmemAllocation::kDV; - auto mDV = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in); auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, blk_coord_k, _0{}, blk_coord_batch); Tensor cDV = domain_offset( - make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_coord(blk_coord_k * TileShapeK{}, _0{}), make_identity_tensor(take<0,2>(TileShapePDO{})) ); @@ -1041,10 +1114,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } - template + template CUTLASS_DEVICE void compute( BlkCoord const& blk_coord, - ProblemShape const& problem_shape, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, int iter_index, int iter_count, MainloopArguments const& mainloop_args, @@ -1075,7 +1149,14 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; - auto store_op = SM100_TMEM_STORE_32dp32b8x{}; + auto store_op = []() { + if constexpr (sizeof(Element) == 1) { + return SM100_TMEM_STORE_32dp32b4x{}; + } + else { + return SM100_TMEM_STORE_32dp32b8x{}; + } + }(); Tensor tSTtST = partition_fragment_C(TiledMmaKQ{}, select<0,1>(TileShapeKQ{}))(make_coord(_,_),_0{},_0{}); tSTtST.data() = TmemAllocation::kS; @@ -1093,17 +1174,32 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { auto thread_t2r = tiled_t2r.get_slice(dp_idx); auto split_wg = [&](auto const& t) { - if constexpr (decltype(rank(t))::value == 3) { - auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); - return p(_, _, make_coord(wg_idx, _)); + if constexpr (decltype(size<1>(t))::value > 1) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t)))); + return p(_, make_coord(wg_idx, _), _); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t), size<3>(t)))); + return p(_, make_coord(wg_idx, _), _, _); + } } else { - auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); - return p(_, _, _, make_coord(wg_idx, _)); + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + } }; - Tensor tTR_cST = split_wg(thread_t2r.partition_D(cST)); + + Tensor tTR_cST_p = thread_t2r.partition_D(cST); + Tensor tTR_cST = split_wg(tTR_cST_p); Tensor tTR_rST = make_tensor(shape(tTR_cST)); Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); @@ -1117,7 +1213,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); - auto tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); auto tDVcST = TiledMmaPDO{}.get_slice(_0{}).partition_A(cST); tDVrP.data() = TmemAllocation::kP; @@ -1126,7 +1221,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { auto thread_r2t = tiled_r2t.get_slice(dp_idx); auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP)); - auto tRT_cST = split_wg(thread_r2t.partition_S(tDVcST)); + auto tRT_cST_p = thread_r2t.partition_S(tDVcST); + auto tRT_cST = split_wg(tRT_cST_p); + + bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape); + int last_iter = iter_count - 1 + iter_index; CUTLASS_PRAGMA_NO_UNROLL while (iter_count > 0) { @@ -1145,13 +1244,21 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } }; - dispatch_bool(std::is_base_of_v && - warp_uniform(iter_index == get<1>(blk_coord)), [&](auto is_causal_masked_tile) { + bool leading_causal_masking = false; + if constexpr (std::is_base_of_v) { + leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord)); + } + bool trailing_residual_masking = false; + if constexpr (std::is_base_of_v) { + trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k); + } + + dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { // compute P = softmax(S, LSE) cute::copy(tiled_t2r, tTR_tST, tTR_rST); - if constexpr (std::is_base_of_v && decltype(is_causal_masked_tile)::value) { + if constexpr (decltype(is_masked_tile)::value) { Mask{}.apply_mask(tTR_rST, [&](int i) { auto c_transpose = tTR_cST(i); return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); @@ -1267,15 +1374,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } epilogue( - blk_coord, problem_shape, mainloop_args, epilogue_args, + blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args, pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state ); } - template + template CUTLASS_DEVICE void reduce( BlkCoord const& blk_coord, - ProblemShape const& problem_shape, + ProblemShape_ const& problem_shape, int iter_index, int iter_count, MainloopArguments const& mainloop_args, @@ -1290,7 +1397,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { auto [Q, K, D, HB] = problem_shape; - auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord; // must match TileShapeDQ auto load_op = SM100_TMEM_LOAD_32dp32b32x{}; @@ -1568,20 +1675,38 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { pipeline_init_wait(size(ClusterShape{})); - auto blk_coord = make_coord(_0{}, blockIdx.x, make_coord(blockIdx.y, blockIdx.z)); - auto problem_shape = params.problem_shape; + auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z)); + auto [problem_shape, blk_offset] = apply_variable_length_offset( + params.problem_shape, + blk_coord + ); int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{}); int iter_start = 0; if constexpr (std::is_base_of_v) { iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; } + if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { + return; + } iter_count -= iter_start; + if (iter_count <= 0) { + epilogue_clear( + blk_coord, + blk_offset, + problem_shape, + params.mainloop, + params.epilogue + ); + return; + } + if (role == WarpRole::Load) { warpgroup_reg_set(); load( blk_coord, + blk_offset, problem_shape, iter_start, iter_count, @@ -1624,6 +1749,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { compute( blk_coord, + blk_offset, problem_shape, iter_start, iter_count, diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp index 0649b087..489600f7 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp @@ -422,24 +422,30 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { static bool can_implement(Arguments const& args) { if (kIsCpAsync) { if ((args.mainloop.page_size & (args.mainloop.page_size - 1)) != 0) { + std::cerr << __FILE__ << "(" << __LINE__ << "): cpasync page size pow2\n"; return false; } if (args.mainloop.page_size > TileShapeS{}) { + std::cerr << __FILE__ << "(" << __LINE__ << "): cpasync page size too big\n"; return false; } } else { if (args.mainloop.ptr_page_table != nullptr && args.mainloop.page_size != TileShapeS{}) { + std::cerr << __FILE__ << "(" << __LINE__ << "): tma page size off\n"; return false; } } if (get<0>(args.problem_shape) != 128) { + std::cerr << __FILE__ << "(" << __LINE__ << "): heads off\n"; return false; } if (get<1>(args.problem_shape) <= 0) { + std::cerr << __FILE__ << "(" << __LINE__ << "): heads off\n"; return false; } if (args.split_kv <= 0) { + std::cerr << __FILE__ << "(" << __LINE__ << "): split-k off\n"; return false; } return true; diff --git a/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp index bb8cfb34..66883af4 100644 --- a/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp +++ b/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp @@ -44,29 +44,43 @@ template< class Fusion > void __global__ fmha_bwd_reference_dQ_kernel( - ProblemShape problem_shape, - TensorQ mQ, TensorK mK, TensorV mV, - TensorO mO, TensorLSE mLSE, TensorDO mDO, - TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */ + ProblemShape problem_shape_in, + TensorQ mQ_in, TensorK mK_in, TensorV mV_in, + TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in, + TensorDQ mDQ_in, /* TensorDK mDK, TensorDV mDV, */ Fusion fusion) { using namespace cute; + using namespace cutlass::fmha::collective; using Element = typename TensorO::value_type; using ElementAccumulator = typename TensorLSE::value_type; extern __shared__ char mS_mem[]; - Element* mS = reinterpret_cast(mS_mem); + ElementAccumulator* mS = reinterpret_cast(mS_mem); - Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in))); - for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) { - for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) { - for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) { + for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) { + auto [problem_shape, offset] = apply_variable_length_offset( + problem_shape_in, + make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in))) + ); + // problem_shape = problem_shape_in; + // offset = repeat_like(problem_shape_in, _0{}); + auto mQ = domain_offset(select<0,2,3>(offset), mQ_in); + auto mK = domain_offset(select<1,2,3>(offset), mK_in); + auto mV = domain_offset(select<1,2,3>(offset), mV_in); + auto mO = domain_offset(select<0,2,3>(offset), mO_in); + auto mLSE = domain_offset(select<0,3>(offset), mLSE_in); + auto mDO = domain_offset(select<0,2,3>(offset), mDO_in); + auto mDQ = domain_offset(select<0,2,3>(offset), mDQ_in); + for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape); idx_Q += gridDim.x) { + for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) { ElementAccumulator acc_qk = 0; ElementAccumulator acc_dov = 0; ElementAccumulator acc_doo = 0; - for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) { acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); @@ -78,15 +92,15 @@ void __global__ fmha_bwd_reference_dQ_kernel( fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); acc_qk = frag(0); - mS[idx_K] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); + mS[idx_K] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); } // for idx_K __syncthreads(); - for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) { + for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) { ElementAccumulator acc = 0; - for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) { - acc += mS[idx_K] * mK(idx_K, idx_D, idx_L); + for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) { + acc += mS[idx_K] * ElementAccumulator(mK(idx_K, idx_D, idx_L)); } mDQ(idx_Q, idx_D, idx_L) = static_cast(acc); } // for idx_D @@ -104,29 +118,43 @@ template< class Fusion > void __global__ fmha_bwd_reference_dK_kernel( - ProblemShape problem_shape, - TensorQ mQ, TensorK mK, TensorV mV, - TensorO mO, TensorLSE mLSE, TensorDO mDO, - /* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */ + ProblemShape problem_shape_in, + TensorQ mQ_in, TensorK mK_in, TensorV mV_in, + TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in, + /* TensorDQ mDQ_in, */ TensorDK mDK_in, /* TensorDV mDV_in, */ Fusion fusion) { using namespace cute; + using namespace cutlass::fmha::collective; using Element = typename TensorO::value_type; using ElementAccumulator = typename TensorLSE::value_type; extern __shared__ char mS_mem[]; - Element* mS = reinterpret_cast(mS_mem); + ElementAccumulator* mS = reinterpret_cast(mS_mem); - Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in))); - for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) { - for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) { - for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) { + for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) { + auto [problem_shape, offset] = apply_variable_length_offset( + problem_shape_in, + make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in))) + ); + // problem_shape = problem_shape_in; + // offset = repeat_like(problem_shape_in, _0{}); + auto mQ = domain_offset(select<0,2,3>(offset), mQ_in); + auto mK = domain_offset(select<1,2,3>(offset), mK_in); + auto mV = domain_offset(select<1,2,3>(offset), mV_in); + auto mO = domain_offset(select<0,2,3>(offset), mO_in); + auto mLSE = domain_offset(select<0,3>(offset), mLSE_in); + auto mDO = domain_offset(select<0,2,3>(offset), mDO_in); + auto mDK = domain_offset(select<1,2,3>(offset), mDK_in); + for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) { + for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) { ElementAccumulator acc_qk = 0; ElementAccumulator acc_dov = 0; ElementAccumulator acc_doo = 0; - for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) { acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); @@ -138,15 +166,15 @@ void __global__ fmha_bwd_reference_dK_kernel( fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); acc_qk = frag(0); - mS[idx_Q] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); + mS[idx_Q] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); } // for idx_Q __syncthreads(); - for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) { + for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) { ElementAccumulator acc = 0; - for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) { - acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L); + for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) { + acc += mS[idx_Q] * ElementAccumulator(mQ(idx_Q, idx_D, idx_L)); } mDK(idx_K, idx_D, idx_L) = static_cast(acc); } // for idx_D @@ -164,28 +192,42 @@ template< class Fusion > void __global__ fmha_bwd_reference_dV_kernel( - ProblemShape problem_shape, - TensorQ mQ, TensorK mK, TensorV mV, - TensorO mO, TensorLSE mLSE, TensorDO mDO, - /* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV, + ProblemShape problem_shape_in, + TensorQ mQ_in, TensorK mK_in, TensorV mV_in, + TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in, + /* TensorDQ mDQ_in, TensorDK mDK_in, */ TensorDV mDV_in, Fusion fusion) { using namespace cute; + using namespace cutlass::fmha::collective; using Element = typename TensorO::value_type; using ElementAcc = typename TensorLSE::value_type; extern __shared__ char mS_mem[]; - Element* mS = reinterpret_cast(mS_mem); + ElementAcc* mS = reinterpret_cast(mS_mem); - ElementAcc softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + ElementAcc softmax_scale = 1.0 / sqrt(ElementAcc(size<2>(problem_shape_in))); - for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) { - for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) { - for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) { + for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) { + auto [problem_shape, offset] = apply_variable_length_offset( + problem_shape_in, + make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in))) + ); + // problem_shape = problem_shape_in; + // offset = repeat_like(problem_shape_in, _0{}); + auto mQ = domain_offset(select<0,2,3>(offset), mQ_in); + auto mK = domain_offset(select<1,2,3>(offset), mK_in); + auto mV = domain_offset(select<1,2,3>(offset), mV_in); + auto mO = domain_offset(select<0,2,3>(offset), mO_in); + auto mLSE = domain_offset(select<0,3>(offset), mLSE_in); + auto mDO = domain_offset(select<0,2,3>(offset), mDO_in); + auto mDV = domain_offset(select<1,2,3>(offset), mDV_in); + for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) { + for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) { ElementAcc acc_qk = 0; - for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) { ElementAcc rQ = mQ(idx_Q, idx_D0, idx_L); ElementAcc rK = mK(idx_K, idx_D0, idx_L); acc_qk += rQ * rK; @@ -197,15 +239,15 @@ void __global__ fmha_bwd_reference_dV_kernel( fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); acc_qk = frag(0); - mS[idx_Q] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L))); + mS[idx_Q] = expf(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)); } // for idx_Q __syncthreads(); - for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) { + for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) { ElementAcc acc = 0; - for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) { - ElementAcc rS = mS[idx_Q]; + for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) { + ElementAcc rS = static_cast(mS[idx_Q]); ElementAcc rDO = mDO(idx_Q, idx_D, idx_L); acc += rS * rDO; } @@ -235,7 +277,7 @@ void fmha_bwd_reference_dQ( dim3 grid(size<0>(mDQ), size<2>(mDQ), 1); dim3 block(256); - int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type); + int shared_mem = size<0>(mK) * sizeof(typename TensorLSE::value_type); fmha_bwd_reference_dQ_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); } @@ -259,7 +301,7 @@ void fmha_bwd_reference_dK( dim3 grid(size<0>(mDK), size<2>(mDK), 1); dim3 block(256); - int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type); + int shared_mem = size<0>(mDO) * sizeof(typename TensorLSE::value_type); fmha_bwd_reference_dK_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); } @@ -283,7 +325,7 @@ void fmha_bwd_reference_dV( dim3 grid(size<0>(mDV), size<2>(mDV), 1); dim3 block(256); - int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type); + int shared_mem = size<0>(mDO) * sizeof(typename TensorLSE::value_type); fmha_bwd_reference_dV_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); } diff --git a/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu b/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu index 058c4b2b..7e186c72 100644 --- a/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu +++ b/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu @@ -36,7 +36,7 @@ This kernel is optimized for the GeForce RTX 50 series GPUs. The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (mma.sync.aligned.block_scale). - NVFP4 MMA has 2x throughput compared to MXFP8 MMA and 4x throughput compared to Ada Tensor Core FP8 MMA. + NVFP4 MMA has 2x throughput compared to MXFP8 MMA and 4x throughput compared to Ada Tensor Core FP8 MMA. (See https://docs.nvidia.com/cuda/parallel-thread-execution). This kernel leverages: @@ -44,11 +44,11 @@ 2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). 3. Block Scaled Tensor Core MMA Instructions 4. Epilogue Optimization - + Note that GeForce RTX 50 series GPUs do not support: 1. Multicast feature of TMA load. Cluster shape has to be 1x1x1. 2. Dynamic datatypes. - + Usage: $ ./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm --m=2048 --n=2048 --k=2048 @@ -122,7 +122,7 @@ using ThreadBlockShape = Shape<_128,_128,_128>; // T using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, + ArchTag, OperatorClass, ThreadBlockShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, @@ -193,13 +193,7 @@ cutlass::HostTensor block_referen template auto make_iterator(T* ptr) { - using namespace cute; - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -328,7 +322,7 @@ bool initialize_block( } cutlass::reference::host::TensorFillRandomUniform( view, seed, scope_max, scope_min, 0); - + return true; } @@ -411,7 +405,7 @@ bool verify(const Options &options) { auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); - + cutlass::reference::host::GettBlockScalingEpilogueParams< ElementAccumulator, // ElementScalar ElementAccumulator, // ElementAccumulator @@ -512,9 +506,9 @@ int main(int argc, char const **args) { cudaDeviceProp props; int current_device_id; CUDA_CHECK(cudaGetDevice(¤t_device_id)); - + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - + if (!(props.major == 12 && props.minor == 0)) { std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; return 0; diff --git a/examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu b/examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu index e3ebba4a..a0aa01c6 100644 --- a/examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu +++ b/examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu @@ -37,7 +37,7 @@ This kernel is optimized for the GeForce RTX 50 series GPUs. Similar to 79a_blackwell_geforce_nvfp4_bf16_gemm, this kernel leverages: - + 1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper. 2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). 3. Block Scaled Tensor Core MMA Instructions @@ -46,7 +46,7 @@ Note that GeForce RTX 50 series GPUs do not support: 1. Multicast feature of TMA load. Cluster shape has to be 1x1x1. 2. Dynamic datatypes. - + Usage: $ ./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm --m=2048 --n=2048 --k=2048 @@ -130,13 +130,13 @@ constexpr int OutputSFVectorSize = InputSFVectorSize; // With BlockScaleFactor generation. using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< OutputSFVectorSize, - ElementD, - ElementCompute, + ElementD, + ElementCompute, ElementSFD, LayoutSFDTag, ElementC>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, + ArchTag, OperatorClass, ThreadBlockShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, @@ -221,13 +221,7 @@ cutlass::HostTensor block_N template auto make_iterator(T* ptr) { - using namespace cute; - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -356,7 +350,7 @@ bool initialize_block( } cutlass::reference::host::TensorFillRandomUniform( view, seed, scope_max, scope_min, 0); - + return true; } @@ -455,7 +449,7 @@ bool verify(const Options &options) { auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); auto tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD); - + cutlass::reference::host::GettBlockScalingEpilogueParams< ElementAccumulator, // ElementScalar ElementAccumulator, // ElementAccumulator @@ -559,9 +553,9 @@ int main(int argc, char const **args) { cudaDeviceProp props; int current_device_id; CUDA_CHECK(cudaGetDevice(¤t_device_id)); - + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - + if (!(props.major == 12 && props.minor == 0)) { std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; return 0; diff --git a/examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu b/examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu index ac2f39c9..655719d9 100644 --- a/examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu +++ b/examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu @@ -36,7 +36,7 @@ This kernel is optimized for the GeForce RTX 50 series GPUs. The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (mma.sync.aligned.block_scale). - MXFP8 MMA has 2x throughput compared to Ada Tensor Core FP8 MMA. + MXFP8 MMA has 2x throughput compared to Ada Tensor Core FP8 MMA. (See https://docs.nvidia.com/cuda/parallel-thread-execution). Similar to 79a_blackwell_geforce_nvfp4_bf16_gemm, this kernel leverages: @@ -48,7 +48,7 @@ Note that GeForce RTX 50 series GPUs do not support: 1. Multicast feature of TMA load. Cluster shape has to be 1x1x1. 2. Dynamic datatypes. - + Usage: $ ./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_bf16_gemm --m=2048 --n=2048 --k=2048 @@ -122,7 +122,7 @@ using ThreadBlockShape = Shape<_128,_128,_128>; // T using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, + ArchTag, OperatorClass, ThreadBlockShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, @@ -193,13 +193,7 @@ cutlass::HostTensor block_referen template auto make_iterator(T* ptr) { - using namespace cute; - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -328,7 +322,7 @@ bool initialize_block( } cutlass::reference::host::TensorFillRandomUniform( view, seed, scope_max, scope_min, 0); - + return true; } @@ -411,7 +405,7 @@ bool verify(const Options &options) { auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); - + cutlass::reference::host::GettBlockScalingEpilogueParams< ElementAccumulator, // ElementScalar ElementAccumulator, // ElementAccumulator @@ -512,9 +506,9 @@ int main(int argc, char const **args) { cudaDeviceProp props; int current_device_id; CUDA_CHECK(cudaGetDevice(¤t_device_id)); - + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - + if (!(props.major == 12 && props.minor == 0)) { std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; return 0; diff --git a/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu b/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu index 2b55c465..48df9108 100644 --- a/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu +++ b/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu @@ -137,8 +137,8 @@ constexpr int OutputSFVectorSize = 16; // With BlockScaleFactor generation. using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< OutputSFVectorSize, - ElementD, - ElementCompute, + ElementD, + ElementCompute, ElementSFD, LayoutCTag, ElementC>; @@ -201,7 +201,7 @@ using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutS using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< - OutputSFVectorSize, + OutputSFVectorSize, cute::is_same_v ? cute::UMMA::Major::K : cute::UMMA::Major::MN >; @@ -267,13 +267,7 @@ cutlass::DeviceAllocation norm_constant_device; template auto make_iterator(T* ptr) { - using namespace cute; - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -511,7 +505,7 @@ bool initialize_block( } cutlass::reference::host::TensorFillRandomUniform( view, seed, scope_max, scope_min, 0); - + return true; } @@ -760,9 +754,9 @@ bool verify(const Options &options) { decltype(tensor_SFA), decltype(tensor_B), decltype(tensor_SFB) - > + > mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; - + auto tensor_C = cute::make_tensor(make_iterator(block_C.at(i).host_data()), layout_C); auto tensor_ref_D = cute::make_tensor(make_iterator(block_ref_D.at(i).host_data()), layout_D); auto tensor_ref_SFD = cute::make_tensor(make_iterator(block_ref_SFD.at(i).host_data()), layout_SFD); @@ -777,7 +771,7 @@ bool verify(const Options &options) { cute::Int, cutlass::reference::host::SfStrategy::SfDGen > epilogue_params {alpha_host.at(i), beta_host.at(i), tensor_C, tensor_ref_D, tensor_ref_SFD, options.norm_constant}; - + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); // Comparison @@ -842,7 +836,7 @@ int run(Options &options, bool host_problem_shapes_available = true) } else { std::cout << " Verfication is turned off for this run." << std::endl; - } + } // Run profiling loop if (options.iterations > 0) @@ -918,7 +912,7 @@ int main(int argc, char const **args) { std::cout << "Running kernel with Cooperative kernel schedule:" << std::endl; run(options, false /*host_problem_shapes_available*/); std::cout << "Running kernel with Pingpong kernel schedule:" << std::endl; - run(options, false /*host_problem_shapes_available*/); + run(options, false /*host_problem_shapes_available*/); #endif return 0; diff --git a/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu b/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu index 32df1146..e7253c24 100644 --- a/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu +++ b/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu @@ -108,7 +108,7 @@ using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f using ThreadBlockShape = Shape<_128,_128,_256>; // Threadblock's tile size using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, + ArchTag, OperatorClass, ThreadBlockShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, @@ -175,13 +175,7 @@ cutlass::HostTensor block_referen #endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) template auto make_iterator(T* ptr) { - using namespace cute; - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Testbed utility types @@ -289,7 +283,7 @@ bool initialize_block( } cutlass::reference::host::TensorFillRandomUniform( view, seed, scope_max, scope_min, 0); - + return true; } /// Initialize blocks that released to sparse Matrix A and its metadata E @@ -465,7 +459,7 @@ template int run(Options &options) { // Initialization - if(!initialize(options)) + if(!initialize(options)) { std::cerr << " Initialization failed! " << std::endl; exit(-1); @@ -527,9 +521,9 @@ int main(int argc, char const **args) { cudaDeviceProp props; int current_device_id; CUDA_CHECK(cudaGetDevice(¤t_device_id)); - + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - + if (!(props.major == 12 && props.minor == 0)) { std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; return 0; diff --git a/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu b/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu index f3441b56..b5ba430e 100644 --- a/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu +++ b/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu @@ -111,7 +111,7 @@ using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecializedNvf4Sm1 using ThreadBlockShape = Shape<_128,_128,_256>; // Threadblock's tile size using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, + ArchTag, OperatorClass, ThreadBlockShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, @@ -186,13 +186,7 @@ cutlass::HostTensor block_N #endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) template auto make_iterator(T* ptr) { - using namespace cute; - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Testbed utility types @@ -300,7 +294,7 @@ bool initialize_block( } cutlass::reference::host::TensorFillRandomUniform( view, seed, scope_max, scope_min, 0); - + return true; } /// Initialize blocks that released to sparse Matrix A and its metadata E @@ -489,7 +483,7 @@ template int run(Options &options) { // Initialization - if(!initialize(options)) + if(!initialize(options)) { std::cerr << " Initialization failed! " << std::endl; exit(-1); @@ -551,9 +545,9 @@ int main(int argc, char const **args) { cudaDeviceProp props; int current_device_id; CUDA_CHECK(cudaGetDevice(¤t_device_id)); - + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - + if (!(props.major == 12 && props.minor == 0)) { std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; return 0; diff --git a/examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu b/examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu index d2d87c46..4f1b4f49 100644 --- a/examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu +++ b/examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu @@ -41,8 +41,8 @@ Similar to 83_blackwell_sparse_gemm, this kernel leverages: 1. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/). - 2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM - which allows us to decouple the execution of MMA and epilogue into separate warps. + 2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM + which allows us to decouple the execution of MMA and epilogue into separate warps. 3. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). @@ -123,8 +123,8 @@ using ArchTag = cutlass::arch::Sm100; // using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag // MMA and Cluster Tile Shapes -// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 -using MmaTileShape = Shape<_256,_128,_256>; +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape = Shape<_256,_128,_256>; // Shape of the threadblocks in a cluster using ClusterShape = Shape<_2,_1,_1>; @@ -157,7 +157,7 @@ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< ProblemShape, CollectiveMainloop, CollectiveEpilogue, - void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -244,13 +244,7 @@ cutlass::HostTensor reference_D; template auto make_iterator(T* ptr) { - using namespace cute; - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -536,18 +530,18 @@ bool verify(const Options &options) { // Create the arguments for host reference implementation auto A = make_tensor(make_iterator(tensor_A.host_data()), layout_A); auto SFA = make_tensor(tensor_SFA.host_data(), layout_SFA); - auto B = make_tensor(make_iterator(tensor_B.host_data()), + auto B = make_tensor(make_iterator(tensor_B.host_data()), make_layout(make_shape(options.n, options.k, options.l), stride_B)); auto SFB = make_tensor(tensor_SFB.host_data(), layout_SFB); cutlass::reference::host::GettMainloopParams< - ElementAccumulator, - decltype(A), - decltype(B), - decltype(SFA), + ElementAccumulator, + decltype(A), + decltype(B), + decltype(SFA), decltype(SFB)> mainloop_params{A, SFA, B, SFB}; - auto C = make_tensor(make_iterator(tensor_C.host_data()), + auto C = make_tensor(make_iterator(tensor_C.host_data()), make_layout(make_shape(options.m, options.n, options.l), stride_C)); auto D = make_tensor(make_iterator(reference_D.host_data()), make_layout(make_shape(options.m, options.n, options.l), stride_D)); @@ -563,7 +557,7 @@ bool verify(const Options &options) { options.beta, C, D}; - + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); // Comparison diff --git a/examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu b/examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu index a23af158..ae922472 100644 --- a/examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu +++ b/examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu @@ -41,8 +41,8 @@ Similar to 83_blackwell_sparse_gemm, this kernel leverages: 1. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/). - 2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM - which allows us to decouple the execution of MMA and epilogue into separate warps. + 2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM + which allows us to decouple the execution of MMA and epilogue into separate warps. 3. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). @@ -123,8 +123,8 @@ using ArchTag = cutlass::arch::Sm100; // using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag // MMA and Cluster Tile Shapes -// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 -using MmaTileShape_MNK = Shape<_256,_128,_256>; +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape_MNK = Shape<_256,_128,_256>; // Shape of the threadblocks in a cluster using ClusterShape_MNK = Shape<_2,_1,_1>; @@ -157,7 +157,7 @@ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< ProblemShape, CollectiveMainloop, CollectiveEpilogue, - void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -244,13 +244,7 @@ cutlass::HostTensor reference_D; template auto make_iterator(T* ptr) { - using namespace cute; - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -536,18 +530,18 @@ bool verify(const Options &options) { // Create the arguments for host reference implementation auto A = make_tensor(make_iterator(tensor_A.host_data()), layout_A); auto SFA = make_tensor(tensor_SFA.host_data(), layout_SFA); - auto B = make_tensor(make_iterator(tensor_B.host_data()), + auto B = make_tensor(make_iterator(tensor_B.host_data()), make_layout(make_shape(options.n, options.k, options.l), stride_B)); auto SFB = make_tensor(tensor_SFB.host_data(), layout_SFB); cutlass::reference::host::GettMainloopParams< - ElementAccumulator, - decltype(A), - decltype(B), - decltype(SFA), + ElementAccumulator, + decltype(A), + decltype(B), + decltype(SFA), decltype(SFB)> mainloop_params{A, SFA, B, SFB}; - auto C = make_tensor(make_iterator(tensor_C.host_data()), + auto C = make_tensor(make_iterator(tensor_C.host_data()), make_layout(make_shape(options.m, options.n, options.l), stride_C)); auto D = make_tensor(make_iterator(reference_D.host_data()), make_layout(make_shape(options.m, options.n, options.l), stride_D)); @@ -560,7 +554,7 @@ bool verify(const Options &options) { decltype(C), // TensorC decltype(D) // TensorD > epilogue_params{}; - + epilogue_params.C = C; epilogue_params.D = D; epilogue_params.alpha = options.alpha; diff --git a/examples/python/CuTeDSL/ampere/call_from_jit.py b/examples/python/CuTeDSL/ampere/call_from_jit.py new file mode 100644 index 00000000..ffe2eb70 --- /dev/null +++ b/examples/python/CuTeDSL/ampere/call_from_jit.py @@ -0,0 +1,259 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +Demonstrating JIT GEMM Implementation with Static Shape Wrapper + +This example illustrates how to invoke a JIT-compiled GEMM implementation through a wrapper function +with static shapes. It showcases the integration between PyTorch and CuTe tensors in a JIT context. + +Key features demonstrated: +1. Seamless conversion between PyTorch and CuTe tensors using the JitArgument protocol +2. Integration of static shape GEMM operations within a JIT-compiled wrapper function + +Core components: +- BufferWithLayout: Handles memory buffer management with configurable stride ordering +- tensor_op_gemm_wrapper: JIT-compiled entry point that orchestrates the GEMM operation + +Usage: + +.. code-block:: bash + + python examples/ampere/call_from_jit.py + +Default configuration: +- Batch dimension (L): 16 +- Matrix dimensions: M=512, N=256, K=128 +- Precision: Float16 inputs with Float32 accumulation + +Requirements: +- CUDA-capable GPU +- PyTorch with CUDA support +""" + +import os +import sys +from typing import Type, Tuple + +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.torch import dtype as torch_dtype +from cutlass.cute.runtime import make_ptr + + +# Add the current directory to sys.path +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from tensorop_gemm import TensorOpGemm + + +class BufferWithLayout: + def __init__(self, ptr: cute.Pointer, stride_order: tuple[int, int, int]): + self.ptr = ptr + + # static properties + self.stride_order = stride_order + + def to_tensor( + self, shape: tuple[int, int, int], *, loc=None, ip=None + ) -> cute.Tensor: + assert len(shape) == len(self.stride_order), ( + f"Shape {shape} and stride_order {self.stride_order} must have the " + "same rank." + ) + layout = cute.make_ordered_layout(shape, self.stride_order) + # permute (l, mn, k) -> (mn, k, l) + res = cute.make_tensor(self.ptr, cute.select(layout, mode=[1, 2, 0])) + return res + + # Implement JitArgument Protocol and DynamicExpression Protocol + + def __c_pointers__(self): + """Get the C pointers for the underlying pointer. + + This method is part of the JitArgument Protocol and returns the C pointers + from the underlying pointer object. + + This is required for user to define a custom data type which can pass to JIT function. + When JIT compiled function is called, JIT executor will call this method to get raw pointers + to underlying data object. + + Following condition must be satisfied: + + len(__c_pointers__()) == len(__get_mlir_types__()) == len(__extract_mlir_values__()) + + :return: The C pointers from the underlying pointer object + :rtype: Any + """ + return self.ptr.__c_pointers__() + + def __get_mlir_types__(self): + """Get the MLIR types for the underlying pointer. + + This method is part of the JitArgument Protocol and returns the MLIR types + used for compiler to generate code. It must match the type of the underlying pointers + returned by __c_pointers__(). + + :return: The MLIR types from the underlying pointer object + :rtype: Any + """ + return self.ptr.__get_mlir_types__() + + def __extract_mlir_values__(self): + """Extract MLIR values from the underlying pointer. + + This method is part of the DynamicExpression Protocol and extracts MLIR values + from the underlying pointer object. + + It is used by compiler to generate function call in MLIR to another JIT function. + It must match the types returned by __get_mlir_types__(). + + :return: The MLIR values extracted from the underlying pointer object + :rtype: Any + """ + return self.ptr.__extract_mlir_values__() + + def __new_from_mlir_values__(self, values): + """Create a new BufferWithLayout instance from MLIR values. + + This method is part of the JitArgument & DynamicExpression Protocol and creates a new + BufferWithLayout instance with pointer initialized from the given MLIR values. + + It is used by compiler to generate function body in MLIR called by JIT function. + It must match the types returned by __c_pointers__() and __get_mlir_types__(). + code generator takes function arguments and reconstructs python object which is legal + inside function body. + + :param values: MLIR values to initialize the underlying pointer + :type values: Any + :return: A new BufferWithLayout instance with pointer initialized from values + :rtype: BufferWithLayout + """ + return BufferWithLayout( + self.ptr.__new_from_mlir_values__(values), self.stride_order + ) + + +@cute.jit +def tensor_op_gemm_wrapper( + buffer_a: BufferWithLayout, + buffer_b: BufferWithLayout, + buffer_c: BufferWithLayout, + mnkl: cutlass.Constexpr[tuple[int, int, int, int]], + acc_dtype: Type[cutlass.Numeric], + atom_layout_mnk: cutlass.Constexpr[tuple[int, int, int]], +): + print(f"\n[DSL INFO] Input Parameters:") + print(f"[DSL INFO] mnkl: {mnkl}") + print(f"[DSL INFO] buffer_a: {buffer_a}") + print(f"[DSL INFO] buffer_b: {buffer_b}") + print(f"[DSL INFO] buffer_c: {buffer_c}") + print(f"[DSL INFO] acc_dtype: {acc_dtype}") + print(f"[DSL INFO] atom_layout_mnk: {atom_layout_mnk}") + + mA = buffer_a.to_tensor(cute.select(mnkl, mode=[3, 0, 2])) + mB = buffer_b.to_tensor(cute.select(mnkl, mode=[3, 1, 2])) + mC = buffer_c.to_tensor(cute.select(mnkl, mode=[3, 0, 1])) + + print(f"\n[DSL INFO] Created Tensors:") + print(f"[DSL INFO] mA = {mA}") + print(f"[DSL INFO] mB = {mB}") + print(f"[DSL INFO] mC = {mC}") + + tensor_op_gemm = TensorOpGemm( + buffer_a.ptr.value_type, + buffer_c.ptr.value_type, + acc_dtype, + atom_layout_mnk, + ) + print(f"\n[DSL INFO] Created TensorOpGemm instance") + print(f"[DSL INFO] Input dtype: {buffer_a.ptr.value_type}") + print(f"[DSL INFO] Output dtype: {buffer_c.ptr.value_type}") + print(f"[DSL INFO] Accumulation dtype: {acc_dtype}") + print(f"[DSL INFO] Atom layout: {atom_layout_mnk}") + + # No need to compile inside jit function + tensor_op_gemm(mA, mB, mC) + print(f"\n[DSL INFO] Executed TensorOpGemm") + + +def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]): + print(f"\nRunning TensorOpGemm test with:") + print(f"Tensor dimensions: {mnkl}") + + ab_dtype = cutlass.Float16 + c_dtype = cutlass.Float16 + + a = torch.randn( + mnkl[3], mnkl[0], mnkl[2], dtype=torch_dtype(ab_dtype), device="cuda" + ) + b = torch.randn( + mnkl[3], mnkl[1], mnkl[2], dtype=torch_dtype(ab_dtype), device="cuda" + ) + c = torch.randn( + mnkl[3], mnkl[0], mnkl[1], dtype=torch_dtype(c_dtype), device="cuda" + ) + + print(f"Input tensor shapes:") + print(f"a: {a.shape}, dtype: {a.dtype}") + print(f"b: {b.shape}, dtype: {b.dtype}") + print(f"c: {c.shape}, dtype: {c.dtype}\n") + + buffer_a = BufferWithLayout( + make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem), + (2, 1, 0), + ) + buffer_b = BufferWithLayout( + make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem), + (2, 1, 0), + ) + buffer_c = BufferWithLayout( + make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem), + (2, 1, 0), + ) + + tensor_op_gemm_wrapper( + buffer_a, + buffer_b, + buffer_c, + mnkl, # pass shape as static value + # no stride passing + cutlass.Float32, + (2, 2, 1), + ) + torch.cuda.synchronize() + + ref = torch.einsum("lmk,lnk->lmn", a, b) + torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05) + print(f"\n[DSL INFO] Results verified successfully!") + print(f"First few elements of result: \n{c[:3, :3, :3]}") + + +if __name__ == "__main__": + run_tensor_op_gemm_wrapper((512, 256, 128, 16)) diff --git a/examples/python/CuTeDSL/ampere/elementwise_add.py b/examples/python/CuTeDSL/ampere/elementwise_add.py index dc70a913..6b244b01 100644 --- a/examples/python/CuTeDSL/ampere/elementwise_add.py +++ b/examples/python/CuTeDSL/ampere/elementwise_add.py @@ -28,16 +28,17 @@ import argparse -import torch import time from typing import Type import cuda.bindings.driver as cuda +import torch import cutlass import cutlass.cute as cute -from cutlass.cute.runtime import from_dlpack +import cutlass.cute.testing as testing import cutlass.torch as cutlass_torch +from cutlass.cute.runtime import from_dlpack """ An Elementwise Addition Example using CuTe DSL. @@ -153,6 +154,7 @@ def elementwise_add_kernel( blkC = gC[blk_coord] # (TileM,TileN) blkCrd = cC[blk_coord] # (TileM, TileN) + # Note: these prints only run at compile/jit time print(f"[DSL INFO] Sliced Tensors per thread block:") print(f"[DSL INFO] blkA = {blkA.type}") print(f"[DSL INFO] blkB = {blkB.type}") @@ -189,7 +191,7 @@ def elementwise_add_kernel( print(f"[DSL INFO] thrC = {thrC.type}") print(f"[DSL INFO] thrCrd = {thrCrd.type}") - for i in cutlass.range_dynamic(0, cute.size(frgPred), 1): + for i in range(0, cute.size(frgPred), 1): val = cute.elem_less(thrCrd[i], shape) frgPred[i] = val @@ -270,9 +272,6 @@ def run_elementwise_add( warmup_iterations=2, iterations=200, ): - if not torch.cuda.is_available(): - raise RuntimeError(f"Ampere GPU is required to run this example!") - print(f"\nRunning Elementwise Add test with:") print(f"Tensor dimensions: [{M}, {N}]") print(f"Input and Output Data type: {dtype}") @@ -315,10 +314,8 @@ def run_elementwise_add( print("Executing vector add kernel...") - # Get current CUDA stream from PyTorch - torch_stream = torch.cuda.current_stream() - # Get the raw stream pointer as a CUstream - current_stream = cuda.CUstream(torch_stream.cuda_stream) + # Get current CUstream from torch + current_stream = cutlass_torch.current_stream() if not skip_ref_check: compiled_func(a_tensor, b_tensor, c_tensor) @@ -329,41 +326,52 @@ def run_elementwise_add( if not benchmark: return - # Create CUDA events for timing - start_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1] - end_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1] + def generate_tensors(): + if dtype.is_integer: + a = torch.randint( + 0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype + ) + b = torch.randint( + 0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype + ) + else: + a = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) + b = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) - # Warmup - for _ in range(warmup_iterations): - compiled_func(a_tensor, b_tensor, c_tensor) + c = torch.zeros_like(a) - # Use the current stream for CUDA events instead of the default stream - # Record start event - cuda.cuEventRecord(start_event, current_stream) + if not is_a_dynamic_layout: + a_tensor = from_dlpack(a).mark_layout_dynamic() + else: + a_tensor = a - # Execute the kernel - for _ in range(iterations): - compiled_func(a_tensor, b_tensor, c_tensor) + if not is_b_dynamic_layout: + b_tensor = from_dlpack(b).mark_layout_dynamic() + else: + b_tensor = b - # Record end event - cuda.cuEventRecord(end_event, current_stream) - cuda.cuEventSynchronize(end_event) + if not is_result_dynamic_layout: + c_tensor = from_dlpack(c).mark_layout_dynamic() + else: + c_tensor = c - # Calculate elapsed time - err, elapsed_time = cuda.cuEventElapsedTime(start_event, end_event) - avg_time = elapsed_time / iterations + return testing.JitArguments(a_tensor, b_tensor, c_tensor) + + avg_time_us = testing.benchmark( + compiled_func, + workspace_generator=generate_tensors, + workspace_count=10, + warmup_iterations=warmup_iterations, + profiling_iterations=iterations, + ) # Print execution results - print(f"Kernel execution time: {avg_time:.4f} ms") + print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms") print( - f"Achieved memory throughput: {(3 * a.numel() * dtype.width // 8) / (avg_time / 1000) / 1e9:.2f} GB/s" + f"Achieved memory throughput: {(3 * a.numel() * dtype.width // 8) / (avg_time_us / 1e6) / 1e9:.2f} GB/s" ) print(f"First few elements of result: \n{c[:3, :3]}") - # Destroy events - cuda.cuEventDestroy(start_event) - cuda.cuEventDestroy(end_event) - if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -377,6 +385,10 @@ if __name__ == "__main__": parser.add_argument("--benchmark", action="store_true") args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError(f"Ampere GPU is required to run this example!") + run_elementwise_add( args.M, args.N, diff --git a/examples/python/CuTeDSL/ampere/elementwise_apply.py b/examples/python/CuTeDSL/ampere/elementwise_apply.py index b395e9f5..649c7789 100644 --- a/examples/python/CuTeDSL/ampere/elementwise_apply.py +++ b/examples/python/CuTeDSL/ampere/elementwise_apply.py @@ -29,14 +29,15 @@ import argparse import operator -import torch -from typing import Type import time +from typing import Type, List import cuda.bindings.driver as cuda +import torch import cutlass import cutlass.cute as cute +import cutlass.cute.testing as testing import cutlass.torch as cutlass_torch from cutlass.cute.runtime import from_dlpack @@ -77,8 +78,7 @@ while maintaining high performance through efficient memory access patterns. @cute.kernel def elementwise_apply_kernel( op: cutlass.Constexpr, - gA: cute.Tensor, - gB: cute.Tensor, + inputs: List[cute.Tensor], gC: cute.Tensor, cC: cute.Tensor, # coordinate tensor shape: cute.Shape, @@ -90,48 +90,46 @@ def elementwise_apply_kernel( # slice for CTAs cta_coord = ((None, None), bidx) # logical coord -> address - ctaA = gA[cta_coord] # (TileM, TileN) - ctaB = gB[cta_coord] # (TileM, TileN) + # Leverage the meta-programming capability of the DSL to slice the tensors for each input + # All for loops below on input tensors would be fully unrolled automatically at compile time + ctaInputs = [t[cta_coord] for t in inputs] # (TileM, TileN) ctaC = gC[cta_coord] # (TileM, TileN) ctaCrd = cC[cta_coord] # (TileM, TileN) print(f"[DSL INFO] Sliced Tensors per thread block:") - print(f"[DSL INFO] ctaA = {ctaA.type}") - print(f"[DSL INFO] ctaB = {ctaB.type}") + for i in cutlass.range_constexpr(len(ctaInputs)): + print(f"[DSL INFO] ctaInputs{i} = {ctaInputs[i].type}") print(f"[DSL INFO] ctaC = {ctaC.type}") print(f"[DSL INFO] ctaCrd = {ctaCrd.type}") # compose with CTA TV layout # (tid, vid) -> address - tidfrgA = cute.composition(ctaA, tv_layout) - tidfrgB = cute.composition(ctaB, tv_layout) + tidfrgInputs = [cute.composition(t, tv_layout) for t in ctaInputs] tidfrgC = cute.composition(ctaC, tv_layout) tidfrgCrd = cute.composition(ctaCrd, tv_layout) # print(f"{tv_layout = }") - # print(f"{tidfrgA = }") + # print(f"{tidfrgAB[0] = }") thr_coord = (tidx, (None, None)) # slice for threads # vid -> address - thrA = tidfrgA[thr_coord] # (V) - thrB = tidfrgB[thr_coord] # (V) + thrInputs = [t[thr_coord] for t in tidfrgInputs] # (V) thrC = tidfrgC[thr_coord] # (V) thrCrd = tidfrgCrd[thr_coord] print(f"[DSL INFO] Sliced Tensors per thread:") - print(f"[DSL INFO] thrA = {thrA.type}") - print(f"[DSL INFO] thrB = {thrB.type}") + for i in cutlass.range_constexpr(len(thrInputs)): + print(f"[DSL INFO] thrInputs{i} = {thrInputs[i].type}") print(f"[DSL INFO] thrC = {thrC.type}") print(f"[DSL INFO] thrCrd = {thrCrd.type}") # allocate fragments for gmem->rmem - frgA = cute.make_fragment_like(thrA, gA.element_type) - frgB = cute.make_fragment_like(thrB, gB.element_type) + frgInputs = [cute.make_fragment_like(t, t.element_type) for t in thrInputs] frgC = cute.make_fragment_like(thrC, gC.element_type) frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean) - for i in cutlass.range_dynamic(cute.size(frgPred), unroll=1): + for i in cutlass.range(cute.size(frgPred), unroll=1): frgPred[i] = cute.elem_less(thrCrd[i], shape) # if tidx == 0 and bidx == 0: @@ -142,10 +140,13 @@ def elementwise_apply_kernel( ########################################################## # declare the atoms which will be used later for memory copy + # Compile time validation: expect same element type for all input tensors so as to reuse the copy atom for load + assert all(t.element_type == inputs[0].element_type for t in inputs) + copy_atom_load = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), - gA.element_type, - num_bits_per_copy=gA.element_type.width, + inputs[0].element_type, + num_bits_per_copy=inputs[0].element_type.width, ) copy_atom_store = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), @@ -153,12 +154,12 @@ def elementwise_apply_kernel( num_bits_per_copy=gC.element_type.width, ) - cute.copy(copy_atom_load, thrA, frgA, pred=frgPred) - cute.copy(copy_atom_load, thrB, frgB, pred=frgPred) + for thrInput, frgInput in zip(thrInputs, frgInputs): + cute.copy(copy_atom_load, thrInput, frgInput, pred=frgPred) # Load data before use. The compiler will optimize the copy and load # operations to convert some memory ld/st into register uses. - result = op(frgA.load(), frgB.load()) + result = op(*[frgInput.load() for frgInput in frgInputs]) # Save the results back to registers. Here we reuse b's registers. frgC.store(result) @@ -173,6 +174,7 @@ def elementwise_apply( a: cute.Tensor, b: cute.Tensor, result: cute.Tensor, + stream: cuda.CUstream, ): """CUDA kernel applying binary operator on each element of two n-D input tensors in CuTe Python and store to result tensor. @@ -262,8 +264,7 @@ def elementwise_apply( # Async token(s) can also be specified as dependencies elementwise_apply_kernel( op, - gA, - gB, + [gA, gB], # Group input tensors into a list as a single argument gC, cC, result.shape, @@ -271,6 +272,7 @@ def elementwise_apply( ).launch( grid=[cute.size(gC, mode=[1]), 1, 1], block=[cute.size(tv_layout, mode=[0]), 1, 1], + stream=stream, ) @@ -287,6 +289,11 @@ def run_elementwise_apply_and_verify( if not torch.cuda.is_available(): raise RuntimeError(f"Ampere GPU is required to run this example!") + # Create non default CUDA stream from PyTorch + torch_stream = torch.cuda.Stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + print(f"\nRunning Elementwise Apply test with:") print(f"Tensor dimensions: [{M}, {N}]") print(f"Input and Output Data type: {dtype}") @@ -309,20 +316,16 @@ def run_elementwise_apply_and_verify( if op in (operator.truediv, operator.floordiv): b = torch.where(b == 0, torch.tensor(epsilon), b) - print("Compiling kernel with cute.compile ...") - start_time = time.time() - compiled_func = cute.compile(elementwise_apply, op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()) - compilation_time = time.time() - start_time - print(f"Compilation time: {compilation_time:.4f} seconds") - print("Executing elementwise apply kernel...") - # Get current CUDA stream from PyTorch - torch_stream = torch.cuda.current_stream() - # Get the raw stream pointer as a CUstream - current_stream = cuda.CUstream(torch_stream.cuda_stream) if not skip_ref_check: - compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()) + elementwise_apply( + op, + from_dlpack(a), + from_dlpack(b), + from_dlpack(c).mark_layout_dynamic(), + current_stream, + ) print("Verifying results...") torch.testing.assert_close(op(a, b), c) print("Results verified successfully!") @@ -330,28 +333,32 @@ def run_elementwise_apply_and_verify( if not benchmark: return - # Create CUDA events for timing - start_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1] - end_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1] + compiled_func = cute.compile( + elementwise_apply, + op, + from_dlpack(a), + from_dlpack(b), + from_dlpack(c).mark_layout_dynamic(), + current_stream, + ) - # Warmup - for _ in range(warmup_iterations): - compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()) + # When compiled we inlined op in the kernel, so we do not pass it when benchmarking - # Record start event - cuda.cuEventRecord(start_event, current_stream) + avg_time_us = testing.benchmark( + compiled_func, + kernel_arguments=testing.JitArguments( + from_dlpack(a), + from_dlpack(b), + from_dlpack(c).mark_layout_dynamic(), + current_stream, + ), + warmup_iterations=warmup_iterations, + profiling_iterations=iterations, + use_cuda_graphs=True, + stream=current_stream, + ) - # Execute the kernel - for _ in range(iterations): - compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()) - - # Record end event - cuda.cuEventRecord(end_event, current_stream) - cuda.cuEventSynchronize(end_event) - - # Calculate elapsed time - err, elapsed_time = cuda.cuEventElapsedTime(start_event, end_event) - avg_time = elapsed_time / iterations + avg_time = avg_time_us / 1e3 # Print execution results print(f"Kernel execution time: {avg_time:.4f} ms") @@ -360,10 +367,6 @@ def run_elementwise_apply_and_verify( ) print(f"First few elements of result: \n{c[:3, :3]}") - # Destroy events - cuda.cuEventDestroy(start_event) - cuda.cuEventDestroy(end_event) - if __name__ == "__main__": parser = argparse.ArgumentParser( diff --git a/examples/python/CuTeDSL/ampere/flash_attention_v2.py b/examples/python/CuTeDSL/ampere/flash_attention_v2.py index 0f41245e..b36ec1d1 100644 --- a/examples/python/CuTeDSL/ampere/flash_attention_v2.py +++ b/examples/python/CuTeDSL/ampere/flash_attention_v2.py @@ -542,13 +542,13 @@ class FlashAttentionForwardAmpere: cutlass.Boolean, ) # Set predicates for head_dim bounds, seqlen_q/k bounds is processed at the first tile. - for rest_v in range(tQpQ.shape[0]): - for rest_k in range(tQpQ.shape[2]): + for rest_v in cutlass.range_constexpr(tQpQ.shape[0]): + for rest_k in cutlass.range_constexpr(tQpQ.shape[2]): tQpQ[rest_v, 0, rest_k] = cute.elem_less( tQcQ[(0, rest_v), 0, rest_k][3], mQ.layout.shape[3] ) - for rest_v in range(tKVpKV.shape[0]): - for rest_k in range(tKVpKV.shape[2]): + for rest_v in cutlass.range_constexpr(tKVpKV.shape[0]): + for rest_k in cutlass.range_constexpr(tKVpKV.shape[2]): tKVpKV[rest_v, 0, rest_k] = cute.elem_less( tKVcKV[(0, rest_v), 0, rest_k][3], mK.layout.shape[3] ) @@ -556,7 +556,7 @@ class FlashAttentionForwardAmpere: # Prefetch Prologue # /////////////////////////////////////////////////////////////////////////////// # Start async loads of the last mn-tile, where we take care of the mn residue - for m in range(cute.size(tQsQ.shape[1])): + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): if cute.elem_less(tQcQ[0, m, 0][1], mQ.layout.shape[1]): cute.copy( gmem_tiled_copy_QKV, @@ -567,7 +567,7 @@ class FlashAttentionForwardAmpere: else: # Clear the smem tiles to account for predicated off loads tQsQ[None, m, None].fill(0) - for n in range(cute.size(tKsK.shape[1])): + for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): if cute.elem_less(tKVcKV[0, n, 0][1], mK.layout.shape[1]): cute.copy( gmem_tiled_copy_QKV, @@ -644,13 +644,13 @@ class FlashAttentionForwardAmpere: # We also need masking on S if it's causal, for the last ceil_div(m_block_size, n_block_size) blocks. # We will have at least 1 "masking" iteration. mask_steps = 1 - if self._is_causal: + if cutlass.const_expr(self._is_causal): mask_steps = cute.ceil_div(self._m_block_size, self._n_block_size) - for n_tile in range(mask_steps): + for n_tile in cutlass.range_constexpr(mask_steps): n_block = n_block_max - n_tile - 1 basic_params.n_block = n_block - if self._is_causal: + if cutlass.const_expr(self._is_causal): if n_block >= 0: self.compute_one_n_block( basic_params, @@ -673,7 +673,7 @@ class FlashAttentionForwardAmpere: ) # Start async loads of rest k-tiles in reverse order, no k-residue handling needed - for n_tile in cutlass.range_dynamic(mask_steps, n_block_max, 1): + for n_tile in range(mask_steps, n_block_max, 1): n_block = n_block_max - n_tile - 1 basic_params.n_block = n_block self.compute_one_n_block( @@ -748,13 +748,13 @@ class FlashAttentionForwardAmpere: ), cutlass.Boolean, ) - for rest_v in range(tOpO.shape[0]): - for rest_n in range(cute.size(tOpO.shape[2])): + for rest_v in cutlass.range_constexpr(tOpO.shape[0]): + for rest_n in cutlass.range_constexpr(cute.size(tOpO.shape[2])): tOpO[rest_v, 0, rest_n] = cute.elem_less( tOcO[(0, rest_v), 0, rest_n][3], mO.layout.shape[3] ) # copy acc O from rmem to gmem - for rest_m in range(cute.size(tOpO.shape[1])): + for rest_m in cutlass.range_constexpr(cute.size(tOpO.shape[1])): if cute.elem_less(tOcO[0, rest_m, 0][1], mO.layout.shape[1]): cute.copy( gmem_tiled_copy_O, @@ -804,7 +804,7 @@ class FlashAttentionForwardAmpere: # load smem tile V for O, special process for the first tile to avoid loading nan. # The `if` here is a constexpr, won't be generated in the IR. if is_first_n_block: - for n in range(cute.size(gmem_copy_params.tVsV.shape[1])): + for n in cutlass.range_constexpr(cute.size(gmem_copy_params.tVsV.shape[1])): if cute.elem_less( gmem_copy_params.tKVcKV[0, n, 0][1], basic_params.mK.layout.shape[1], @@ -841,7 +841,7 @@ class FlashAttentionForwardAmpere: smem_copy_params.tSrK_copy_view[None, None, 0], ) # mma for S - for k in range(cute.size(smem_copy_params.tSsQ.shape[2])): + for k in cutlass.range_constexpr(cute.size(smem_copy_params.tSsQ.shape[2])): # load next QK k-block from smem to rmem for mma k_next = (k + 1) % cute.size(smem_copy_params.tSsQ.shape[2]) cute.copy( @@ -916,7 +916,7 @@ class FlashAttentionForwardAmpere: smem_copy_params.tOrVt_copy_view[None, None, 0], ) # mma for O - for k in range(cute.size(tOrS.shape[2])): + for k in cutlass.range_constexpr(cute.size(tOrS.shape[2])): # load next V k-block from smem to rmem for mma k_next = (k + 1) % cute.size(tOrS.shape[2]) cute.copy( @@ -965,14 +965,14 @@ class FlashAttentionForwardAmpere: acc_O_mn = self._make_acc_tensor_mn_view(mma_params.acc_O) row_max_prev = None # if it is not the first tile, load the row r of previous row_max and compare with row_max_cur_row. - if not is_first_n_block: + if cutlass.const_expr(not is_first_n_block): row_max_prev = cute.make_fragment_like( softmax_params.row_max, cutlass.Float32 ) cute.basic_copy(softmax_params.row_max, row_max_prev) # if it is the first tile, create a mask for residual of S to -inf for softmax. tScS_mn = None - if in_mask_steps: + if cutlass.const_expr(in_mask_steps): mcS = cute.make_identity_tensor( ( basic_params.mQ.shape[0], @@ -990,12 +990,12 @@ class FlashAttentionForwardAmpere: tScS_mn = self._make_acc_tensor_mn_view(tScS) # Each iteration processes one row of acc_S - for r in range(cute.size(softmax_params.row_max)): + for r in cutlass.range_constexpr(cute.size(softmax_params.row_max)): # mask residual of S with -inf - if in_mask_steps: - if not self._is_causal: + if cutlass.const_expr(in_mask_steps): + if cutlass.const_expr(not self._is_causal): # traverse column index. - for c in range(cute.size(tScS_mn.shape[1])): + for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])): if cute.elem_less( basic_params.mK.shape[1], tScS_mn[0, c][3] + 1 ): @@ -1006,7 +1006,7 @@ class FlashAttentionForwardAmpere: tScS_mn[r, 0][1] + 1, basic_params.mK.shape[1] ) # traverse column index. - for c in range(cute.size(tScS_mn.shape[1])): + for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])): # only consider the column index, so the row index sets to 0. if cute.elem_less(col_idx_limit, tScS_mn[0, c][3] + 1): acc_S_mn[r, c] = -cutlass.Float32.inf @@ -1021,10 +1021,10 @@ class FlashAttentionForwardAmpere: row_max_cur_row = self._threadquad_reduce_max(row_max_cur_row) row_max_prev_row = None # if it is not the first tile, load the row r of previous row_max and compare with row_max_cur_row. - if not is_first_n_block: + if cutlass.const_expr(not is_first_n_block): row_max_prev_row = row_max_prev[r] row_max_cur_row = cute.arch.fmax(row_max_prev_row, row_max_cur_row) - if self._is_causal: + if cutlass.const_expr(self._is_causal): row_max_cur_row = ( 0.0 if row_max_cur_row == -cutlass.Float32.inf else row_max_cur_row ) @@ -1043,7 +1043,7 @@ class FlashAttentionForwardAmpere: cute.ReductionOp.ADD, cutlass.Float32.zero, 0 ) # if it is not the first tile, load the row r of previous row_max and minus row_max_cur_row to update row_sum. - if not is_first_n_block: + if cutlass.const_expr(not is_first_n_block): prev_minus_cur_exp = self._exp2f( row_max_prev_row * softmax_params.softmax_scale_log2 - row_max_cur_row * softmax_params.softmax_scale_log2 @@ -1072,7 +1072,7 @@ class FlashAttentionForwardAmpere: """ # do quad reduction for row_sum. acc_O_mn = self._make_acc_tensor_mn_view(acc_O) - for r in range(cute.size(row_sum)): + for r in cutlass.range_constexpr(cute.size(row_sum)): row_sum[r] = self._threadquad_reduce_sum(row_sum[r]) # if row_sum is zero or nan, set acc_O_mn_row to 1.0 acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r] diff --git a/examples/python/CuTeDSL/ampere/sgemm.py b/examples/python/CuTeDSL/ampere/sgemm.py index a4a032b4..474366e5 100644 --- a/examples/python/CuTeDSL/ampere/sgemm.py +++ b/examples/python/CuTeDSL/ampere/sgemm.py @@ -35,6 +35,8 @@ import torch import cutlass import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.torch as cutlass_torch import cutlass.utils as utils from cutlass.cute.runtime import from_dlpack @@ -109,6 +111,7 @@ class SGemm: mB: cute.Tensor, mC: cute.Tensor, epilogue_op: cutlass.Constexpr = lambda x: x, + stream: cuda.CUstream = cuda.CUstream(cuda.CUstream_flags.CU_STREAM_DEFAULT), ): self.a_major_mode = utils.LayoutEnum.from_tensor(mA) self.b_major_mode = utils.LayoutEnum.from_tensor(mB) @@ -168,7 +171,7 @@ class SGemm: num_bits_per_copy=mB.element_type.width, ) - if self.a_major_mode == utils.LayoutEnum.COL_MAJOR: + if cutlass.const_expr(self.a_major_mode == utils.LayoutEnum.COL_MAJOR): num_vectorized = 4 if (mA.layout.max_alignment % 16 == 0) else 1 atom_async_copy_A = cute.make_copy_atom( cute.nvgpu.cpasync.CopyG2SOp(), @@ -182,7 +185,7 @@ class SGemm: ) vA = cute.make_layout((num_vectorized, 1)) - if self.b_major_mode == utils.LayoutEnum.COL_MAJOR: + if cutlass.const_expr(self.b_major_mode == utils.LayoutEnum.COL_MAJOR): num_vectorized = 4 if (mB.layout.max_alignment % 16 == 0) else 1 atom_async_copy_B = cute.make_copy_atom( cute.nvgpu.cpasync.CopyG2SOp(), @@ -222,7 +225,7 @@ class SGemm: atoms_layout = cute.make_layout( (self._num_threads // 16, 16, 1), stride=(16, 1, 0) ) - if self.c_major_mode == utils.LayoutEnum.COL_MAJOR: + if cutlass.const_expr(self.c_major_mode == utils.LayoutEnum.COL_MAJOR): atoms_layout = cute.make_layout( (16, self._num_threads // 16, 1), stride=(1, 16, 0) ) @@ -256,6 +259,7 @@ class SGemm: grid=grid_dim, block=[cute.size(atoms_layout), 1, 1], smem=smem_size, + stream=stream, ) @cute.kernel @@ -540,8 +544,8 @@ class SGemm: # 3. Combining the smem and register pipelines results in the mainloop. # /////////////////////////////////////////////////////////////////////////////// - for _ in cutlass.range_dynamic(k_tile_count, unroll=1): - for k_block in range(k_block_max): + for _ in range(k_tile_count): + for k_block in range(k_block_max, unroll_full=True): if k_block == k_block_max - 1: tCsA_p = tCsA[None, None, None, smem_pipe_read] tCsB_p = tCsB[None, None, None, smem_pipe_read] @@ -639,7 +643,6 @@ def main( iterations: int = 100, skip_ref_check: bool = False, ): - torch.manual_seed(1024) M, N, K = problem_shape # Create and permute tensor A/B/C @@ -694,51 +697,36 @@ def main( sgemm = SGemm() + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + print("Compiling kernel with cute.compile ...") start_time = time.time() - gemm = cute.compile(sgemm, a_tensor, b_tensor, c_tensor) + gemm = cute.compile(sgemm, a_tensor, b_tensor, c_tensor, stream=current_stream) compilation_time = time.time() - start_time print(f"Compilation time: {compilation_time:.4f} seconds") print("Executing GEMM kernel...") - # Get current CUDA stream from PyTorch - torch_stream = torch.cuda.current_stream() - - # Get the raw stream pointer as a CUstream - current_stream = cuda.CUstream(torch_stream.cuda_stream) - - # Create CUDA events for timing - start_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1] - end_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1] - - # Warmup - for _ in range(warmup_iterations): - gemm(a_tensor, b_tensor, c_tensor) - - # Use the current stream for CUDA events instead of the default stream - # Record start event - cuda.cuEventRecord(start_event, current_stream) - - # Execute the kernel - for _ in range(iterations): - gemm(a_tensor, b_tensor, c_tensor) - - # Record end event - cuda.cuEventRecord(end_event, current_stream) - cuda.cuEventSynchronize(end_event) - - # Calculate elapsed time - err, elapsed_time = cuda.cuEventElapsedTime(start_event, end_event) + avg_time_us = testing.benchmark( + gemm, + kernel_arguments=testing.JitArguments( + a_tensor, b_tensor, c_tensor, current_stream + ), + warmup_iterations=warmup_iterations, + profiling_iterations=iterations, + use_cuda_graphs=False, + stream=current_stream, + ) # Print execution results - print(f"Kernel execution time: {elapsed_time / iterations:.4f} ms") - - # Destroy events - cuda.cuEventDestroy(start_event) - cuda.cuEventDestroy(end_event) + print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms") if not skip_ref_check: + gemm(a_tensor, b_tensor, c_tensor) + torch.cuda.synchronize() print("Verifying results...") ref = torch.einsum("mk,nk->mn", a, b) torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05) @@ -768,6 +756,9 @@ if __name__ == "__main__": args = parser.parse_args() print("Running SIMT GEMM example:") + + torch.manual_seed(1024) + main( args.a_major, args.b_major, diff --git a/examples/python/CuTeDSL/ampere/tensorop_gemm.py b/examples/python/CuTeDSL/ampere/tensorop_gemm.py index cc93f93d..2e6482f9 100644 --- a/examples/python/CuTeDSL/ampere/tensorop_gemm.py +++ b/examples/python/CuTeDSL/ampere/tensorop_gemm.py @@ -36,6 +36,7 @@ import torch import cutlass import cutlass.cute as cute +import cutlass.cute.testing as testing import cutlass.torch as cutlass_torch import cutlass.utils as utils from cutlass.cute.runtime import from_dlpack @@ -48,6 +49,7 @@ A dense GEMM (C = A * B) example for the NVIDIA Ampere architecture using CUTE D This GEMM kernel supports the following features: - Utilizes Ampere's tensor cores for matrix multiply-accumulate (MMA) operations + - Threadblock rasterization to improve data re-use - Supports multi-stage pipeline to overlap computation and memory access - Implements shared memory buffering for epilogue to increase coalesed global memory access @@ -253,6 +255,22 @@ class TensorOpGemm: # grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, l) grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1)) + # Add threadblock rasterization to improve re-use of data + raster_factor = 1 + grid_dim_n = cute.size(grid_dim[1]) + # Thresholds picked so that it doesn't cause too many no-op CTAs + if grid_dim_n > 5: + raster_factor = 8 + elif grid_dim_n > 2: + raster_factor = 4 + elif grid_dim_n > 1: + raster_factor = 2 + rasterization_remap_grid_dim = ( + cute.size(grid_dim[0]) * raster_factor, + (cute.size(grid_dim[1]) + raster_factor - 1) // raster_factor, + cute.size(grid_dim[2]), + ) + self.kernel( mA, mB, @@ -264,9 +282,10 @@ class TensorOpGemm: tiled_copy_B, tiled_copy_C, tiled_mma, + raster_factor, epilogue_op, ).launch( - grid=grid_dim, + grid=rasterization_remap_grid_dim, block=[self.num_threads, 1, 1], smem=smem_size, ) @@ -284,436 +303,445 @@ class TensorOpGemm: tiled_copy_B: cute.TiledCopy, tiled_copy_C: cute.TiledCopy, tiled_mma: cute.TiledMma, + rasterization_factor: cutlass.Int32, epilogue_op: cutlass.Constexpr = lambda x: x, ): # Thread index, block index tidx, _, _ = cute.arch.thread_idx() bidx, bidy, bidz = cute.arch.block_idx() - tiler_coord = (bidx, bidy, None) - - # /////////////////////////////////////////////////////////////////////////////// - # Get the appropriate tiles for this thread block. - # gA: (BLK_M, BLK_N, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N) - # /////////////////////////////////////////////////////////////////////////////// - gA = cute.local_tile( - mA[None, None, bidz], - tiler=self.cta_tiler, - coord=tiler_coord, - proj=(1, None, 1), - ) - gB = cute.local_tile( - mB[None, None, bidz], - tiler=self.cta_tiler, - coord=tiler_coord, - proj=(None, 1, 1), - ) - gC = cute.local_tile( - mC[None, None, bidz], - tiler=self.cta_tiler, - coord=tiler_coord, - proj=(1, 1, None), + grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1)) + offset_tile_x, offset_tile_y = self.raster_tile( + bidx, bidy, rasterization_factor ) + # Early exit if CTA is out of range + if grid_dim[0] <= offset_tile_x or grid_dim[1] <= offset_tile_y: + pass + else: + tiler_coord = (offset_tile_x, offset_tile_y, None) - # By default, if the tensor k mode does not divide into the tile k - # size, then last tiles in the k dimension are irregular. - # Instead, make the first tiles irregular when k is irregular. - # This allows us to handle the irregular tile first to avoid - # checking for this condition within the mainloop. - - # residual_k is a negative number indicating the amount needed to - # shift the pointer by in dimension k - residual_k = cute.size(mA, mode=[1]) - cutlass.Int32(self.bK) * cute.size( - gA, mode=[2] - ) - - # move the pointer of gA/gB in the `-k` direction - gA = cute.domain_offset((0, residual_k, 0), gA) - gB = cute.domain_offset((0, residual_k, 0), gB) - # input is 16B aligned - gA = cute.make_tensor(gA.iterator.align(16), gA.layout) - gB = cute.make_tensor(gB.iterator.align(16), gB.layout) - - # Construct identity layout for sA and sB (mirrors global tensors, - # used for predication only) - mcA = cute.make_identity_tensor(mA.layout.shape) - mcB = cute.make_identity_tensor(mB.layout.shape) - cA = cute.local_tile( - mcA[None, None, bidz], - tiler=self.cta_tiler, - coord=tiler_coord, - proj=(1, None, 1), - ) - cB = cute.local_tile( - mcB[None, None, bidz], - tiler=self.cta_tiler, - coord=tiler_coord, - proj=(None, 1, 1), - ) - - cA = cute.domain_offset((0, residual_k, 0), cA) - cB = cute.domain_offset((0, residual_k, 0), cB) - - # /////////////////////////////////////////////////////////////////////////////// - # Create shared memory buffers and get the appropriate fragments for this thread. - # sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE) - # tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k) - # tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE) - # /////////////////////////////////////////////////////////////////////////////// - # Shared memory buffer - smem = cutlass.utils.SmemAllocator() - - sA = smem.allocate_tensor(mA.element_type, sA_layout, 16) - sB = smem.allocate_tensor(mB.element_type, sB_layout, 16) - sC = cute.make_tensor( - cute.recast_ptr(sA.iterator, dtype=self.c_dtype), sC_layout - ) - - thr_copy_A = tiled_copy_A.get_slice(tidx) - thr_copy_B = tiled_copy_B.get_slice(tidx) - thr_copy_C = tiled_copy_C.get_slice(tidx) - tAgA = thr_copy_A.partition_S(gA) - tAsA = thr_copy_A.partition_D(sA) - tBgB = thr_copy_B.partition_S(gB) - tBsB = thr_copy_B.partition_D(sB) - tCsC_epilogue = thr_copy_C.partition_S(sC) - tCgC_epilogue = thr_copy_C.partition_D(gC) - - # Repeat the partitioning with identity layouts - tAcA = thr_copy_A.partition_S(cA) - tBcB = thr_copy_B.partition_S(cB) - - # /////////////////////////////////////////////////////////////////////////////// - # Predicate: Mark indices that need to copy when problem_shape isn't a multiple - # of tile_shape - # /////////////////////////////////////////////////////////////////////////////// - - # For predication over the tensors A (M/K), B (N/K), and (in the - # epilogue) C (M/N), we will compute it in a fashion similar to an - # outer product. The predication along one of the dimensions is - # evaluated and stored in a predication tensor. Then, the - # predication for the remaining dimension is handled later via an - # if/else branch at the copy. - # For A and B, predication booleans along M/N are stored in a - # predication tensor and along K is handled via a if/else branch. - - # Allocate predicate tensors for M and N. Predication is checked - # at the granularity of a copy atom, so the predicate tensor does not - # need separate booleans for individual elements within a copy - # atom (for example, the elements of tAgA.shape[0][0].) - tApA = cute.make_fragment( - cute.make_layout( - ( - tAgA.shape[0][1], - cute.size(tAgA, mode=[1]), - cute.size(tAgA, mode=[2]), - ), - stride=(cute.size(tAgA, mode=[1]), 1, 0), - ), - cutlass.Boolean, - ) - tBpB = cute.make_fragment( - cute.make_layout( - ( - tBsB.shape[0][1], - cute.size(tBsB, mode=[1]), - cute.size(tBsB, mode=[2]), - ), - stride=(cute.size(tBsB, mode=[1]), 1, 0), - ), - cutlass.Boolean, - ) - # Set predicates for M/N bounds - for rest_v in range(tApA.shape[0]): - for m in range(tApA.shape[1]): - tApA[rest_v, m, 0] = cute.elem_less( - tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0] - ) - for rest_v in range(tBpB.shape[0]): - for n in range(tBpB.shape[1]): - tBpB[rest_v, n, 0] = cute.elem_less( - tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0] - ) - - # /////////////////////////////////////////////////////////////////////////////// - # Prefetch Prologue - # /////////////////////////////////////////////////////////////////////////////// - # Clear the smem tiles to account for predicated off loads - tAsA.fill(0) - tBsB.fill(0) - cute.arch.sync_threads() - # Start async loads for the first k-tile. Here we take care of the k residue - # via if/else check along the k dimension. Because we shifted the identity tensor - # by the residue_k and because the identity tensor is a counting tensor, the - # values of any identity tensor element that is poison is less than -1 - num_smem_stages = cute.size(tAsA, mode=[3]) - k_tile_count = cute.size(tAgA, mode=[3]) - k_tile_index = cutlass.Int32(0) - - for k in range(tApA.shape[2]): - if cute.elem_less(cutlass.Int32(-1), tAcA[0, 0, k, 0][1]): - cute.copy( - tiled_copy_A, - tAgA[None, None, k, k_tile_index], - tAsA[None, None, k, 0], - pred=tApA[None, None, k], - ) - for k in range(tBpB.shape[2]): - if cute.elem_less(cutlass.Int32(-1), tBcB[0, 0, k, 0][1]): - cute.copy( - tiled_copy_B, - tBgB[None, None, k, k_tile_index], - tBsB[None, None, k, 0], - pred=tBpB[None, None, k], - ) - k_tile_index = k_tile_index + 1 - cute.arch.cp_async_commit_group() - - # Start async loads for rest of the k-tiles - for k_tile in range(1, num_smem_stages - 1): - if k_tile == k_tile_count: - tApA.fill(0) - tBpB.fill(0) - cute.copy( - tiled_copy_A, - tAgA[None, None, None, k_tile_index], - tAsA[None, None, None, k_tile], - pred=tApA, + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # gA: (BLK_M, BLK_N, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N) + # /////////////////////////////////////////////////////////////////////////////// + gA = cute.local_tile( + mA[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, None, 1), ) - cute.copy( - tiled_copy_B, - tBgB[None, None, None, k_tile_index], - tBsB[None, None, None, k_tile], - pred=tBpB, + gB = cute.local_tile( + mB[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(None, 1, 1), ) + gC = cute.local_tile( + mC[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, 1, None), + ) + + # By default, if the tensor k mode does not divide into the tile k + # size, then last tiles in the k dimension are irregular. + # Instead, make the first tiles irregular when k is irregular. + # This allows us to handle the irregular tile first to avoid + # checking for this condition within the mainloop. + + # residual_k is a negative number indicating the amount needed to + # shift the pointer by in dimension k + residual_k = cute.size(mA, mode=[1]) - cutlass.Int32(self.bK) * cute.size( + gA, mode=[2] + ) + + # move the pointer of gA/gB in the `-k` direction + gA = cute.domain_offset((0, residual_k, 0), gA) + gB = cute.domain_offset((0, residual_k, 0), gB) + # input is 16B aligned + gA = cute.make_tensor(gA.iterator.align(16), gA.layout) + gB = cute.make_tensor(gB.iterator.align(16), gB.layout) + + # Construct identity layout for sA and sB (mirrors global tensors, + # used for predication only) + mcA = cute.make_identity_tensor(mA.layout.shape) + mcB = cute.make_identity_tensor(mB.layout.shape) + cA = cute.local_tile( + mcA[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, None, 1), + ) + cB = cute.local_tile( + mcB[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(None, 1, 1), + ) + + cA = cute.domain_offset((0, residual_k, 0), cA) + cB = cute.domain_offset((0, residual_k, 0), cB) + + # /////////////////////////////////////////////////////////////////////////////// + # Create shared memory buffers and get the appropriate fragments for this thread. + # sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE) + # tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k) + # tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE) + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory buffer + smem = cutlass.utils.SmemAllocator() + + sA = smem.allocate_tensor(mA.element_type, sA_layout, 16) + sB = smem.allocate_tensor(mB.element_type, sB_layout, 16) + sC = cute.make_tensor( + cute.recast_ptr(sA.iterator, dtype=self.c_dtype), sC_layout + ) + + thr_copy_A = tiled_copy_A.get_slice(tidx) + thr_copy_B = tiled_copy_B.get_slice(tidx) + thr_copy_C = tiled_copy_C.get_slice(tidx) + tAgA = thr_copy_A.partition_S(gA) + tAsA = thr_copy_A.partition_D(sA) + tBgB = thr_copy_B.partition_S(gB) + tBsB = thr_copy_B.partition_D(sB) + tCsC_epilogue = thr_copy_C.partition_S(sC) + tCgC_epilogue = thr_copy_C.partition_D(gC) + + # Repeat the partitioning with identity layouts + tAcA = thr_copy_A.partition_S(cA) + tBcB = thr_copy_B.partition_S(cB) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + + # For predication over the tensors A (M/K), B (N/K), and (in the + # epilogue) C (M/N), we will compute it in a fashion similar to an + # outer product. The predication along one of the dimensions is + # evaluated and stored in a predication tensor. Then, the + # predication for the remaining dimension is handled later via an + # if/else branch at the copy. + # For A and B, predication booleans along M/N are stored in a + # predication tensor and along K is handled via a if/else branch. + + # Allocate predicate tensors for M and N. Predication is checked + # at the granularity of a copy atom, so the predicate tensor does not + # need separate booleans for individual elements within a copy + # atom (for example, the elements of tAgA.shape[0][0].) + tApA = cute.make_fragment( + cute.make_layout( + ( + tAgA.shape[0][1], + cute.size(tAgA, mode=[1]), + cute.size(tAgA, mode=[2]), + ), + stride=(cute.size(tAgA, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + tBpB = cute.make_fragment( + cute.make_layout( + ( + tBsB.shape[0][1], + cute.size(tBsB, mode=[1]), + cute.size(tBsB, mode=[2]), + ), + stride=(cute.size(tBsB, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + # Set predicates for M/N bounds + for rest_v in range(tApA.shape[0]): + for m in range(tApA.shape[1]): + tApA[rest_v, m, 0] = cute.elem_less( + tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0] + ) + for rest_v in range(tBpB.shape[0]): + for n in range(tBpB.shape[1]): + tBpB[rest_v, n, 0] = cute.elem_less( + tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0] + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Prefetch Prologue + # /////////////////////////////////////////////////////////////////////////////// + # Clear the smem tiles to account for predicated off loads + tAsA.fill(0) + tBsB.fill(0) + cute.arch.sync_threads() + # Start async loads for the first k-tile. Here we take care of the k residue + # via if/else check along the k dimension. Because we shifted the identity tensor + # by the residue_k and because the identity tensor is a counting tensor, the + # values of any identity tensor element that is poison is less than -1 + num_smem_stages = cute.size(tAsA, mode=[3]) + k_tile_count = cute.size(tAgA, mode=[3]) + k_tile_index = cutlass.Int32(0) + + for k in range(tApA.shape[2]): + if cute.elem_less(cutlass.Int32(-1), tAcA[0, 0, k, 0][1]): + cute.copy( + tiled_copy_A, + tAgA[None, None, k, k_tile_index], + tAsA[None, None, k, 0], + pred=tApA[None, None, k], + ) + for k in range(tBpB.shape[2]): + if cute.elem_less(cutlass.Int32(-1), tBcB[0, 0, k, 0][1]): + cute.copy( + tiled_copy_B, + tBgB[None, None, k, k_tile_index], + tBsB[None, None, k, 0], + pred=tBpB[None, None, k], + ) k_tile_index = k_tile_index + 1 cute.arch.cp_async_commit_group() - # /////////////////////////////////////////////////////////////////////////////// - # Tile MMA compute thread partitions and allocate accumulators - # /////////////////////////////////////////////////////////////////////////////// - thr_mma = tiled_mma.get_slice(tidx) - tCsA = thr_mma.partition_A(sA) - tCsB = thr_mma.partition_B(sB) - tCsC = thr_mma.partition_C(sC) - tCgC = thr_mma.partition_C(gC) - tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) - tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) - tCrC = tiled_mma.make_fragment_C(tCgC) - # Clear the accumulator - tCrC.fill(0.0) + # Start async loads for rest of the k-tiles + for k_tile in range(1, num_smem_stages - 1): + if k_tile == k_tile_count: + tApA.fill(0) + tBpB.fill(0) + cute.copy( + tiled_copy_A, + tAgA[None, None, None, k_tile_index], + tAsA[None, None, None, k_tile], + pred=tApA, + ) + cute.copy( + tiled_copy_B, + tBgB[None, None, None, k_tile_index], + tBsB[None, None, None, k_tile], + pred=tBpB, + ) + k_tile_index = k_tile_index + 1 + cute.arch.cp_async_commit_group() - # /////////////////////////////////////////////////////////////////////////////// - # Copy Atom A/B retiling - # /////////////////////////////////////////////////////////////////////////////// + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + thr_mma = tiled_mma.get_slice(tidx) + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCsC = thr_mma.partition_C(sC) + tCgC = thr_mma.partition_C(gC) + tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) + tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) + tCrC = tiled_mma.make_fragment_C(tCgC) + # Clear the accumulator + tCrC.fill(0.0) - # Create the copy atoms for the copy from shared memory to register - atom_copy_s2r_A = cute.make_copy_atom( - cute.nvgpu.warp.LdMatrix8x8x16bOp( - self.a_major_mode != utils.LayoutEnum.ROW_MAJOR, 4 - ), - mA.element_type, - ) - atom_copy_s2r_B = cute.make_copy_atom( - cute.nvgpu.warp.LdMatrix8x8x16bOp( - self.b_major_mode != utils.LayoutEnum.ROW_MAJOR, 4 - ), - mB.element_type, - ) + # /////////////////////////////////////////////////////////////////////////////// + # Copy Atom A/B retiling + # /////////////////////////////////////////////////////////////////////////////// - # Creates the tiled copy so that it matches the thread-value layout - # expected by the tiled mma - tiled_copy_s2r_A = cute.make_tiled_copy( - atom_copy_s2r_A, - layout_tv=tiled_mma.tv_layout_A_tiled, - tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)), - ) - tiled_copy_s2r_B = cute.make_tiled_copy( - atom_copy_s2r_B, - layout_tv=tiled_mma.tv_layout_B_tiled, - tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)), - ) - - thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx) - thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx) - tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA) - tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA) - tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB) - tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB) - - # Current pipe index in smem to read from / write to - smem_pipe_read = 0 - smem_pipe_write = num_smem_stages - 1 - - tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read] - tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read] - - # /////////////////////////////////////////////////////////////////////////////// - # PREFETCH register pipeline - # /////////////////////////////////////////////////////////////////////////////// - num_k_block = cute.size(tCrA, mode=[2]) - if num_k_block > 1: - # Wait until our first prefetched tile is loaded in - cute.arch.cp_async_wait_group(num_smem_stages - 2) - cute.arch.sync_threads() - # Prefetch the first k-block rmem from the first k-tile - cute.copy( - tiled_copy_s2r_A, - tCsA_p[None, None, 0], - tCrA_copy_view[None, None, 0], + # Create the copy atoms for the copy from shared memory to register + atom_copy_s2r_A = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp( + self.a_major_mode != utils.LayoutEnum.ROW_MAJOR, 4 + ), + mA.element_type, ) - cute.copy( - tiled_copy_s2r_B, - tCsB_p[None, None, 0], - tCrB_copy_view[None, None, 0], + atom_copy_s2r_B = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp( + self.b_major_mode != utils.LayoutEnum.ROW_MAJOR, 4 + ), + mB.element_type, ) - # /////////////////////////////////////////////////////////////////////////////// - # Mainloop - # 1. Shared memory pipeline (gmem -> smem): - # The default smem pipeline depth is 3, meaning that for shared - # memory buffers, we allocate three times the size described by the - # CTA tiler. We prefetch 2 of these buffers before entering the main - # loop. Considering only the transfer from global memory to shared - # memory, the general structure of the mainloop is: - # (1) copy k-tile from gmem to smem; - # (2) perform gemm computation on k-tile; - # (3) wait for the next copy to finish. - # The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command - # waits for the number of unfinished 'copy' to be <= 1. The advantage - # of this approach is that it allows for simultaneous production - # (i.e., step (1)) and consumption (i.e., step (2)) of smem. - # A common misconception is to prefetch N buffers and rewrite - # the pipeline logic to wait on N-1 pending copies. The disadvantage - # of this approach is that it requires fully consuming a buffer in - # order to open an empty buffer for the next copy. - # 2. Register pipeline (smem -> register): - # Similarly, the register pipeline produces i+1, consumes i, and - # produces i+2... Notably, i and i+1 do not use the same register, - # eliminating dependencies on the same register for better parallelism. - # 3. Combining the smem and register pipelines results in the mainloop. - # /////////////////////////////////////////////////////////////////////////////// - for k_tile in cutlass.range_dynamic(k_tile_count, unroll=1): - for k_block in range(num_k_block): - if k_block == num_k_block - 1: - tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read] - tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read] - cute.arch.cp_async_wait_group(num_smem_stages - 2) - cute.arch.sync_threads() + # Creates the tiled copy so that it matches the thread-value layout + # expected by the tiled mma + tiled_copy_s2r_A = cute.make_tiled_copy( + atom_copy_s2r_A, + layout_tv=tiled_mma.tv_layout_A_tiled, + tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)), + ) + tiled_copy_s2r_B = cute.make_tiled_copy( + atom_copy_s2r_B, + layout_tv=tiled_mma.tv_layout_B_tiled, + tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)), + ) - # Load A, B from shared memory to registers for k_block + 1 - k_block_next = (k_block + 1) % num_k_block # static + thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx) + thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx) + tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA) + tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA) + tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB) + tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB) + + # Current pipe index in smem to read from / write to + smem_pipe_read = 0 + smem_pipe_write = num_smem_stages - 1 + + tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read] + tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read] + + # /////////////////////////////////////////////////////////////////////////////// + # PREFETCH register pipeline + # /////////////////////////////////////////////////////////////////////////////// + num_k_block = cute.size(tCrA, mode=[2]) + if num_k_block > 1: + # Wait until our first prefetched tile is loaded in + cute.arch.cp_async_wait_group(num_smem_stages - 2) + cute.arch.sync_threads() + # Prefetch the first k-block rmem from the first k-tile cute.copy( tiled_copy_s2r_A, - tCsA_p[None, None, k_block_next], - tCrA_copy_view[None, None, k_block_next], + tCsA_p[None, None, 0], + tCrA_copy_view[None, None, 0], ) cute.copy( tiled_copy_s2r_B, - tCsB_p[None, None, k_block_next], - tCrB_copy_view[None, None, k_block_next], + tCsB_p[None, None, 0], + tCrB_copy_view[None, None, 0], ) - # Fetch next A: To better interleave global memory access and compute - # instructions, we intentionally use the sequence: copy A, perform GEMM, - # then copy B. - if k_block == 0: - if k_tile + num_smem_stages - 1 < k_tile_count: - cute.copy( - tiled_copy_A, - tAgA[None, None, None, k_tile_index], - tAsA[None, None, None, smem_pipe_write], - pred=tApA, - ) + # /////////////////////////////////////////////////////////////////////////////// + # Mainloop + # 1. Shared memory pipeline (gmem -> smem): + # The default smem pipeline depth is 3, meaning that for shared + # memory buffers, we allocate three times the size described by the + # CTA tiler. We prefetch 2 of these buffers before entering the main + # loop. Considering only the transfer from global memory to shared + # memory, the general structure of the mainloop is: + # (1) copy k-tile from gmem to smem; + # (2) perform gemm computation on k-tile; + # (3) wait for the next copy to finish. + # The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command + # waits for the number of unfinished 'copy' to be <= 1. The advantage + # of this approach is that it allows for simultaneous production + # (i.e., step (1)) and consumption (i.e., step (2)) of smem. + # A common misconception is to prefetch N buffers and rewrite + # the pipeline logic to wait on N-1 pending copies. The disadvantage + # of this approach is that it requires fully consuming a buffer in + # order to open an empty buffer for the next copy. + # 2. Register pipeline (smem -> register): + # Similarly, the register pipeline produces i+1, consumes i, and + # produces i+2... Notably, i and i+1 do not use the same register, + # eliminating dependencies on the same register for better parallelism. + # 3. Combining the smem and register pipelines results in the mainloop. + # /////////////////////////////////////////////////////////////////////////////// + for k_tile in range(k_tile_count): + for k_block in cutlass.range(num_k_block, unroll_full=True): + if k_block == num_k_block - 1: + tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read] + tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read] + cute.arch.cp_async_wait_group(num_smem_stages - 2) + cute.arch.sync_threads() - # Thread-level register gemm for k_block - cute.gemm( - tiled_mma, - tCrC, - tCrA[None, None, k_block], - tCrB[None, None, k_block], - tCrC, - ) - - # Fetch next B and update smem pipeline read/write - if k_block == 0: - if k_tile + num_smem_stages - 1 < k_tile_count: - cute.copy( - tiled_copy_B, - tBgB[None, None, None, k_tile_index], - tBsB[None, None, None, smem_pipe_write], - pred=tBpB, - ) - k_tile_index = k_tile_index + 1 - cute.arch.cp_async_commit_group() - smem_pipe_write = smem_pipe_read - smem_pipe_read = smem_pipe_read + 1 - if smem_pipe_read == num_smem_stages: - smem_pipe_read = 0 - - # Sync before epilogue - cute.arch.cp_async_wait_group(0) - cute.arch.sync_threads() - - # /////////////////////////////////////////////////////////////////////////////// - # Epilogue with fusion - # /////////////////////////////////////////////////////////////////////////////// - tCrD = cute.make_fragment_like(tCrC, self.c_dtype) - tCrD[None] = epilogue_op(tCrC.load()).to(self.c_dtype) - - # Copy results of D back to shared memory - cute.autovec_copy(tCrD, tCsC) - - # Create counting tensor for C - ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1)) - mcC = cute.make_identity_tensor( - ( - cute.size(ceilM) * self.cta_tiler[0], - cute.size(ceilN) * self.cta_tiler[1], - 1, - ) - ) - cC = cute.local_tile( - mcC[None, None, bidz], - tiler=self.cta_tiler, - coord=tiler_coord, - proj=(1, 1, None), - ) - tCcC = thr_copy_C.partition_S(cC) - - tCrC_epilogue = cute.make_fragment_like(tCsC_epilogue) - # Wait for all writes to shared memory to finish before starting copies - # using the new layouts - cute.arch.sync_threads() - cute.autovec_copy(tCsC_epilogue, tCrC_epilogue) - - # Create predication tensor for m - tCpC = cute.make_fragment( - cute.make_layout( - ( - tCgC_epilogue.shape[0][1], - cute.size(tCgC_epilogue, mode=[1]), - cute.size(tCgC_epilogue, mode=[2]), - ), - stride=(cute.size(tCgC_epilogue, mode=[1]), 1, 0), - ), - cutlass.Boolean, - ) - for rest_v in range(tCpC.shape[0]): - for m in range(tCpC.shape[1]): - tCpC[rest_v, m, 0] = cute.elem_less( - tCcC[(0, rest_v), m, 0][0], mC.shape[0] - ) - - # Copy to global memory using better vectorization - for rest_v in range(tCpC.shape[0]): - for n in range(tCpC.shape[2]): - if cute.elem_less(tCcC[(0, rest_v), 0, n][1], mC.shape[1]): + # Load A, B from shared memory to registers for k_block + 1 + k_block_next = (k_block + 1) % num_k_block # static cute.copy( - tiled_copy_C, - tCrC_epilogue[None, None, n], - tCgC_epilogue[None, None, n], - pred=tCpC[None, None, n], + tiled_copy_s2r_A, + tCsA_p[None, None, k_block_next], + tCrA_copy_view[None, None, k_block_next], ) + cute.copy( + tiled_copy_s2r_B, + tCsB_p[None, None, k_block_next], + tCrB_copy_view[None, None, k_block_next], + ) + + # Fetch next A: To better interleave global memory access and compute + # instructions, we intentionally use the sequence: copy A, perform GEMM, + # then copy B. + if k_block == 0: + if k_tile + num_smem_stages - 1 < k_tile_count: + cute.copy( + tiled_copy_A, + tAgA[None, None, None, k_tile_index], + tAsA[None, None, None, smem_pipe_write], + pred=tApA, + ) + + # Thread-level register gemm for k_block + cute.gemm( + tiled_mma, + tCrC, + tCrA[None, None, k_block], + tCrB[None, None, k_block], + tCrC, + ) + + # Fetch next B and update smem pipeline read/write + if k_block == 0: + if k_tile + num_smem_stages - 1 < k_tile_count: + cute.copy( + tiled_copy_B, + tBgB[None, None, None, k_tile_index], + tBsB[None, None, None, smem_pipe_write], + pred=tBpB, + ) + k_tile_index = k_tile_index + 1 + cute.arch.cp_async_commit_group() + smem_pipe_write = smem_pipe_read + smem_pipe_read = smem_pipe_read + 1 + if smem_pipe_read == num_smem_stages: + smem_pipe_read = 0 + + # Sync before epilogue + cute.arch.cp_async_wait_group(0) + cute.arch.sync_threads() + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue with fusion + # /////////////////////////////////////////////////////////////////////////////// + tCrD = cute.make_fragment_like(tCrC, self.c_dtype) + tCrD[None] = epilogue_op(tCrC.load()).to(self.c_dtype) + + # Copy results of D back to shared memory + cute.autovec_copy(tCrD, tCsC) + + # Create counting tensor for C + ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1)) + mcC = cute.make_identity_tensor( + ( + cute.size(ceilM) * self.cta_tiler[0], + cute.size(ceilN) * self.cta_tiler[1], + 1, + ) + ) + cC = cute.local_tile( + mcC[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, 1, None), + ) + tCcC = thr_copy_C.partition_S(cC) + + tCrC_epilogue = cute.make_fragment_like(tCsC_epilogue) + # Wait for all writes to shared memory to finish before starting copies + # using the new layouts + cute.arch.sync_threads() + cute.autovec_copy(tCsC_epilogue, tCrC_epilogue) + + # Create predication tensor for m + tCpC = cute.make_fragment( + cute.make_layout( + ( + tCgC_epilogue.shape[0][1], + cute.size(tCgC_epilogue, mode=[1]), + cute.size(tCgC_epilogue, mode=[2]), + ), + stride=(cute.size(tCgC_epilogue, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + for rest_v in range(tCpC.shape[0]): + for m in range(tCpC.shape[1]): + tCpC[rest_v, m, 0] = cute.elem_less( + tCcC[(0, rest_v), m, 0][0], mC.shape[0] + ) + + # Copy to global memory using better vectorization + for rest_v in range(tCpC.shape[0]): + for n in range(tCpC.shape[2]): + if cute.elem_less(tCcC[(0, rest_v), 0, n][1], mC.shape[1]): + cute.copy( + tiled_copy_C, + tCrC_epilogue[None, None, n], + tCgC_epilogue[None, None, n], + pred=tCpC[None, None, n], + ) return def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler): @@ -811,6 +839,11 @@ class TensorOpGemm: tiler_mn, layout_tv = cute.make_layout_tv(thread_layout, value_layout) return cute.make_tiled_copy(atom_copy, layout_tv, tiler_mn) + def raster_tile(self, i, j, f): + new_i = i // f + new_j = (i % f) + (j * f) + return (new_i, new_j) + def run_tensor_op_gemm( a_major: str, @@ -892,15 +925,18 @@ def run_tensor_op_gemm( print("Executing GEMM kernel...") - # Warmup - for _ in range(warmup_iterations): - gemm(a_tensor, b_tensor, c_tensor) + avg_time_us = testing.benchmark( + gemm, + kernel_arguments=testing.JitArguments(a_tensor, b_tensor, c_tensor), + warmup_iterations=warmup_iterations, + profiling_iterations=iterations, + use_cuda_graphs=False, + ) - # Execute the kernel - for _ in range(iterations): - gemm(a_tensor, b_tensor, c_tensor) + print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms") if not skip_ref_check: + gemm(a_tensor, b_tensor, c_tensor) print("Verifying results...") torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05) print("Results verified successfully!") diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm.py b/examples/python/CuTeDSL/blackwell/dense_gemm.py index 89696c8a..77c5c923 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm.py @@ -35,6 +35,7 @@ import torch import cutlass import cutlass.cute as cute import cutlass.utils as utils +import cutlass.pipeline as pipeline from cutlass.cute.nvgpu import cpasync, tcgen05 import cutlass.torch as cutlass_torch import cutlass.utils.blackwell_helpers as sm100_utils @@ -211,7 +212,7 @@ class DenseGemmKernel: self.occupancy = 1 self.threads_per_cta = 128 - self.num_smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"] + self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"] def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs @@ -283,7 +284,7 @@ class DenseGemmKernel: self.epi_tile, self.c_dtype, self.c_layout, - self.num_smem_capacity, + self.smem_capacity, self.occupancy, self.use_tma_store, ) @@ -308,7 +309,7 @@ class DenseGemmKernel: self.epi_tile, self.num_c_stage, ) - if cutlass.const_expr(self.use_tma_store) + if self.use_tma_store else None ) @@ -372,9 +373,11 @@ class DenseGemmKernel: atom_thr_size = cute.size(tiled_mma.thr_id.shape) # Setup TMA load for A - a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast) + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) - tma_atom_a, tma_tensor_a = cute.nvgpu.make_tma_tile_atom_A( + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( a_op, a, a_smem_layout, @@ -387,9 +390,11 @@ class DenseGemmKernel: ) # Setup TMA load for B - b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast) + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) - tma_atom_b, tma_tensor_b = cute.nvgpu.make_tma_tile_atom_B( + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( b_op, b, b_smem_layout, @@ -413,7 +418,7 @@ class DenseGemmKernel: cute.make_identity_layout(c.shape), self.epi_tile ) epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) - tma_atom_c, tma_tensor_c = cpasync.make_tma_tile_atom( + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem_layout, @@ -426,9 +431,7 @@ class DenseGemmKernel: self.buffer_align_bytes = 1024 c_smem_size = ( - cute.cosize(self.c_smem_layout_staged.outer) - if cutlass.const_expr(self.use_tma_store) - else 0 + cute.cosize(self.c_smem_layout_staged.outer) if self.use_tma_store else 0 ) # Define shared storage for kernel @@ -472,7 +475,7 @@ class DenseGemmKernel: tma_atom_b, tma_tensor_b, tma_atom_c, - tma_tensor_c if cutlass.const_expr(self.use_tma_store) else c, + tma_tensor_c if self.use_tma_store else c, self.cluster_layout_vmnk, self.a_smem_layout_staged, self.b_smem_layout_staged, @@ -556,12 +559,12 @@ class DenseGemmKernel: tmem_holding_buf = storage.tmem_holding_buf # Initialize mainloop ab_pipeline (barrier) and states - ab_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread) + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 - ab_pipeline_consumer_group = utils.CooperativeGroup( - utils.Agent.Thread, num_tma_producer + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer ) - ab_pipeline = utils.PipelineTmaUmma.create( + ab_pipeline = pipeline.PipelineTmaUmma.create( barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), num_stages=self.num_ab_stage, producer_group=ab_pipeline_producer_group, @@ -569,30 +572,30 @@ class DenseGemmKernel: tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cluster_layout_vmnk, ) - ab_producer_state = utils.make_pipeline_state( - utils.PipelineUserType.Producer, self.num_ab_stage + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage ) - ab_consumer_state = utils.make_pipeline_state( - utils.PipelineUserType.Consumer, self.num_ab_stage + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage ) # Initialize acc_pipeline (barrier) and states - acc_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread) - acc_pipeline_consumer_group = utils.CooperativeGroup( - utils.Agent.Thread, self.threads_per_cta, self.threads_per_cta + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta ) - acc_pipeline = utils.PipelineUmmaAsync.create( + acc_pipeline = pipeline.PipelineUmmaAsync.create( barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), num_stages=self.num_acc_stage, producer_group=acc_pipeline_producer_group, consumer_group=acc_pipeline_consumer_group, cta_layout_vmnk=cluster_layout_vmnk, ) - acc_producer_state = utils.make_pipeline_state( - utils.PipelineUserType.Producer, self.num_acc_stage + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage ) - acc_consumer_state = utils.make_pipeline_state( - utils.PipelineUserType.Consumer, self.num_acc_stage + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage ) # Tensor memory dealloc barrier init @@ -600,7 +603,7 @@ class DenseGemmKernel: if warp_idx == 0: num_tmem_dealloc_threads = 32 with cute.arch.elect_one(): - cute.arch.mbarrier_init_arrive_cnt( + cute.arch.mbarrier_init( tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads ) cute.arch.mbarrier_init_fence() @@ -617,7 +620,7 @@ class DenseGemmKernel: storage.sC.get_tensor( c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner ) - if cutlass.const_expr(self.use_tma_store) + if self.use_tma_store else None ) # (MMA, MMA_M, MMA_K, STAGE) @@ -634,7 +637,7 @@ class DenseGemmKernel: # a_full_mcast_mask = None b_full_mcast_mask = None - if self.is_a_mcast or self.is_b_mcast or use_2cta_instrs: + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): a_full_mcast_mask = cpasync.create_tma_multicast_mask( cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 ) @@ -645,15 +648,15 @@ class DenseGemmKernel: # # Local_tile partition global tensors # - # (bM, bK, loopM, loopK, loopL) + # (bM, bK, RestM, RestK, RestL) gA_mkl = cute.local_tile( mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) ) - # (bN, bK, loopN, loopK, loopL) + # (bN, bK, RestN, RestK, RestL) gB_nkl = cute.local_tile( mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) ) - # (bM, bN, loopM, loopN, loopL) + # (bM, bN, RestM, RestN, RestL) gC_mnl = cute.local_tile( mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) ) @@ -663,11 +666,11 @@ class DenseGemmKernel: # Partition global tensor for TiledMMA_A/B/C # thr_mma = tiled_mma.get_slice(mma_tile_coord_v) - # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) tCgA = thr_mma.partition_A(gA_mkl) - # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) tCgB = thr_mma.partition_B(gB_nkl) - # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) tCgC = thr_mma.partition_C(gC_mnl) # @@ -678,7 +681,7 @@ class DenseGemmKernel: cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape ) # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), loopM, loopK, loopL) + # ((atom_v, rest_v), RestM, RestK, RestL) tAsA, tAgA = cpasync.tma_partition( tma_atom_a, block_in_cluster_coord_vmnk[2], @@ -691,7 +694,7 @@ class DenseGemmKernel: cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape ) # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), loopN, loopK, loopL) + # ((atom_v, rest_v), RestN, RestK, RestL) tBsB, tBgB = cpasync.tma_partition( tma_atom_b, block_in_cluster_coord_vmnk[1], @@ -771,9 +774,9 @@ class DenseGemmKernel: # # Slice to per mma tile index # - # ((atom_v, rest_v), loopK) + # ((atom_v, rest_v), RestK) tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), loopK) + # ((atom_v, rest_v), RestK) tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] if cutlass.const_expr(self.use_tma_store): # ((ATOM_V, REST_V), EPI_M, EPI_N) @@ -797,7 +800,7 @@ class DenseGemmKernel: # # Prefetch TMA load A/B # - for prefetch_idx in cutlass.range_dynamic(prefetch_k_block_cnt, unroll=1): + for prefetch_idx in cutlass.range(prefetch_k_block_cnt, unroll=1): # Conditionally wait for AB buffer empty ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) @@ -833,7 +836,7 @@ class DenseGemmKernel: # # MMA mainloop # - for k_block in cutlass.range_dynamic(0, k_block_cnt, 1, unroll=1): + for k_block in range(k_block_cnt): # Conditionally wait for AB buffer empty ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) @@ -860,7 +863,7 @@ class DenseGemmKernel: # tCtAcc += tCrA * tCrB num_kphases = cute.size(tCrA, mode=[2]) - for kphase_idx in range(num_kphases): + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): kphase_coord = (None, None, kphase_idx, ab_consumer_state.index) cute.gemm( @@ -917,10 +920,10 @@ class DenseGemmKernel: c_pipeline = None if cutlass.const_expr(self.use_tma_store): # Initialize tma store c_pipeline - c_producer_group = utils.CooperativeGroup( - utils.Agent.Thread, self.threads_per_cta, self.threads_per_cta + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta ) - c_pipeline = utils.PipelineTmaStore.create( + c_pipeline = pipeline.PipelineTmaStore.create( num_stages=self.num_c_stage, producer_group=c_producer_group, ) @@ -929,7 +932,7 @@ class DenseGemmKernel: # Store accumulator to global memory in subtiles # subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) - for subtile_idx in cutlass.range_dynamic(subtile_cnt): + for subtile_idx in range(subtile_cnt): # # Load accumulator from tensor memory buffer to register # @@ -1007,7 +1010,7 @@ class DenseGemmKernel: # if warp_idx == 0: # Reverse prefetch_k_block_cnt times to next available buffer - for i in cutlass.range_dynamic(prefetch_k_block_cnt): + for i in range(prefetch_k_block_cnt): ab_producer_state.reverse() ab_pipeline.producer_tail(ab_producer_state) return @@ -1063,11 +1066,11 @@ class DenseGemmKernel: # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) - # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) gC_mnl_epi = cute.flat_divide( gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile ) - # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) # (T2R, T2R_M, T2R_N) tTR_rAcc = cute.make_fragment( @@ -1149,7 +1152,7 @@ class DenseGemmKernel: - tTR_gC: The partitioned global tensor C :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] """ - # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) gC_epi = cute.flat_divide( gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile ) @@ -1158,7 +1161,7 @@ class DenseGemmKernel: sC_for_tma_partition = cute.group_modes(sC, 0, 2) gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) # ((ATOM_V, REST_V), EPI_M, EPI_N) - # ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) bSG_sC, bSG_gC = cpasync.tma_partition( tma_atom_c, 0, @@ -1169,7 +1172,7 @@ class DenseGemmKernel: return tma_atom_c, bSG_sC, bSG_gC else: tiled_copy_t2r = atom - # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) tTR_gC = thr_copy_t2r.partition_D(gC_epi) # (T2R, T2R_M, T2R_N) @@ -1188,7 +1191,7 @@ class DenseGemmKernel: epi_tile: cute.Tile, c_dtype: Type[cutlass.Numeric], c_layout: utils.LayoutEnum, - num_smem_capacity: int, + smem_capacity: int, occupancy: int, use_tma_store: bool, ) -> Tuple[int, int, int]: @@ -1208,8 +1211,8 @@ class DenseGemmKernel: :type c_dtype: type[cutlass.Numeric] :param c_layout: Layout enum of operand C in global memory. :type c_layout: utils.LayoutEnum - :param num_smem_capacity: Total available shared memory capacity in bytes. - :type num_smem_capacity: int + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int :param occupancy: Target number of CTAs per SM (occupancy). :type occupancy: int :param use_tma_store: Whether TMA store is enabled. @@ -1263,7 +1266,7 @@ class DenseGemmKernel: # Subtract reserved bytes and initial C stages bytes # Divide remaining by bytes needed per A/B stage num_ab_stage = ( - num_smem_capacity - (occupancy + 1) * (mbar_helpers_bytes + c_bytes) + smem_capacity - (occupancy + 1) * (mbar_helpers_bytes + c_bytes) ) // ab_bytes_per_stage # Refine epilogue stages: @@ -1271,7 +1274,7 @@ class DenseGemmKernel: # Add remaining unused smem to epilogue if use_tma_store: num_c_stage += ( - num_smem_capacity + smem_capacity - ab_bytes_per_stage * num_ab_stage - (occupancy + 1) * (mbar_helpers_bytes + c_bytes) ) // ((occupancy + 1) * c_bytes_per_stage) @@ -1309,36 +1312,6 @@ class DenseGemmKernel: return grid - @staticmethod - def _get_tma_atom_kind( - atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean - ) -> Union[ - cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp - ]: - """ - Select the appropriate TMA copy atom based on the number of SMs and the multicast flag. - - :param atom_sm_cnt: The number of SMs - :type atom_sm_cnt: cutlass.Int32 - :param mcast: The multicast flag - :type mcast: cutlass.Boolean - - :return: The appropriate TMA copy atom kind - :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp - - :raise ValueError: If the atom_sm_cnt is invalid - """ - if atom_sm_cnt == 2 and mcast: - return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO) - elif atom_sm_cnt == 2 and not mcast: - return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO) - elif atom_sm_cnt == 1 and mcast: - return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE) - elif atom_sm_cnt == 1 and not mcast: - return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) - - raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}") - @staticmethod def _compute_num_tmem_alloc_cols( tiled_mma: cute.TiledMma, mma_tiler: Tuple[int, int, int] diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py b/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py index abc2597d..e577b7a7 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py @@ -37,6 +37,7 @@ import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync, tcgen05 import cutlass.torch as cutlass_torch import cutlass.utils as utils +import cutlass.pipeline as pipeline import cutlass.utils.blackwell_helpers as sm100_utils from cutlass.cute.runtime import from_dlpack @@ -225,7 +226,7 @@ class PersistentDenseGemmKernel: self.cta_sync_bar_id = 0 self.epilog_sync_bar_id = 1 self.tmem_ptr_sync_bar_id = 2 - self.num_smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"] + self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"] def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs @@ -297,7 +298,7 @@ class PersistentDenseGemmKernel: self.epi_tile, self.c_dtype, self.c_layout, - self.num_smem_capacity, + self.smem_capacity, self.occupancy, self.use_tma_store, ) @@ -389,9 +390,11 @@ class PersistentDenseGemmKernel: atom_thr_size = cute.size(tiled_mma.thr_id.shape) # Setup TMA load for A - a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast) + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) - tma_atom_a, tma_tensor_a = cute.nvgpu.make_tma_tile_atom_A( + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( a_op, a, a_smem_layout, @@ -404,9 +407,11 @@ class PersistentDenseGemmKernel: ) # Setup TMA load for B - b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast) + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) - tma_atom_b, tma_tensor_b = cute.nvgpu.make_tma_tile_atom_B( + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( b_op, b, b_smem_layout, @@ -430,7 +435,7 @@ class PersistentDenseGemmKernel: cute.make_identity_layout(c.shape), self.epi_tile ) epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) - tma_atom_c, tma_tensor_c = cpasync.make_tma_tile_atom( + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem_layout, @@ -571,12 +576,12 @@ class PersistentDenseGemmKernel: tmem_holding_buf = storage.tmem_holding_buf # Initialize mainloop ab_pipeline (barrier) and states - ab_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread) + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 - ab_pipeline_consumer_group = utils.CooperativeGroup( - utils.Agent.Thread, num_tma_producer + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer ) - ab_pipeline = utils.PipelineTmaUmma.create( + ab_pipeline = pipeline.PipelineTmaUmma.create( barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), num_stages=self.num_ab_stage, producer_group=ab_pipeline_producer_group, @@ -586,14 +591,14 @@ class PersistentDenseGemmKernel: ) # Initialize acc_pipeline (barrier) and states - acc_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread) + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_acc_consumer_threads = len(self.epilog_warp_id) * ( 2 if use_2cta_instrs else 1 ) - acc_pipeline_consumer_group = utils.CooperativeGroup( - utils.Agent.Thread, num_acc_consumer_threads + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads ) - acc_pipeline = utils.PipelineUmmaAsync.create( + acc_pipeline = pipeline.PipelineUmmaAsync.create( barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), num_stages=self.num_acc_stage, producer_group=acc_pipeline_producer_group, @@ -606,7 +611,7 @@ class PersistentDenseGemmKernel: if warp_idx == self.tma_warp_id: num_tmem_dealloc_threads = 32 with cute.arch.elect_one(): - cute.arch.mbarrier_init_arrive_cnt( + cute.arch.mbarrier_init( tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads ) cute.arch.mbarrier_init_fence() @@ -640,7 +645,7 @@ class PersistentDenseGemmKernel: # a_full_mcast_mask = None b_full_mcast_mask = None - if self.is_a_mcast or self.is_b_mcast or use_2cta_instrs: + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): a_full_mcast_mask = cpasync.create_tma_multicast_mask( cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 ) @@ -651,15 +656,15 @@ class PersistentDenseGemmKernel: # # Local_tile partition global tensors # - # (bM, bK, loopM, loopK, loopL) + # (bM, bK, RestM, RestK, RestL) gA_mkl = cute.local_tile( mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) ) - # (bN, bK, loopN, loopK, loopL) + # (bN, bK, RestN, RestK, RestL) gB_nkl = cute.local_tile( mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) ) - # (bM, bN, loopM, loopN, loopL) + # (bM, bN, RestM, RestN, RestL) gC_mnl = cute.local_tile( mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) ) @@ -669,11 +674,11 @@ class PersistentDenseGemmKernel: # Partition global tensor for TiledMMA_A/B/C # thr_mma = tiled_mma.get_slice(mma_tile_coord_v) - # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) tCgA = thr_mma.partition_A(gA_mkl) - # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) tCgB = thr_mma.partition_B(gB_nkl) - # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) tCgC = thr_mma.partition_C(gC_mnl) # @@ -684,7 +689,7 @@ class PersistentDenseGemmKernel: cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape ) # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), loopM, loopK, loopL) + # ((atom_v, rest_v), RestM, RestK, RestL) tAsA, tAgA = cpasync.tma_partition( tma_atom_a, block_in_cluster_coord_vmnk[2], @@ -697,7 +702,7 @@ class PersistentDenseGemmKernel: cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape ) # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), loopM, loopK, loopL) + # ((atom_v, rest_v), RestM, RestK, RestL) tBsB, tBgB = cpasync.tma_partition( tma_atom_b, block_in_cluster_coord_vmnk[1], @@ -743,12 +748,11 @@ class PersistentDenseGemmKernel: ) work_tile = tile_sched.initial_work_tile_info() - ab_producer_state = utils.make_pipeline_state( - utils.PipelineUserType.Producer, self.num_ab_stage + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage ) while work_tile.is_valid_tile: - # Get tile coord from tile scheduler cur_tile_coord = work_tile.tile_idx mma_tile_coord_mnl = ( @@ -760,11 +764,11 @@ class PersistentDenseGemmKernel: # # Slice to per mma tile index # - # ((atom_v, rest_v), loopK) + # ((atom_v, rest_v), RestK) tAgA_slice = tAgA[ (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) ] - # ((atom_v, rest_v), loopK) + # ((atom_v, rest_v), RestK) tBgB_slice = tBgB[ (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) ] @@ -779,7 +783,7 @@ class PersistentDenseGemmKernel: # # Tma load loop # - for k_block in cutlass.range_dynamic(0, k_block_cnt, 1, unroll=1): + for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1): # Conditionally wait for AB buffer empty ab_pipeline.producer_acquire( ab_producer_state, peek_ab_empty_status @@ -852,15 +856,14 @@ class PersistentDenseGemmKernel: ) work_tile = tile_sched.initial_work_tile_info() - ab_consumer_state = utils.make_pipeline_state( - utils.PipelineUserType.Consumer, self.num_ab_stage + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage ) - acc_producer_state = utils.make_pipeline_state( - utils.PipelineUserType.Producer, self.num_acc_stage + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage ) while work_tile.is_valid_tile: - # Get tile coord from tile scheduler cur_tile_coord = work_tile.tile_idx mma_tile_coord_mnl = ( @@ -895,7 +898,7 @@ class PersistentDenseGemmKernel: # # Mma mainloop # - for k_block in cutlass.range_dynamic(0, k_block_cnt, 1, unroll=1): + for k_block in range(k_block_cnt): if is_leader_cta: # Conditionally wait for AB buffer full ab_pipeline.consumer_wait( @@ -904,7 +907,7 @@ class PersistentDenseGemmKernel: # tCtAcc += tCrA * tCrB num_kphases = cute.size(tCrA, mode=[2]) - for kphase_idx in range(num_kphases): + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): kphase_coord = ( None, None, @@ -989,10 +992,12 @@ class PersistentDenseGemmKernel: # Partition for epilogue # epi_tidx = tidx - tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = ( - self.epilog_tmem_copy_and_partition( - epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs - ) + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs ) tTR_rC = None @@ -1008,16 +1013,20 @@ class PersistentDenseGemmKernel: tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( tiled_copy_t2r, tTR_rC, epi_tidx, sC ) - tma_atom_c, bSG_sC, bSG_gC_partitioned = ( - self.epilog_gmem_copy_and_partition( - epi_tidx, tma_atom_c, tCgC, epi_tile, sC - ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC ) else: - simt_atom, tTR_rC, tTR_gC_partitioned = ( - self.epilog_gmem_copy_and_partition( - epi_tidx, tiled_copy_t2r, tCgC, epi_tile, sC - ) + ( + simt_atom, + tTR_rC, + tTR_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition( + epi_tidx, tiled_copy_t2r, tCgC, epi_tile, sC ) # @@ -1028,25 +1037,24 @@ class PersistentDenseGemmKernel: ) work_tile = tile_sched.initial_work_tile_info() - acc_consumer_state = utils.make_pipeline_state( - utils.PipelineUserType.Consumer, self.num_acc_stage + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage ) c_pipeline = None if cutlass.const_expr(self.use_tma_store): # Threads/warps participating in tma store pipeline - c_producer_group = utils.CooperativeGroup( - utils.Agent.Thread, + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 32 * len(self.epilog_warp_id), ) - c_pipeline = utils.PipelineTmaStore.create( + c_pipeline = pipeline.PipelineTmaStore.create( num_stages=self.num_c_stage, producer_group=c_producer_group, ) while work_tile.is_valid_tile: - # Get tile coord from tile scheduler cur_tile_coord = work_tile.tile_idx mma_tile_coord_mnl = ( @@ -1105,7 +1113,7 @@ class PersistentDenseGemmKernel: # subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt - for subtile_idx in cutlass.range_dynamic(subtile_cnt): + for subtile_idx in cutlass.range(subtile_cnt): # # Load accumulator from tensor memory buffer to register # @@ -1259,11 +1267,11 @@ class PersistentDenseGemmKernel: # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) - # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) gC_mnl_epi = cute.flat_divide( gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile ) - # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) # (T2R, T2R_M, T2R_N) tTR_rAcc = cute.make_fragment( @@ -1346,7 +1354,7 @@ class PersistentDenseGemmKernel: - tTR_gC: The partitioned global tensor C :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] """ - # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) gC_epi = cute.flat_divide( gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile ) @@ -1355,7 +1363,7 @@ class PersistentDenseGemmKernel: sC_for_tma_partition = cute.group_modes(sC, 0, 2) gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) # ((ATOM_V, REST_V), EPI_M, EPI_N) - # ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) bSG_sC, bSG_gC = cpasync.tma_partition( tma_atom_c, 0, @@ -1366,7 +1374,7 @@ class PersistentDenseGemmKernel: return tma_atom_c, bSG_sC, bSG_gC else: tiled_copy_t2r = atom - # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) tTR_gC = thr_copy_t2r.partition_D(gC_epi) # (T2R, T2R_M, T2R_N) @@ -1385,7 +1393,7 @@ class PersistentDenseGemmKernel: epi_tile: cute.Tile, c_dtype: Type[cutlass.Numeric], c_layout: utils.LayoutEnum, - num_smem_capacity: int, + smem_capacity: int, occupancy: int, use_tma_store: bool, ) -> Tuple[int, int, int]: @@ -1405,8 +1413,8 @@ class PersistentDenseGemmKernel: :type c_dtype: type[cutlass.Numeric] :param c_layout: Layout enum of operand C. :type c_layout: utils.LayoutEnum - :param num_smem_capacity: Total available shared memory capacity in bytes. - :type num_smem_capacity: int + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int :param occupancy: Target number of CTAs per SM (occupancy). :type occupancy: int :param use_tma_store: Whether TMA store is enabled. @@ -1461,7 +1469,7 @@ class PersistentDenseGemmKernel: # Subtract reserved bytes and initial C stages bytes # Divide remaining by bytes needed per A/B stage num_ab_stage = ( - num_smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) + smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) ) // ab_bytes_per_stage # Refine epilogue stages: @@ -1469,7 +1477,7 @@ class PersistentDenseGemmKernel: # Add remaining unused smem to epilogue if use_tma_store: num_c_stage += ( - num_smem_capacity + smem_capacity - occupancy * ab_bytes_per_stage * num_ab_stage - occupancy * (mbar_helpers_bytes + c_bytes) ) // (occupancy * c_bytes_per_stage) @@ -1512,36 +1520,6 @@ class PersistentDenseGemmKernel: return tile_sched_params, grid - @staticmethod - def _get_tma_atom_kind( - atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean - ) -> Union[ - cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp - ]: - """ - Select the appropriate TMA copy atom based on the number of SMs and the multicast flag. - - :param atom_sm_cnt: The number of SMs - :type atom_sm_cnt: cutlass.Int32 - :param mcast: The multicast flag - :type mcast: cutlass.Boolean - - :return: The appropriate TMA copy atom kind - :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp - - :raise ValueError: If the atom_sm_cnt is invalid - """ - if atom_sm_cnt == 2 and mcast: - return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO) - elif atom_sm_cnt == 2 and not mcast: - return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO) - elif atom_sm_cnt == 1 and mcast: - return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE) - elif atom_sm_cnt == 1 and not mcast: - return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) - - raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}") - @staticmethod def _compute_num_tmem_alloc_cols( tiled_mma: cute.TiledMma, diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py b/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py new file mode 100644 index 00000000..8f6ab9fb --- /dev/null +++ b/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py @@ -0,0 +1,1852 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Optional, Type, Tuple, Union +import cuda.bindings.driver as cuda + +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack + +""" +A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Blackwell SM100 architecture +using CUTE DSL with compiler generated software pipeline. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions) + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Supports multi-stage pipeline to overlap computation and memory access + +This GEMM works as follows: +1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. +4. Type convert C matrix to output type. +5. Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations, + or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations. +6. Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor: + e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0)) + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +To run this example: + +.. code-block:: bash + + python examples/blackwell/dense_gemm_software_pipeline.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_tma_store --use_2cta_instrs + +The above example command compute batched gemm with M=8192, N=8192, K=8192, +batch_count=1. The Blackwell tcgen05 MMA tile shape used 2 cta with 256x128 +MMA tile and the cluster shape is (2,1). The input, mma accumulator and output +data type are set as fp16, fp32 and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/dense_gemm_software_pipeline.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_tma_store --use_2cta_instrs + +Constraints: +* Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2), + see detailed valid dtype combinations in below DenseGemmKernel class documentation +* A/B tensor must have the same data type +* Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) +* Mma tiler N must be 32-256, step 32 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 if use_2cta_instrs=True +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32, + Float16/BFloat16, and Int8/Uint8/Float8, respectively. +* OOB tiles are not allowed when TMA store is disabled +""" + + +class PipelineStateMinimal: + """ + Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer. + """ + + def __init__(self, count, index, phase): + self.count = count + self.index = index + self.phase = phase + + +class DenseGemmKernel: + """ + This class implements batched matrix multiplication (C = A x B) with support for various data types + and architectural features specific to Blackwell GPUs. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tiler (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether to use Tensor Memory Access (TMA) for storing results + :type use_tma_store: bool + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported A/B data types: + - TFloat32 + - Float16/BFloat16 + - Int8/Uint8 + - Float8E4M3FN/Float8E5M2 + + :note: Supported accumulator data types: + - Float32 (for all floating point A/B data types) + - Float16 (only for fp16 and fp8 A/B data types) + - Int32 (only for uint8/int8 A/B data types) + + :note: Supported C data types: + - Float32 (for float32 and int32 accumulator data types) + - Int32 (for float32 and int32 accumulator data types) + - Float16/BFloat16 (for fp16 and fp8 accumulator data types) + - Int8/Uint8 (for uint8/int8 accumulator data types) + - Float8E4M3FN/Float8E5M2 (for float32 accumulator data types) + + :note: Constraints: + - MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) + - MMA tiler N must be 32-256, step 32 + - Cluster shape M must be multiple of 2 if use_2cta_instrs=True + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = DenseGemmKernel( + ... acc_dtype=cutlass.Float32, + ... use_2cta_instrs=True, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, stream) + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + ): + """Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + 3. Output C tensor store mode: + - use_tma_store: Boolean indicating whether to use Tensor Memory Access (TMA) for storing results. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Use Tensor Memory Access (TMA) or normal store for output C tensor. + :type use_tma_store: bool + """ + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + self.use_tma_store = use_tma_store + + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + self.threads_per_cta = 128 + self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"] + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + if cutlass.const_expr(self.use_tma_store): + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + else: + self.epi_tile = self.cta_tile_shape_mnk[:2] + + # Setup A/B/C stage count in shared memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.smem_capacity, + self.occupancy, + self.use_tma_store, + ) + + # Compute A/B/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.c_smem_layout_staged = ( + sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + if self.use_tma_store + else None + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( + tiled_mma, self.mma_tiler + ) + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes + - Setup TMA load/store atoms and tensors + - Compute grid size + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + :raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if a.element_type is cutlass.Float32 else None + ), + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup store for C + tma_atom_c = None + tma_tensor_c = None + if cutlass.const_expr(self.use_tma_store): + c_cta_v_layout = cute.composition( + cute.make_identity_layout(c.shape), self.epi_tile + ) + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + c_cta_v_layout, + ) + + # Compute grid size + grid = self._compute_grid(c, self.cta_tile_shape_mnk, self.cluster_shape_mn) + + self.buffer_align_bytes = 1024 + + c_smem_size = ( + cute.cosize(self.c_smem_layout_staged.outer) if self.use_tma_store else 0 + ) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + c_smem_size, + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c if self.use_tma_store else c, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + smem=self.shared_storage.size_in_bytes(), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma descriptor + # + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + if cutlass.const_expr(self.use_tma_store): + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coords outside cluster + cta_coord = (bidx, bidy, bidz) + mma_tile_coord_mnl = ( + cta_coord[0] // cute.size(tiled_mma.thr_id.shape), + cta_coord[1], + cta_coord[2], + ) + # Coords inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == 0: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = ( + storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + if self.use_tma_store + else None + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_block_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + + # + # Alloc tensor memory buffer + # + if warp_idx == 0: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, tmem_holding_buf, is_two_cta=use_2cta_instrs + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + cute.arch.barrier() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf + ) + # (MMA, MMA_M, MMA_N) + tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + tiled_copy_t2r, tTR_tAcc, tTR_rAcc = self.epilog_tmem_copy_and_partition( + tidx, tCtAcc, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_r2s = None + simt_atom = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC = None + tTR_gC = None + if cutlass.const_expr(self.use_tma_store): + tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, tidx, sC + ) + tma_atom_c, bSG_sC, bSG_gC = self.epilog_gmem_copy_and_partition( + tidx, tma_atom_c, tCgC, epi_tile, sC + ) + else: + simt_atom, tTR_rC, tTR_gC = self.epilog_gmem_copy_and_partition( + tidx, tiled_copy_t2r, tCgC, epi_tile, sC + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + if cutlass.const_expr(self.use_tma_store): + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC[(None, None, None, *mma_tile_coord_mnl)] + else: + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_gC = tTR_gC[(None, None, None, None, None, *mma_tile_coord_mnl)] + + # /////////////////////////////////////////////////////////////////////////////// + # MAINLOOP + # /////////////////////////////////////////////////////////////////////////////// + prefetch_k_block_cnt = cutlass.min(self.num_ab_stage - 2, k_block_cnt) + if warp_idx == 0: + for k_block in cutlass.range( + k_block_cnt, + pipelining=self.num_ab_stage - 2, + ): + ab_producer_state = PipelineStateMinimal( + k_block, + k_block % self.num_ab_stage, + cutlass.Int32((k_block // self.num_ab_stage) % 2) ^ 1, + ) + + ab_consumer_state = PipelineStateMinimal( + k_block, + k_block % self.num_ab_stage, + cutlass.Int32((k_block // self.num_ab_stage) % 2), + ) + + # wait for AB buffer empty + ab_pipeline.producer_acquire(ab_producer_state) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + + if is_leader_cta: + # Wait for AB buffer full + ab_pipeline.consumer_wait(ab_consumer_state) + + # tCtAcc += tCrA * tCrB + num_kphases = cute.size(tCrA, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx, ab_consumer_state.index) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kphase_coord], + tCrB[kphase_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kphase + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + # Async arrive accumulator buffer full + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + + # + # Epilogue + # + + # Release tensor memory allocation lock + if warp_idx == 0: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + + # Wait for accumulator buffer full + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + if cutlass.const_expr(self.use_tma_store): + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + else: + tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC)) + + c_pipeline = None + if cutlass.const_expr(self.use_tma_store): + # Initialize tma store c_pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + if cutlass.const_expr(self.use_tma_store): + # + # Perform epilogue op on accumulator and convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + c_buffer = subtile_idx % self.num_c_stage + cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + cute.arch.barrier() + + # + # TMA store C to global memory + # + if warp_idx == 0: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure TMA store is completed to recollect C buffer + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + cute.arch.barrier() + else: + # + # Perform epilogue op on accumulator and convert to C type + # + acc_vec = tTR_rAcc.load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tTR_rC.store(acc_vec) + + # + # Store C to global memory + # + cute.copy(simt_atom, tTR_rC, tTR_gC[(None, None, None, subtile_idx)]) + + # + # Dealloc the tensor memory buffer + # + cute.arch.barrier() + if warp_idx == 0: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + + # + # Wait for C store complete + # + if cutlass.const_expr(self.use_tma_store): + c_pipeline.producer_tail() + + # + # Wait A/B buffer empty + # + if warp_idx == 0: + ab_producer_state = PipelineStateMinimal( + k_block_cnt, + k_block_cnt % self.num_ab_stage, + cutlass.Int32((k_block_cnt // self.num_ab_stage) % 2) ^ 1, + ) + ab_pipeline.producer_acquire(ab_producer_state) + return + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy( + copy_atom_r2s, + layout_tv=tiled_copy_t2r.layout_dst_tv_tiled, + tiler_mn=tiled_copy_t2r.tiler_mn, + ) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing either: + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + - For non-TMA store: (simt_atom, tTR_rC, tTR_gC) where: + - simt_atom: The SIMT copy atom + - tTR_rC: The register tensor C + - tTR_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + if cutlass.const_expr(self.use_tma_store): + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + else: + tiled_copy_t2r = atom + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tTR_gC = thr_copy_t2r.partition_D(gC_epi) + # (T2R, T2R_M, T2R_N) + tTR_rC = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype + ) + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) + return simt_atom, tTR_rC, tTR_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + smem_capacity: int, + occupancy: int, + use_tma_store: bool, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tile. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C in global memory. + :type c_layout: utils.LayoutEnum + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + :param use_tma_store: Whether TMA store is enabled. + :type use_tma_store: bool + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, epilogue stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 1 + # Default C stages + num_c_stage = 2 if use_tma_store else 0 + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + c_smem_layout_staged_one = ( + sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + if use_tma_store + else None + ) + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = ( + cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + if use_tma_store + else 0 + ) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + smem_capacity - (occupancy + 1) * (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + if use_tma_store: + num_c_stage += ( + smem_capacity + - ab_bytes_per_stage * num_ab_stage + - (occupancy + 1) * (mbar_helpers_bytes + c_bytes) + ) // ((occupancy + 1) * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + ) -> Tuple[int, int, int]: + """Compute grid shape for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + + :return: Grid shape for kernel launch. + :rtype: tuple[int, int, int] + """ + + cluster_shape_mnl = (*cluster_shape_mn, 1) + + grid = cute.round_up( + ( + cute.ceil_div(c.layout.shape[0], cta_tile_shape_mnk[0]), + cute.ceil_div(c.layout.shape[1], cta_tile_shape_mnk[1]), + c.layout.shape[2], + ), + cluster_shape_mnl, + ) + + return grid + + @staticmethod + def _compute_num_tmem_alloc_cols( + tiled_mma: cute.TiledMma, mma_tiler: Tuple[int, int, int] + ) -> int: + """ + Compute the number of tensor memory allocation columns. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler: The shape (M, N, K) of the MMA tile. + :type mma_tiler: tuple[int, int, int] + + :return: The number of tensor memory allocation columns. + :rtype: int + """ + acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + return sm100_utils.get_num_tmem_alloc_cols(tCtAcc_fake) + + @staticmethod + def is_valid_dtypes( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if ab_dtype not in { + cutlass.Float16, + cutlass.BFloat16, + cutlass.TFloat32, + cutlass.Uint8, + cutlass.Int8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if ( + acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32} + or acc_dtype == cutlass.Float16 + and ab_dtype + not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2} + or acc_dtype == cutlass.Int32 + and ab_dtype not in {cutlass.Uint8, cutlass.Int8} + ): + is_valid = False + if ( + acc_dtype == cutlass.Float32 + and c_dtype + not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + } + or acc_dtype == cutlass.Float16 + and c_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + } + or acc_dtype == cutlass.Int32 + and c_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + cutlass.Float32, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + } + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + if mma_tiler_mn[1] not in range(32, 257, 32): + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_epilog_store_option( + use_2cta_instrs: bool, + use_tma_store: bool, + m: int, + n: int, + mma_tiler_mn: Tuple[int, int], + ) -> bool: + """ + Check if the epilogue store option is valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param use_tma_store: Whether to use TMA store + :type use_tma_store: bool + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + + :return: True if the epilogue store option is valid, False otherwise + :rtype: bool + """ + + is_valid = True + # None TMA store version does not have predication, can not support OOB tiles + cta_tile_shape_mn = ( + mma_tiler_mn[0] // (2 if use_2cta_instrs else 1), + mma_tiler_mn[1], + ) + if not use_tma_store: + if not (m % cta_tile_shape_mn[0] == 0 and n % cta_tile_shape_mn[1] == 0): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether to use TMA store + :type use_tma_store: bool + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not DenseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not DenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not DenseGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip invalid epilogue store option + if not DenseGemmKernel.is_valid_epilog_store_option( + use_2cta_instrs, use_tma_store, m, n, mma_tiler_mn + ): + can_implement = False + return can_implement + + +def run_dense_gemm( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_2cta_instrs: bool, + use_tma_store: bool, + tolerance: float, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + measure_launch_overhead=False, +): + """ + Prepare A/B/C tensors, launch GPU kernel, and reference checking. + """ + print(f"Running B100 software pipeline Dense GEMM test with:") + print(f"mnkl: {mnkl}") + print(f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}") + print(f"Use TMA Store: {'True' if use_tma_store else 'False'}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + + # Unpack parameters + m, n, k, l = mnkl + + # Skip unsupported testcase + if not DenseGemmKernel.can_implement( + ab_dtype, + acc_dtype, + c_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {use_tma_store}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Create and permute tensor A/B/C + def create_and_permute_tensor( + l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True + ): + # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) + # else: (l, mode0, mode1) -> (mode0, mode1, l) + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + is_unsigned = dtype in {cutlass.Uint8} + # Temporarily use uint8 as torch does not support fp8 type + torch_dtype = ( + cutlass_torch.dtype(dtype) + if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} + else torch.uint8 + ) + + # Create dtype torch tensor (cpu) + torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 + ), + ) + # Create dtype torch tensor (gpu) + torch_tensor = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + cute_tensor = from_dlpack(torch_tensor, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic( + leading_dim=(0 if is_mode0_major else 1) + ) + cute_tensor = cutlass_torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor + + a_ref, a_tensor, a_torch = create_and_permute_tensor( + l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True + ) + b_ref, b_tensor, b_torch = create_and_permute_tensor( + l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True + ) + c_ref, c_tensor, c_torch = create_and_permute_tensor( + l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True + ) + + # Configure gemm kernel + gemm = DenseGemmKernel( + acc_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + ) + + torch_stream = torch.cuda.Stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + # Compile gemm kernel + compiled_gemm = cute.compile(gemm, a_tensor, b_tensor, c_tensor, stream) + + # Launch GPU kernel + # Warm up + for i in range(warmup_iterations): + compiled_gemm(a_tensor, b_tensor, c_tensor, stream) + # Execution + for i in range(iterations): + compiled_gemm(a_tensor, b_tensor, c_tensor, stream) + + # Compute reference result + if not skip_ref_check: + if ab_dtype in { + cutlass.Int8, + cutlass.Uint8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + ref = torch.einsum("mkl,nkl->mnl", a_ref.cpu(), b_ref.cpu()) + else: + ref = (torch.einsum("mkl,nkl->mnl", a_ref, b_ref)).cpu() + + # Copy gpu result back + gpu_c = c_torch.cpu() + + # Convert ref to c_type + if c_dtype == cutlass.Float32: + ref_c = ref + elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: + # m major: (l, n, m) -> (m, n, l) + # k major: (l, m, n) -> (m, n, l) + permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) + shape = (l, m, n) if c_major == "n" else (l, n, m) + f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch.uint8, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.SKIP, + ).cuda() + # Create dtype cute tensor (gpu) + ref_c_tensor = from_dlpack( + f8_torch_tensor, assumed_align=16 + ).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + ref_c_tensor.element_type = c_dtype + ref_c_tensor = cutlass_torch.convert_cute_tensor( + ref, + ref_c_tensor, + c_dtype, + is_dynamic_layout=True, + ) + + ref_c = f8_torch_tensor.cpu() + else: + ref_c = ref.to(cutlass_torch.dtype(c_dtype)) + + # Reference checking ref_c and gpu_c + torch.testing.assert_close( + gpu_c, + ref_c, + atol=tolerance, + rtol=1e-05, + ) + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + # or: return tuple([int(x.strip()) for x in s.split(",")]) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of MxNxKxL GEMM on Blackwell." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(256, 256, 512, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tiler (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.TFloat32) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--use_tma_store", action="store_true", help="Use tma store or not" + ) + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument("--iterations", type=int, default=1, help="Iterations") + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run_dense_gemm( + args.mnkl, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.use_2cta_instrs, + args.use_tma_store, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + ) + print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/fmha.py b/examples/python/CuTeDSL/blackwell/fmha.py index 144ba01b..ce4cafe9 100644 --- a/examples/python/CuTeDSL/blackwell/fmha.py +++ b/examples/python/CuTeDSL/blackwell/fmha.py @@ -40,9 +40,11 @@ import cutlass import cutlass.cute as cute import cutlass.cute.nvgpu.tcgen05 as tcgen05 import cutlass.utils as utils +import cutlass.pipeline as pipeline import cutlass.torch as cutlass_torch import cutlass.utils.blackwell_helpers as sm100_utils from cutlass.cute.runtime import from_dlpack +from cutlass.cute.typing import Int32, Int64, Float32, Boolean """ A fused multi-head attention (FMHA) example for the NVIDIA Blackwell SM100 architecture using CUTE DSL @@ -130,13 +132,12 @@ def create_fmha_static_tile_scheduler_params( ) -> FmhaStaticTileSchedulerParams: return FmhaStaticTileSchedulerParams(is_persistent, problem_shape_mbh) - class FmhaStaticTileScheduler: def __init__( self, params: FmhaStaticTileSchedulerParams, - current_work_linear_idx: cutlass.Int32, + current_work_linear_idx: Int32, blk_coord: cute.Coord, grid_shape: cute.Shape, *, @@ -178,6 +179,14 @@ class FmhaStaticTileScheduler: else: return params.problem_shape_mbh + @staticmethod + def check_valid_work_for_seqlen_q( + q_tiler: int, + current_idx: Int32, + seqlen_q: Int32, + ) -> Boolean: + return current_idx * q_tiler < seqlen_q + def get_current_work(self, *, loc=None, ip=None) -> utils.WorkTileInfo: is_valid = ( self._current_work_linear_idx < self._num_blocks @@ -243,118 +252,6 @@ class MaskType(enum.Enum): RESIDUAL_MASK = enum.auto() CAUSAL_MASK = enum.auto() - -class FusedMask: - def __init__( - self, - mask_type: MaskType, - seq_len_k: cutlass.Int32, - *, - loc=None, - ip=None, - ): - self._mask_type = mask_type - self._seq_len_k = seq_len_k - self._loc = loc - self._ip = ip - - def get_trip_count( - self, - blk_coord: cute.Coord, - tile_shape: cute.Shape, - ) -> cutlass.Int32: - result = 0 - if ( - self._mask_type == MaskType.NO_MASK - or self._mask_type == MaskType.RESIDUAL_MASK - ): - result = cute.ceil_div(self._seq_len_k, tile_shape[1]) - elif self._mask_type == MaskType.CAUSAL_MASK: - max_blocks_k = cute.ceil_div(self._seq_len_k, tile_shape[1]) - max_blocks_q = cute.ceil_div( - (blk_coord[0] + 1) * tile_shape[0], tile_shape[1] - ) - result = cutlass.min(max_blocks_k, max_blocks_q) - return result - - @cute.jit - def get_masked_trip_count( - self, - blk_coord: cute.Coord, - tile_shape: cute.Shape, - ) -> cutlass.Int32: - result = 0 - if self._mask_type == MaskType.NO_MASK: - result = 0 - elif self._mask_type == MaskType.RESIDUAL_MASK: - if self._seq_len_k % tile_shape[1] != 0: - result = 1 - else: - result = 0 - elif self._mask_type == MaskType.CAUSAL_MASK: - result = cute.ceil_div(tile_shape[0], tile_shape[1]) - return result - - @cute.jit - def get_unmasked_trip_count( - self, - blk_coord: cute.Coord, - tile_shape: cute.Shape, - ) -> cutlass.Int32: - result = 0 - if self._mask_type == MaskType.NO_MASK: - result = self.get_trip_count(blk_coord, tile_shape) - elif self._mask_type == MaskType.RESIDUAL_MASK: - if self._seq_len_k % tile_shape[1] != 0: - result = self.get_trip_count(blk_coord, tile_shape) - 1 - else: - result = self.get_trip_count(blk_coord, tile_shape) - elif self._mask_type == MaskType.CAUSAL_MASK: - result = self.get_trip_count( - blk_coord, tile_shape - ) - self.get_masked_trip_count(blk_coord, tile_shape) - return result - - @cute.jit - def apply_mask( - self, - acc_qk: cute.Tensor, - index_qk: cute.Tensor, - ): - if self._mask_type == MaskType.RESIDUAL_MASK: - for i in range(cute.size(acc_qk)): - pos = index_qk[i] - if pos[1] >= self._seq_len_k: - acc_qk[i] = -cutlass.Float32.inf - elif self._mask_type == MaskType.CAUSAL_MASK: - for i in range(cute.size(acc_qk)): - pos = index_qk[i] - if pos[0] < pos[1] or pos[1] >= self._seq_len_k: - acc_qk[i] = -cutlass.Float32.inf - - def __extract_mlir_values__(self): - values, self._values_pos = [], [] - for obj in [self._mask_type, self._seq_len_k]: - obj_values = cutlass.extract_mlir_values(obj) - values += obj_values - self._values_pos.append(len(obj_values)) - return values - - def __new_from_mlir_values__(self, values): - obj_list = [] - for obj, n_items in zip([self._mask_type, self._seq_len_k], self._values_pos): - obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) - values = values[n_items:] - return FusedMask(*(tuple(obj_list)), loc=self._loc) - - -def create_fused_mask( - mask_type: MaskType, - seq_len_k: cutlass.Int32, -) -> FusedMask: - return FusedMask(mask_type, seq_len_k) - - class BlackwellFusedMultiHeadAttentionForward: def __init__( self, @@ -409,7 +306,6 @@ class BlackwellFusedMultiHeadAttentionForward: self.cluster_shape_mn = (1, 1) self.is_persistent = is_persistent self.mask_type = mask_type - self.softmax0_warp_ids = (0, 1, 2, 3) self.softmax1_warp_ids = (4, 5, 6, 7) self.correction_warp_ids = (8, 9, 10, 11) @@ -481,12 +377,15 @@ class BlackwellFusedMultiHeadAttentionForward: @cute.jit def __call__( self, - q: cute.Tensor, - k: cute.Tensor, - v: cute.Tensor, - o: cute.Tensor, - scale_softmax_log2: cutlass.Float32, - scale_output: cutlass.Float32, + q_iter: cute.Pointer, + k_iter: cute.Pointer, + v_iter: cute.Pointer, + o_iter: cute.Pointer, + problem_size: Tuple[Int32, Int32, Int32, Int32, Int32, Int32], + cum_seqlen_q: cute.Tensor | None, + cum_seqlen_k: cute.Tensor | None, + scale_softmax_log2: Float32, + scale_output: Float32, stream: cuda.CUstream, ): """Execute the Fused Multi-Head Attention operation on the provided tensors. @@ -502,23 +401,62 @@ class BlackwellFusedMultiHeadAttentionForward: 5. Grid and work scheduling computation 6. Kernel launch with appropriate parameters - :param q: The query tensor with shape [seq_len_q, d_head, h_q, b] - :type q: cute.Tensor - :param k: The key tensor with shape [seq_len_k, d_head, h_k, b] - :type k: cute.Tensor - :param v: The value tensor with shape [seq_len_k, d_head, h_v, b] - :type v: cute.Tensor - :param o: The output tensor with shape [seq_len_q, d_head, h_q, b] - :type o: cute.Tensor + :param q_iter: The query tensor pointer + :type q_iter: cute.Pointer + :param k_iter: The key tensor pointer + :type k_iter: cute.Pointer + :param v_iter: The value tensor pointer + :type v_iter: cute.Pointer + :param o_iter: The output tensor pointer + :type o_iter: cute.Pointer + :param problem_size: The problem size with shape [b, s_q, s_k, h_q, h_k, d]. If cum_seqlen_q or cum_seqlen_k is not None, s_q and s_k are the max of the cumulative sequence length respectively. + :type problem_size: Tuple[Int32, Int32, Int32, Int32, Int32, Int32] + :param cum_seqlen_q: The cumulative sequence length tensor for query + :type cum_seqlen_q: cute.Tensor | None + :param cum_seqlen_k: The cumulative sequence length tensor for key + :type cum_seqlen_k: cute.Tensor | None :param scale_softmax_log2: The log2 scale factor for softmax - :type scale_softmax_log2: cutlass.Float32 + :type scale_softmax_log2: Float32 :param scale_output: The scale factor for the output - :type scale_output: cutlass.Float32 + :type scale_output: Float32 :param stream: The CUDA stream to execute the kernel on :type stream: cuda.CUstream :raises TypeError: If tensor data types don't match or aren't supported :raises RuntimeError: If tensor layouts aren't in supported formats """ + b, s_q, s_k, h_q, h_k, d = problem_size + h_r = h_q // h_k + qo_offset = 0 if cum_seqlen_q is None else -s_q * d * h_r * h_k + kv_offset = 0 if cum_seqlen_k is None else -s_k * d * h_k + b_qo = b if cum_seqlen_q is None else s_q * (1 + b) + b_kv = b if cum_seqlen_k is None else s_k * (1 + b) + stride_b_qo = h_r * h_k * s_q * d if cum_seqlen_q is None else d * h_r * h_k + stride_b_kv = h_k * s_k * d if cum_seqlen_k is None else d * h_k + + # (s, d, ((h_r, h_k), b)) + q_layout = cute.make_layout( + (s_q, d, ((h_r, h_k), b_qo)), + stride=(d * h_r * h_k, 1, ((d * h_k, d), stride_b_qo)), + ) + q = cute.make_tensor(q_iter + qo_offset, q_layout) + # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast + k_layout = cute.make_layout( + (s_k, d, ((h_r, h_k), b_kv)), + stride=(d * h_k, 1, ((0, d), stride_b_kv)), + ) + k = cute.make_tensor(k_iter + kv_offset, k_layout) + # (d, s, ((h_r, h_k), b)), 0-stride for h_r to broadcast + v_layout = cute.make_layout( + (d, s_k, ((h_r, h_k), b_kv)), + stride=(1, d * h_k, ((0, d), stride_b_kv)), + ) + v = cute.make_tensor(v_iter + kv_offset, v_layout) + # (s, d, ((h_r, h_k), b)) + o_layout = cute.make_layout( + (s_q, d, ((h_r, h_k), b_qo)), + stride=(d * h_r * h_k, 1, ((d * h_k, d), stride_b_qo)), + ) + o = cute.make_tensor(o_iter + qo_offset, o_layout) # setup static attributes before smem/grid/tma computation self.q_dtype = q.element_type @@ -526,34 +464,11 @@ class BlackwellFusedMultiHeadAttentionForward: self.v_dtype = v.element_type self.o_dtype = o.element_type - # (s, d, 1, h_k, b) -> (s, d, ((h_r, h_k), b)) - k = cute.make_tensor( - k.iterator, - cute.make_layout( - (k.shape[0], k.shape[1], ((q.shape[2], k.shape[3]), k.shape[4])), - stride=( - k.layout.stride[0], - k.layout.stride[1], - ((0, k.layout.stride[3]), k.layout.stride[4]), - ), - ), + self.tile_sched_params, grid = self._compute_grid( + cute.shape((s_q, d, ((h_r, h_k), b))), + self.cta_tiler, + self.is_persistent, ) - # (s, d, 1, h_k, b) -> (d, s, ((h_r, h_k), b)) - v = cute.make_tensor( - v.iterator, - cute.make_layout( - (v.shape[1], v.shape[0], ((q.shape[2], v.shape[3]), v.shape[4])), - stride=( - v.layout.stride[1], - v.layout.stride[0], - ((0, v.layout.stride[3]), v.layout.stride[4]), - ), - ), - ) - - # (s, d, h_r, h_k, b) -> (s, d, ((h_r, h_k), b)) - q = cute.group_modes(cute.group_modes(q, begin=2, end=4), begin=2, end=4) - o = cute.group_modes(cute.group_modes(o, begin=2, end=4), begin=2, end=4) self.q_major_mode = utils.LayoutEnum.from_tensor(q).mma_major_mode() self.k_major_mode = utils.LayoutEnum.from_tensor(k).mma_major_mode() @@ -604,7 +519,6 @@ class BlackwellFusedMultiHeadAttentionForward: self.epi_tile = self.pv_mma_tiler[:2] - q_smem_layout_staged = sm100_utils.make_smem_layout_a( qk_tiled_mma, self.qk_mma_tiler, @@ -641,7 +555,7 @@ class BlackwellFusedMultiHeadAttentionForward: tma_store_op = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp() q_smem_layout = cute.select(q_smem_layout_staged, mode=[0, 1, 2]) - tma_atom_q, tma_tensor_q = cute.nvgpu.make_tma_tile_atom_A( + tma_atom_q, tma_tensor_q = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, q, q_smem_layout, @@ -652,7 +566,7 @@ class BlackwellFusedMultiHeadAttentionForward: # TMA load for K k_smem_layout = cute.select(k_smem_layout_staged, mode=[0, 1, 2]) - tma_atom_k, tma_tensor_k = cute.nvgpu.make_tma_tile_atom_B( + tma_atom_k, tma_tensor_k = cute.nvgpu.make_tiled_tma_atom_B( tma_load_op, k, k_smem_layout, @@ -662,7 +576,7 @@ class BlackwellFusedMultiHeadAttentionForward: ) # TMA load for V v_smem_layout = cute.select(v_smem_layout_staged, mode=[0, 1, 2]) - tma_atom_v, tma_tensor_v = cute.nvgpu.make_tma_tile_atom_B( + tma_atom_v, tma_tensor_v = cute.nvgpu.make_tiled_tma_atom_B( tma_load_op, v, v_smem_layout, @@ -676,7 +590,7 @@ class BlackwellFusedMultiHeadAttentionForward: ) o_smem_layout = cute.select(o_smem_layout_staged, mode=[0, 1]) - tma_atom_o, tma_tensor_o = cute.nvgpu.cpasync.make_tma_tile_atom( + tma_atom_o, tma_tensor_o = cute.nvgpu.cpasync.make_tiled_tma_atom( tma_store_op, o, o_smem_layout, @@ -688,40 +602,23 @@ class BlackwellFusedMultiHeadAttentionForward: self.tma_copy_q_bytes = q_copy_size self.tma_copy_kv_bytes = k_copy_size - self.tile_sched_params, grid = self._compute_grid( - o, - self.cta_tiler, - self.is_persistent, - ) - @cute.struct class SharedStorage: # Pipeline barriers - load_q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.q_stage * 2] - load_kv_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.kv_stage * 2] - mma_s0_mbar_ptr: cute.struct.MemRange[ - cutlass.Int64, self.mma_softmax_stage * 2 - ] - mma_s1_mbar_ptr: cute.struct.MemRange[ - cutlass.Int64, self.mma_softmax_stage * 2 - ] - s0_corr_mbar_ptr: cute.struct.MemRange[ - cutlass.Int64, self.softmax_corr_stage * 2 - ] - s1_corr_mbar_ptr: cute.struct.MemRange[ - cutlass.Int64, self.softmax_corr_stage * 2 - ] + load_q_mbar_ptr: cute.struct.MemRange[Int64, self.q_stage * 2] + load_kv_mbar_ptr: cute.struct.MemRange[Int64, self.kv_stage * 2] + mma_s0_mbar_ptr: cute.struct.MemRange[Int64, self.mma_softmax_stage * 2] + mma_s1_mbar_ptr: cute.struct.MemRange[Int64, self.mma_softmax_stage * 2] + s0_corr_mbar_ptr: cute.struct.MemRange[Int64, self.softmax_corr_stage * 2] + s1_corr_mbar_ptr: cute.struct.MemRange[Int64, self.softmax_corr_stage * 2] s0_s1_sequence_mbar_ptr: cute.struct.MemRange[ - cutlass.Int64, self.softmax_warpgroup_count + Int64, self.softmax_warpgroup_count ] - corr_epi_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_stage * 2] - mma_corr_mbar_ptr: cute.struct.MemRange[ - cutlass.Int64, self.mma_corr_stage * 2 - ] - max_reg_setting_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1] - tmem_dealloc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1] + corr_epi_mbar_ptr: cute.struct.MemRange[Int64, self.epi_stage * 2] + mma_corr_mbar_ptr: cute.struct.MemRange[Int64, self.mma_corr_stage * 2] + tmem_dealloc_mbar_ptr: cute.struct.MemRange[Int64, 1] # Tmem holding buffer - tmem_holding_buf: cutlass.Int32 + tmem_holding_buf: Int32 # Smem tensors sO: cute.struct.Align[ cute.struct.MemRange[self.o_dtype, cute.cosize(o_smem_layout_staged)], @@ -737,7 +634,6 @@ class BlackwellFusedMultiHeadAttentionForward: ] self.shared_storage = SharedStorage - fused_mask = create_fused_mask(self.mask_type, k.shape[0]) # Launch the kernel synchronously self.kernel( @@ -751,6 +647,8 @@ class BlackwellFusedMultiHeadAttentionForward: tma_tensor_v, tma_atom_o, tma_tensor_o, + cum_seqlen_q, + cum_seqlen_k, scale_softmax_log2, scale_output, q_smem_layout_staged, @@ -759,7 +657,6 @@ class BlackwellFusedMultiHeadAttentionForward: v_smem_layout_staged, o_smem_layout_staged, self.tile_sched_params, - fused_mask, ).launch( grid=grid, block=[self.threads_per_cta, 1, 1], @@ -783,15 +680,16 @@ class BlackwellFusedMultiHeadAttentionForward: mV_dkl: cute.Tensor, tma_atom_o: cute.CopyAtom, mO_qdl: cute.Tensor, - scale_softmax_log2: cutlass.Float32, - scale_output: cutlass.Float32, + cum_seqlen_q: cute.Tensor | None, + cum_seqlen_k: cute.Tensor | None, + scale_softmax_log2: Float32, + scale_output: Float32, q_smem_layout_staged: cute.ComposedLayout, k_smem_layout_staged: cute.ComposedLayout, p_tmem_layout_staged: cute.ComposedLayout, v_smem_layout_staged: cute.ComposedLayout, o_smem_layout_staged: cute.ComposedLayout, tile_sched_params: FmhaStaticTileSchedulerParams, - fused_mask: FusedMask, ): """The device kernel implementation of the Fused Multi-Head Attention. @@ -827,9 +725,9 @@ class BlackwellFusedMultiHeadAttentionForward: :param mO_qdl: Partitioned output tensor :type mO_qdl: cute.Tensor :param scale_softmax_log2: The log2 scale factor for softmax - :type scale_softmax_log2: cutlass.Float32 + :type scale_softmax_log2: Float32 :param scale_output: The scale factor for the output - :type scale_output: cutlass.Float32 + :type scale_output: Float32 :param q_smem_layout_staged: Shared memory layout for query tensor :type q_smem_layout_staged: cute.ComposedLayout :param k_smem_layout_staged: Shared memory layout for key tensor @@ -842,14 +740,21 @@ class BlackwellFusedMultiHeadAttentionForward: :type o_smem_layout_staged: cute.ComposedLayout :param tile_sched_params: Scheduling parameters for work distribution :type tile_sched_params: FmhaStaticTileSchedulerParams - :param fused_mask: Masking configuration (causal/residual/none) - :type fused_mask: FusedMask """ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # coord inside cta tidx, _, _ = cute.arch.thread_idx() + # + # Prefetch tma desc + # + if warp_idx == self.load_warp_id: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_q) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_k) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_v) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_o) + # Alloc smem = utils.SmemAllocator() storage = smem.allocate(self.shared_storage) @@ -881,25 +786,11 @@ class BlackwellFusedMultiHeadAttentionForward: s0_s1_sequence_pipeline = self.make_and_init_si_sequence_pipeline( storage.s0_s1_sequence_mbar_ptr.data_ptr() ) - max_reg_setting_mbar_ptr = storage.max_reg_setting_mbar_ptr.data_ptr() tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() # Correction & Epilogue & tmem barrier init if warp_idx == self.empty_warp_id: - cute.arch.mbarrier_init_arrive_cnt( - max_reg_setting_mbar_ptr, - self.threads_per_warp - * len( - ( - self.empty_warp_id, - self.load_warp_id, - self.mma_warp_id, - self.epilogue_warp_id, - *self.correction_warp_ids, - ) - ), - ) - cute.arch.mbarrier_init_arrive_cnt( + cute.arch.mbarrier_init( tmem_dealloc_mbar_ptr, self.threads_per_warp * len( @@ -930,50 +821,13 @@ class BlackwellFusedMultiHeadAttentionForward: sO = storage.sO.get_tensor( o_smem_layout_staged.outer, swizzle=o_smem_layout_staged.inner ) - - # Local tile partition global tensors - # (bM, bK, loopM, loopK, loopL) need to check - gQ_qdl = cute.flat_divide(mQ_qdl, cute.select(self.qk_mma_tiler, mode=[0, 2])) qk_thr_mma = qk_tiled_mma.get_slice(0) # default 1sm - tSgQ_qdl = qk_thr_mma.partition_A(gQ_qdl) - - tQsQ, tQgQ_qdl = cute.nvgpu.cpasync.tma_partition( - tma_atom_q, - 0, # no multicast - cute.make_layout(1), - cute.group_modes(sQ, 0, 3), - cute.group_modes(tSgQ_qdl, 0, 3), - ) - - gK_kdl = cute.flat_divide(mK_kdl, cute.select(self.qk_mma_tiler, mode=[1, 2])) - tSgK_kdl = qk_thr_mma.partition_B(gK_kdl) - tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( - tma_atom_k, - 0, # no multicast - cute.make_layout(1), - cute.group_modes(sK, 0, 3), - cute.group_modes(tSgK_kdl, 0, 3), - ) - - # (bM, bN, loopM, loopN, loopL) - gV_dkl = cute.flat_divide(mV_dkl, cute.select(self.pv_mma_tiler, mode=[1, 2])) - pv_thr_mma = pv_tiled_mma.get_slice(0) # default 1sm - tSgV_dkl = pv_thr_mma.partition_B(gV_dkl) - tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( - tma_atom_v, - 0, # no multicast - cute.make_layout(1), - cute.group_modes(sV, 0, 3), - cute.group_modes(tSgV_dkl, 0, 3), - ) tSrQ = qk_thr_mma.make_fragment_A(sQ) tSrK = qk_thr_mma.make_fragment_B(sK) tOrV = pv_thr_mma.make_fragment_B(sV) - gO_qdl = cute.flat_divide(mO_qdl, cute.select(self.pv_mma_tiler, mode=[0, 1])) - qk_acc_shape = qk_thr_mma.partition_shape_C( (self.qk_mma_tiler[0], self.qk_mma_tiler[1]) ) @@ -1014,20 +868,18 @@ class BlackwellFusedMultiHeadAttentionForward: # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.empty_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) - cute.arch.mbarrier_arrive(max_reg_setting_mbar_ptr) # /////////////////////////////////////////////////////////////////////////////// # LOAD # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.load_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) - cute.arch.mbarrier_arrive(max_reg_setting_mbar_ptr) - q_producer_state = utils.make_pipeline_state( - utils.PipelineUserType.Producer, self.q_stage + q_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.q_stage ) - kv_producer_state = utils.make_pipeline_state( - utils.PipelineUserType.Producer, self.kv_stage + kv_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.kv_stage ) tile_sched = create_fmha_static_tile_scheduler( @@ -1037,63 +889,120 @@ class BlackwellFusedMultiHeadAttentionForward: while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx - tQgQ = tQgQ_qdl[None, None, 0, curr_block_coord[2]] - tKgK = tKgK_kdl[None, None, 0, curr_block_coord[2]] - tVgV = tVgV_dkl[None, 0, None, curr_block_coord[2]] + batch_coord = curr_block_coord[2][1] + continue_cond = False + cuseqlen_q = Int32(0) + seqlen_q = mQ_qdl.shape[0] - # Q0 - q0_coord = 2 * curr_block_coord[0] - load_q_pipeline.producer_acquire(q_producer_state) - cute.copy( - tma_atom_q, - tQgQ[None, q0_coord], - tQsQ[None, q_producer_state.index], - tma_bar_ptr=load_q_pipeline.producer_get_barrier(q_producer_state), - ) - q_producer_state.advance() + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) + ) - # K0 - kv_coord = 0 # seqlen_kv_loop - load_kv_pipeline.producer_acquire(kv_producer_state) - cute.copy( - tma_atom_k, - tKgK[None, kv_coord], - tKsK[None, kv_producer_state.index], - tma_bar_ptr=load_kv_pipeline.producer_get_barrier( - kv_producer_state - ), - ) - kv_producer_state.advance() + if not continue_cond: + mQ_qdl_ = mQ_qdl + mK_kdl_ = mK_kdl + mV_dkl_ = mV_dkl + seqlen_k = mK_kdl.shape[0] + curr_block_coord_q = curr_block_coord + curr_block_coord_kv = curr_block_coord - # Q1 - q1_coord = q0_coord + 1 - load_q_pipeline.producer_acquire(q_producer_state) - cute.copy( - tma_atom_q, - tQgQ[None, q1_coord], - tQsQ[None, q_producer_state.index], - tma_bar_ptr=load_q_pipeline.producer_get_barrier(q_producer_state), - ) - q_producer_state.advance() + if cutlass.const_expr(cum_seqlen_q is not None): + logical_offset_mQ = ( + mQ_qdl.shape[0] - seqlen_q, + 0, + (0, cuseqlen_q + seqlen_q), + ) + mQ_qdl_ = cute.domain_offset(logical_offset_mQ, mQ_qdl) + curr_block_coord_q = ( + curr_block_coord[0], + curr_block_coord[1], + (curr_block_coord[2][0], Int32(0)), + ) - # V0 - load_kv_pipeline.producer_acquire(kv_producer_state) - cute.copy( - tma_atom_v, - tVgV[None, kv_coord], - tVsV[None, kv_producer_state.index], - tma_bar_ptr=load_kv_pipeline.producer_get_barrier( - kv_producer_state - ), - ) - kv_producer_state.advance() - kv_coord += 1 + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + logical_offset_mK = ( + mK_kdl.shape[0] - seqlen_k, + 0, + (0, cuseqlen_k + seqlen_k), + ) + logical_offset_mV = ( + 0, + mK_kdl.shape[0] - seqlen_k, + (0, cuseqlen_k + seqlen_k), + ) + mK_kdl_ = cute.domain_offset(logical_offset_mK, mK_kdl) + mV_dkl_ = cute.domain_offset(logical_offset_mV, mV_dkl) + curr_block_coord_kv = ( + curr_block_coord[0], + curr_block_coord[1], + (curr_block_coord[2][0], Int32(0)), + ) - seqlen_kv_loop_steps = ( - fused_mask.get_trip_count(curr_block_coord, self.cta_tiler) - 1 - ) - for i in cutlass.range_dynamic(0, seqlen_kv_loop_steps, 1, unroll=1): - # Ki + # Local tile partition global tensors + # (bM, bK, loopM, loopK, loopL) + gQ_qdl = cute.flat_divide( + mQ_qdl_, cute.select(self.qk_mma_tiler, mode=[0, 2]) + ) + tSgQ_qdl = qk_thr_mma.partition_A(gQ_qdl) + tQsQ, tQgQ_qdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_q, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSgQ_qdl, 0, 3), + ) + tQgQ = tQgQ_qdl[None, None, 0, curr_block_coord_q[2]] + + gK_kdl = cute.flat_divide( + mK_kdl_, cute.select(self.qk_mma_tiler, mode=[1, 2]) + ) + tSgK_kdl = qk_thr_mma.partition_B(gK_kdl) + tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_k, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_kdl, 0, 3), + ) + tKgK = tKgK_kdl[None, None, 0, curr_block_coord_kv[2]] + + gV_dkl = cute.flat_divide( + mV_dkl_, cute.select(self.pv_mma_tiler, mode=[1, 2]) + ) + tSgV_dkl = pv_thr_mma.partition_B(gV_dkl) + tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_v, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tSgV_dkl, 0, 3), + ) + tVgV = tVgV_dkl[None, 0, None, curr_block_coord_kv[2]] + + # Q0 + q0_coord = 2 * curr_block_coord_q[0] + load_q_pipeline.producer_acquire(q_producer_state) + cute.copy( + tma_atom_q, + tQgQ[None, q0_coord], + tQsQ[None, q_producer_state.index], + tma_bar_ptr=load_q_pipeline.producer_get_barrier( + q_producer_state + ), + ) + q_producer_state.advance() + + # K0 + kv_coord = 0 # seqlen_kv_loop load_kv_pipeline.producer_acquire(kv_producer_state) cute.copy( tma_atom_k, @@ -1104,7 +1013,21 @@ class BlackwellFusedMultiHeadAttentionForward: ), ) kv_producer_state.advance() - # Vi + + # Q1 + q1_coord = q0_coord + 1 + load_q_pipeline.producer_acquire(q_producer_state) + cute.copy( + tma_atom_q, + tQgQ[None, q1_coord], + tQsQ[None, q_producer_state.index], + tma_bar_ptr=load_q_pipeline.producer_get_barrier( + q_producer_state + ), + ) + q_producer_state.advance() + + # V0 load_kv_pipeline.producer_acquire(kv_producer_state) cute.copy( tma_atom_v, @@ -1116,7 +1039,38 @@ class BlackwellFusedMultiHeadAttentionForward: ) kv_producer_state.advance() kv_coord += 1 - # End of seqlen_kv loop + + seqlen_kv_loop_steps = ( + self.get_trip_count(curr_block_coord, self.cta_tiler, seqlen_k) + - 1 + ) + + for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): + # Ki + load_kv_pipeline.producer_acquire(kv_producer_state) + cute.copy( + tma_atom_k, + tKgK[None, kv_coord], + tKsK[None, kv_producer_state.index], + tma_bar_ptr=load_kv_pipeline.producer_get_barrier( + kv_producer_state + ), + ) + kv_producer_state.advance() + # Vi + load_kv_pipeline.producer_acquire(kv_producer_state) + + cute.copy( + tma_atom_v, + tVgV[None, kv_coord], + tVsV[None, kv_producer_state.index], + tma_bar_ptr=load_kv_pipeline.producer_get_barrier( + kv_producer_state + ), + ) + kv_producer_state.advance() + kv_coord += 1 + # End of seqlen_kv loop tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() @@ -1127,31 +1081,30 @@ class BlackwellFusedMultiHeadAttentionForward: # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) - cute.arch.mbarrier_arrive(max_reg_setting_mbar_ptr) # Alloc tmem buffer - tmem_alloc_cols = cutlass.Int32(self.tmem_alloc_cols) + tmem_alloc_cols = Int32(self.tmem_alloc_cols) cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) cute.arch.barrier( barrier_id=self.tmem_alloc_sync_bar_id, number_of_threads=self.threads_per_warp, ) - mma_q_consumer_state = utils.make_pipeline_state( - utils.PipelineUserType.Consumer, self.q_stage + mma_q_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.q_stage ) - mma_kv_consumer_state = utils.make_pipeline_state( - utils.PipelineUserType.Consumer, self.kv_stage + mma_kv_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.kv_stage ) mma_q_release_state = mma_q_consumer_state.clone() mma_kv_release_state = mma_kv_consumer_state.clone() - mma_s0_producer_state = utils.make_pipeline_state( - utils.PipelineUserType.Producer, self.mma_softmax_stage + mma_s0_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_softmax_stage ) - mma_s1_producer_state = utils.make_pipeline_state( - utils.PipelineUserType.Producer, self.mma_softmax_stage + mma_s1_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_softmax_stage ) - mma_corr_producer_state = utils.make_pipeline_state( - utils.PipelineUserType.Producer, self.mma_corr_stage + mma_corr_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_corr_stage ) tile_sched = create_fmha_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() @@ -1160,230 +1113,258 @@ class BlackwellFusedMultiHeadAttentionForward: while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx - # GEMM_QK00 (Q0 * K0 -> S0) - # 1. wait for Q0 - load_q_pipeline.consumer_wait(mma_q_consumer_state) - tSrQ0 = tSrQ[None, None, None, mma_q_consumer_state.index] - mma_q_consumer_state.advance() - # 2. wait for K0 - load_kv_pipeline.consumer_wait(mma_kv_consumer_state) - tSrK0 = tSrK[None, None, None, mma_kv_consumer_state.index] - mma_kv_consumer_state.advance() - # 3. acquire empty S0 buffer - mma_s0_pipeline.producer_acquire(mma_s0_producer_state) - # 4. gemm - num_kphases = cute.size(tSrQ0, mode=[2]) - for kphase_idx in range(num_kphases): - kphase_coord = (None, None, kphase_idx) - qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) - cute.gemm( - qk_tiled_mma, - tStS0, - tSrQ0[kphase_coord], - tSrK0[kphase_coord], - tStS0, + batch_coord = curr_block_coord[2][1] + continue_cond = False + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) ) - # 5. release S0 - mma_s0_pipeline.producer_commit(mma_s0_producer_state) - mma_s0_producer_state.advance() - # End of GEMM (Q0 * K0 -> S0) - # GEMM_QK10 (Q1 * K0 -> S1), K0 is ready in GEMM_QK00 - # 1. wait for Q1 - load_q_pipeline.consumer_wait(mma_q_consumer_state) - tSrQ1 = tSrQ[None, None, None, mma_q_consumer_state.index] - mma_q_consumer_state.advance() - # 2. acquire empty S1 - mma_s1_pipeline.producer_acquire(mma_s1_producer_state) - # 3. gemm - num_kphases = cute.size(tSrQ1, mode=[2]) - for kphase_idx in range(num_kphases): - kphase_coord = (None, None, kphase_idx) - qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) - cute.gemm( - qk_tiled_mma, - tStS1, - tSrQ1[kphase_coord], - tSrK0[kphase_coord], - tStS1, - ) - # 4. release S1 - mma_s1_pipeline.producer_commit(mma_s1_producer_state) - mma_s1_producer_state.advance() - # 5. release K0 - load_kv_pipeline.consumer_release(mma_kv_release_state) - mma_kv_release_state.advance() - # End of GEMM (Q1 * K0 -> S1) - # Note: Q0 & Q1 are still needed in the seqlen_kv loop - # so we need to release them after the seqlen_kv loop + if not continue_cond: + seqlen_k = mK_kdl.shape[0] + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k - # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop - # 1. wait for V0 - load_kv_pipeline.consumer_wait(mma_kv_consumer_state) - tOrVi = tOrV[None, None, None, mma_kv_consumer_state.index] - mma_kv_consumer_state.advance() - # 2. acquire corrected O0_partial - # Note: acquire corr first to take it out of the critical - # path since softmax takes longer - mma_corr_pipeline.producer_acquire(mma_corr_producer_state) - # 3. acquire P0 - # this acquire returns the ownership of all of S0 to the mma warp - # including the P0 part (inplaced in S0) - mma_s0_pipeline.producer_acquire(mma_s0_producer_state) - # 4. gemm - num_kphases = cute.size(tOrP0, mode=[2]) - for kphase_idx in range(num_kphases): - kphase_coord = (None, None, kphase_idx) - pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) - cute.gemm( - pv_tiled_mma, - tOtO0, - tOrP0[kphase_coord], - tOrVi[kphase_coord], - tOtO0, - ) - # 5. release accumulated O0_partial - mma_corr_pipeline.producer_commit(mma_corr_producer_state) - mma_corr_producer_state.advance() - # End of GEMM_PV00 (P0 * V0 -> O0_partial) - - seqlen_kv_loop_steps = ( - fused_mask.get_trip_count(curr_block_coord, self.cta_tiler) - 1 - ) - # O1 hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate - pv_whether_acc = False - for i in cutlass.range_dynamic(0, seqlen_kv_loop_steps, 1, unroll=1): - # GEMM_QK0i (Q0 * Ki -> S0) - # 1. wait for Ki + # GEMM_QK00 (Q0 * K0 -> S0) + # 1. wait for Q0 + load_q_pipeline.consumer_wait(mma_q_consumer_state) + tSrQ0 = tSrQ[None, None, None, mma_q_consumer_state.index] + mma_q_consumer_state.advance() + # 2. wait for K0 load_kv_pipeline.consumer_wait(mma_kv_consumer_state) - tSrKi = tSrK[None, None, None, mma_kv_consumer_state.index] + tSrK0 = tSrK[None, None, None, mma_kv_consumer_state.index] mma_kv_consumer_state.advance() - # 2. gemm - inner_num_kphases = cute.size(tSrQ0, mode=[2]) - for kphase_idx in range(inner_num_kphases): - kphase_coord = (None, None, kphase_idx) + # 3. acquire empty S0 buffer + mma_s0_pipeline.producer_acquire(mma_s0_producer_state) + # 4. gemm + num_kphases = cute.size(tSrQ0, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord_0 = (None, None, kphase_idx) qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( qk_tiled_mma, tStS0, - tSrQ0[kphase_coord], - tSrKi[kphase_coord], + tSrQ0[kphase_coord_0], + tSrK0[kphase_coord_0], tStS0, ) - # 3. release S0 + # 5. release S0 mma_s0_pipeline.producer_commit(mma_s0_producer_state) mma_s0_producer_state.advance() - # End of GEMM_QK0i (Q0 * Ki -> S0) + # End of GEMM (Q0 * K0 -> S0) - # GEMM_PV1(i-1) (P1 * V(i-1) -> O1_partial), V(i-1) is ready in GEMM_PV0(i-1) - # 1. acquire corrected O1_partial - mma_corr_pipeline.producer_acquire(mma_corr_producer_state) - # 2. acquire P1 + # GEMM_QK10 (Q1 * K0 -> S1), K0 is ready in GEMM_QK00 + # 1. wait for Q1 + load_q_pipeline.consumer_wait(mma_q_consumer_state) + tSrQ1 = tSrQ[None, None, None, mma_q_consumer_state.index] + mma_q_consumer_state.advance() + # 2. acquire empty S1 mma_s1_pipeline.producer_acquire(mma_s1_producer_state) # 3. gemm - inner_num_kphases = cute.size(tOrP0, mode=[2]) - for kphase_idx in range(inner_num_kphases): - kphase_coord = (None, None, kphase_idx) - pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, pv_whether_acc) - cute.gemm( - pv_tiled_mma, - tOtO1, - tOrP1[kphase_coord], - tOrVi[kphase_coord], - tOtO1, - ) - pv_whether_acc = True - # 4. release accumulated O1_partial - mma_corr_pipeline.producer_commit(mma_corr_producer_state) - mma_corr_producer_state.advance() - # 5. release V(i-1) - load_kv_pipeline.consumer_release(mma_kv_release_state) - mma_kv_release_state.advance() - # End of GEMM_PV1(i-1) (P1 * V(i-1) -> O1_partial) - - # GEMM_QK1i (Q1 * Ki -> S1), Q1 is ready in GEMM_QK10; Ki is ready in GEMM_QK0i - # 1. gemm - inner_num_kphases = cute.size(tSrQ1, mode=[2]) - for kphase_idx in range(inner_num_kphases): - kphase_coord = (None, None, kphase_idx) + num_kphases = cute.size(tSrQ1, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord_1 = (None, None, kphase_idx) qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( qk_tiled_mma, tStS1, - tSrQ1[kphase_coord], - tSrKi[kphase_coord], + tSrQ1[kphase_coord_1], + tSrK0[kphase_coord_1], tStS1, ) + # 4. release S1 mma_s1_pipeline.producer_commit(mma_s1_producer_state) mma_s1_producer_state.advance() - # 2. release Ki + # 5. release K0 load_kv_pipeline.consumer_release(mma_kv_release_state) mma_kv_release_state.advance() - # End of GEMM_QK1i (Q1 * Ki -> S1) + # End of GEMM (Q1 * K0 -> S1) + # Note: Q0 & Q1 are still needed in the seqlen_kv loop + # so we need to release them after the seqlen_kv loop - # GEMM_PV0i (P0 * Vi -> O0_partial) - # 1. wait for Vi + # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop + # 1. wait for V0 load_kv_pipeline.consumer_wait(mma_kv_consumer_state) tOrVi = tOrV[None, None, None, mma_kv_consumer_state.index] mma_kv_consumer_state.advance() # 2. acquire corrected O0_partial + # Note: acquire corr first to take it out of the critical + # path since softmax takes longer mma_corr_pipeline.producer_acquire(mma_corr_producer_state) # 3. acquire P0 + # this acquire returns the ownership of all of S0 to the mma warp + # including the P0 part (inplaced in S0) mma_s0_pipeline.producer_acquire(mma_s0_producer_state) # 4. gemm - inner_num_kphases = cute.size(tOrP0, mode=[2]) - for kphase_idx in range(inner_num_kphases): - kphase_coord = (None, None, kphase_idx) - pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + num_kphases = cute.size(tOrP0, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord_2 = (None, None, kphase_idx) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( pv_tiled_mma, tOtO0, - tOrP0[kphase_coord], - tOrVi[kphase_coord], + tOrP0[kphase_coord_2], + tOrVi[kphase_coord_2], tOtO0, ) # 5. release accumulated O0_partial mma_corr_pipeline.producer_commit(mma_corr_producer_state) mma_corr_producer_state.advance() - # End of GEMM_PV0i (P0 * Vi -> O0_partial) - # End of seqlen_kv loop + # End of GEMM_PV00 (P0 * V0 -> O0_partial) - # release Q0 & Q1 - load_q_pipeline.consumer_release(mma_q_release_state) - mma_q_release_state.advance() - load_q_pipeline.consumer_release(mma_q_release_state) - mma_q_release_state.advance() - - # GEMM_PV1(i_end) (P1 * Vi_end -> O1) - # 1. acquire corrected O1_partial - mma_corr_pipeline.producer_acquire(mma_corr_producer_state) - # 2. acquire P1 - mma_s1_pipeline.producer_acquire(mma_s1_producer_state) - # 3. gemm - num_kphases = cute.size(tOrP1, mode=[2]) - for kphase_idx in range(num_kphases): - kphase_coord = (None, None, kphase_idx) - pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) - cute.gemm( - pv_tiled_mma, - tOtO1, - tOrP1[kphase_coord], - tOrVi[kphase_coord], - tOtO1, + seqlen_kv_loop_steps = ( + self.get_trip_count(curr_block_coord, self.cta_tiler, seqlen_k) + - 1 ) - # 4. commit accumulated O1 - mma_corr_pipeline.producer_commit(mma_corr_producer_state) - mma_corr_producer_state.advance() - # 5. release Vi_end - load_kv_pipeline.consumer_release(mma_kv_release_state) - mma_kv_release_state.advance() - # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) + # O1 hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate + pv_whether_acc = False + for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): + # GEMM_QK0i (Q0 * Ki -> S0) + # 1. wait for Ki + load_kv_pipeline.consumer_wait(mma_kv_consumer_state) + tSrKi = tSrK[None, None, None, mma_kv_consumer_state.index] + mma_kv_consumer_state.advance() + # 2. gemm + inner_num_kphases = cute.size(tSrQ0, mode=[2]) + for kphase_idx in cutlass.range( + inner_num_kphases, unroll_full=True + ): + kphase_coord_3 = (None, None, kphase_idx) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + qk_tiled_mma, + tStS0, + tSrQ0[kphase_coord_3], + tSrKi[kphase_coord_3], + tStS0, + ) + # 3. release S0 + mma_s0_pipeline.producer_commit(mma_s0_producer_state) + mma_s0_producer_state.advance() + # End of GEMM_QK0i (Q0 * Ki -> S0) - # Commit S0 and S1 - mma_s0_pipeline.producer_commit(mma_s0_producer_state) - mma_s0_producer_state.advance() - mma_s1_pipeline.producer_commit(mma_s1_producer_state) - mma_s1_producer_state.advance() + # GEMM_PV1(i-1) (P1 * V(i-1) -> O1_partial), V(i-1) is ready in GEMM_PV0(i-1) + # 1. acquire corrected O1_partial + mma_corr_pipeline.producer_acquire(mma_corr_producer_state) + # 2. acquire P1 + mma_s1_pipeline.producer_acquire(mma_s1_producer_state) + # 3. gemm + inner_num_kphases = cute.size(tOrP0, mode=[2]) + for kphase_idx in cutlass.range( + inner_num_kphases, unroll_full=True + ): + kphase_coord_4 = (None, None, kphase_idx) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, pv_whether_acc) + cute.gemm( + pv_tiled_mma, + tOtO1, + tOrP1[kphase_coord_4], + tOrVi[kphase_coord_4], + tOtO1, + ) + pv_whether_acc = True + # 4. release accumulated O1_partial + mma_corr_pipeline.producer_commit(mma_corr_producer_state) + mma_corr_producer_state.advance() + # 5. release V(i-1) + load_kv_pipeline.consumer_release(mma_kv_release_state) + mma_kv_release_state.advance() + # End of GEMM_PV1(i-1) (P1 * V(i-1) -> O1_partial) + + # GEMM_QK1i (Q1 * Ki -> S1), Q1 is ready in GEMM_QK10; Ki is ready in GEMM_QK0i + # 1. gemm + inner_num_kphases = cute.size(tSrQ1, mode=[2]) + for kphase_idx in cutlass.range( + inner_num_kphases, unroll_full=True + ): + kphase_coord_5 = (None, None, kphase_idx) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + qk_tiled_mma, + tStS1, + tSrQ1[kphase_coord_5], + tSrKi[kphase_coord_5], + tStS1, + ) + mma_s1_pipeline.producer_commit(mma_s1_producer_state) + mma_s1_producer_state.advance() + # 2. release Ki + load_kv_pipeline.consumer_release(mma_kv_release_state) + mma_kv_release_state.advance() + # End of GEMM_QK1i (Q1 * Ki -> S1) + + # GEMM_PV0i (P0 * Vi -> O0_partial) + # 1. wait for Vi + load_kv_pipeline.consumer_wait(mma_kv_consumer_state) + tOrVi = tOrV[None, None, None, mma_kv_consumer_state.index] + mma_kv_consumer_state.advance() + # 2. acquire corrected O0_partial + mma_corr_pipeline.producer_acquire(mma_corr_producer_state) + # 3. acquire P0 + mma_s0_pipeline.producer_acquire(mma_s0_producer_state) + # 4. gemm + inner_num_kphases = cute.size(tOrP0, mode=[2]) + for kphase_idx in cutlass.range( + inner_num_kphases, unroll_full=True + ): + kphase_coord_6 = (None, None, kphase_idx) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.gemm( + pv_tiled_mma, + tOtO0, + tOrP0[kphase_coord_6], + tOrVi[kphase_coord_6], + tOtO0, + ) + # 5. release accumulated O0_partial + mma_corr_pipeline.producer_commit(mma_corr_producer_state) + mma_corr_producer_state.advance() + # End of GEMM_PV0i (P0 * Vi -> O0_partial) + # End of seqlen_kv loop + + # release Q0 & Q1 + load_q_pipeline.consumer_release(mma_q_release_state) + mma_q_release_state.advance() + load_q_pipeline.consumer_release(mma_q_release_state) + mma_q_release_state.advance() + + # GEMM_PV1(i_end) (P1 * Vi_end -> O1) + # 1. acquire corrected O1_partial + mma_corr_pipeline.producer_acquire(mma_corr_producer_state) + # 2. acquire P1 + mma_s1_pipeline.producer_acquire(mma_s1_producer_state) + # 3. gemm + num_kphases = cute.size(tOrP1, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord_7 = (None, None, kphase_idx) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.gemm( + pv_tiled_mma, + tOtO1, + tOrP1[kphase_coord_7], + tOrVi[kphase_coord_7], + tOtO1, + ) + # 4. commit accumulated O1 + mma_corr_pipeline.producer_commit(mma_corr_producer_state) + mma_corr_producer_state.advance() + # 5. release Vi_end + load_kv_pipeline.consumer_release(mma_kv_release_state) + mma_kv_release_state.advance() + # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) + + # Commit S0 and S1 + mma_s0_pipeline.producer_commit(mma_s0_producer_state) + mma_s0_producer_state.advance() + mma_s1_pipeline.producer_commit(mma_s1_producer_state) + mma_s1_producer_state.advance() # Advance to next tile tile_sched.advance_to_next_work() @@ -1391,11 +1372,12 @@ class BlackwellFusedMultiHeadAttentionForward: # End of persistent scheduler loop # dealloc tmem buffer + cute.arch.relinquish_tmem_alloc_permit() cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) - tmem_alloc_cols = cutlass.Int32(self.tmem_alloc_cols) + tmem_alloc_cols = Int32(self.tmem_alloc_cols) # Retrieving tmem ptr and make acc tmem_ptr = cute.arch.retrieve_tmem_ptr( - cutlass.Float32, + Float32, alignment=16, ptr_to_buffer_holding_addr=storage.tmem_holding_buf, ) @@ -1407,10 +1389,9 @@ class BlackwellFusedMultiHeadAttentionForward: # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.epilogue_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) - cute.arch.mbarrier_arrive(max_reg_setting_mbar_ptr) - corr_epi_consumer_state = utils.make_pipeline_state( - utils.PipelineUserType.Consumer, self.epi_stage + corr_epi_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.epi_stage ) corr_epi_release_state = corr_epi_consumer_state.clone() @@ -1421,43 +1402,77 @@ class BlackwellFusedMultiHeadAttentionForward: while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + continue_cond = False + cuseqlen_q = Int32(0) + seqlen_q = mQ_qdl.shape[0] - o0_coord = 2 * curr_block_coord[0] - o1_coord = o0_coord + 1 - gO = gO_qdl[None, None, None, 0, curr_block_coord[2]] - tOsO, tOgO = cute.nvgpu.cpasync.tma_partition( - tma_atom_o, - 0, - cute.make_layout(1), - cute.group_modes(sO, 0, 2), - cute.group_modes(gO, 0, 2), - ) + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) + ) - # O0 O1 using the same pipeline - # wait from corr, issue tma store on smem - # O0 - # 1. wait for O0 final - corr_epi_pipeline.consumer_wait(corr_epi_consumer_state) - corr_epi_consumer_state.advance() - # 2. copy O0 to gmem - cute.copy(tma_atom_o, tOsO[None, 0], tOgO[None, o0_coord]) - cute.arch.cp_async_bulk_commit_group() - # O1 - # 1. wait for O1 final - corr_epi_pipeline.consumer_wait(corr_epi_consumer_state) - corr_epi_consumer_state.advance() - # 2. copy O1 to gmem - cute.copy(tma_atom_o, tOsO[None, 1], tOgO[None, o1_coord]) - cute.arch.cp_async_bulk_commit_group() + if not continue_cond: + curr_block_coord_o = curr_block_coord + mO_qdl_ = mO_qdl + if cutlass.const_expr(cum_seqlen_q is not None): + logical_offset_mO = ( + mO_qdl_.shape[0] - seqlen_q, + 0, + (0, cuseqlen_q + seqlen_q), + ) + mO_qdl_ = cute.domain_offset(logical_offset_mO, mO_qdl_) + curr_block_coord_o = ( + curr_block_coord[0], + curr_block_coord[1], + (curr_block_coord[2][0], 0), + ) - # Ensure O0 buffer is ready to be released - cute.arch.cp_async_bulk_wait_group(1, read=True) - corr_epi_pipeline.consumer_release(corr_epi_release_state) - corr_epi_release_state.advance() - # Ensure O1 buffer is ready to be released - cute.arch.cp_async_bulk_wait_group(0, read=True) - corr_epi_pipeline.consumer_release(corr_epi_release_state) - corr_epi_release_state.advance() + o0_coord = 2 * curr_block_coord_o[0] + o1_coord = o0_coord + 1 + gO_qdl = cute.flat_divide( + mO_qdl_, cute.select(self.pv_mma_tiler, mode=[0, 1]) + ) + gO = gO_qdl[None, None, None, 0, curr_block_coord_o[2]] + tOsO, tOgO = cute.nvgpu.cpasync.tma_partition( + tma_atom_o, + 0, + cute.make_layout(1), + cute.group_modes(sO, 0, 2), + cute.group_modes(gO, 0, 2), + ) + + # O0 O1 using the same pipeline + # wait from corr, issue tma store on smem + # O0 + # 1. wait for O0 final + corr_epi_pipeline.consumer_wait(corr_epi_consumer_state) + corr_epi_consumer_state.advance() + # 2. copy O0 to gmem + cute.copy(tma_atom_o, tOsO[None, 0], tOgO[None, o0_coord]) + cute.arch.cp_async_bulk_commit_group() + # O1 + # 1. wait for O1 final + corr_epi_pipeline.consumer_wait(corr_epi_consumer_state) + corr_epi_consumer_state.advance() + # 2. copy O1 to gmem + cute.copy(tma_atom_o, tOsO[None, 1], tOgO[None, o1_coord]) + cute.arch.cp_async_bulk_commit_group() + + # Ensure O0 buffer is ready to be released + cute.arch.cp_async_bulk_wait_group(1, read=True) + corr_epi_pipeline.consumer_release(corr_epi_release_state) + corr_epi_release_state.advance() + # Ensure O1 buffer is ready to be released + cute.arch.cp_async_bulk_wait_group(0, read=True) + corr_epi_pipeline.consumer_release(corr_epi_release_state) + corr_epi_release_state.advance() # Advance to next tile tile_sched.advance_to_next_work() @@ -1469,11 +1484,13 @@ class BlackwellFusedMultiHeadAttentionForward: # /////////////////////////////////////////////////////////////////////////////// if warp_idx < self.softmax1_warp_ids[0]: # increase register after decreasing - cute.arch.mbarrier_wait(max_reg_setting_mbar_ptr, 0) cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) self.softmax( stage=0, + seqlen_k=mK_kdl.shape[0], + cum_seqlen_q=cum_seqlen_q, + cum_seqlen_k=cum_seqlen_k, scale_softmax_log2=scale_softmax_log2, qk_thr_mma=qk_thr_mma, tStS=tStS, @@ -1482,7 +1499,6 @@ class BlackwellFusedMultiHeadAttentionForward: si_corr_pipeline=s0_corr_pipeline, s0_s1_sequence_pipeline=s0_s1_sequence_pipeline, tile_sched_params=tile_sched_params, - fused_mask=fused_mask, ) cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) @@ -1494,11 +1510,13 @@ class BlackwellFusedMultiHeadAttentionForward: and warp_idx >= self.softmax1_warp_ids[0] ): # increase register after decreasing - cute.arch.mbarrier_wait(max_reg_setting_mbar_ptr, 0) cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) self.softmax( stage=1, + seqlen_k=mK_kdl.shape[0], + cum_seqlen_q=cum_seqlen_q, + cum_seqlen_k=cum_seqlen_k, scale_softmax_log2=scale_softmax_log2, qk_thr_mma=qk_thr_mma, tStS=tStS, @@ -1507,7 +1525,6 @@ class BlackwellFusedMultiHeadAttentionForward: si_corr_pipeline=s1_corr_pipeline, s0_s1_sequence_pipeline=s0_s1_sequence_pipeline, tile_sched_params=tile_sched_params, - fused_mask=fused_mask, ) cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) @@ -1516,19 +1533,18 @@ class BlackwellFusedMultiHeadAttentionForward: # /////////////////////////////////////////////////////////////////////////////// if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_correction) - cute.arch.mbarrier_arrive(max_reg_setting_mbar_ptr) - s0_corr_consumer_state = utils.make_pipeline_state( - utils.PipelineUserType.Consumer, self.softmax_corr_stage + s0_corr_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.softmax_corr_stage ) - s1_corr_consumer_state = utils.make_pipeline_state( - utils.PipelineUserType.Consumer, self.softmax_corr_stage + s1_corr_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.softmax_corr_stage ) - o_corr_consumer_state = utils.make_pipeline_state( - utils.PipelineUserType.Consumer, self.mma_corr_stage + o_corr_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_corr_stage ) - corr_epi_producer_state = utils.make_pipeline_state( - utils.PipelineUserType.Producer, self.epi_stage + corr_epi_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.epi_stage ) cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) @@ -1566,114 +1582,138 @@ class BlackwellFusedMultiHeadAttentionForward: while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + seqlen_k = mK_kdl.shape[0] + continue_cond = False - # Ignore first signal from softmax as no correction is required - s0_corr_pipeline.consumer_wait(s0_corr_consumer_state) - s0_corr_pipeline.consumer_release(s0_corr_consumer_state) - s0_corr_consumer_state.advance() + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) + ) - s1_corr_pipeline.consumer_wait(s1_corr_consumer_state) + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k - seqlen_kv_loop_steps = ( - fused_mask.get_trip_count(curr_block_coord, self.cta_tiler) - 1 - ) - for i in cutlass.range_dynamic(0, seqlen_kv_loop_steps, 1, unroll=1): - # wait for S0 + # Ignore first signal from softmax as no correction is required s0_corr_pipeline.consumer_wait(s0_corr_consumer_state) - tTMEM_LOAD_VECrS = cute.make_fragment( - tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype - ) - # read row_wise new global max - cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS0, tTMEM_LOAD_VECrS) + s0_corr_pipeline.consumer_release(s0_corr_consumer_state) + s0_corr_consumer_state.advance() - scale_ = scale_softmax_log2 * ( - tTMEM_LOAD_VECrS[0] - tTMEM_LOAD_VECrS[1] - ) - scale = cute.arch.exp2(scale_) + s1_corr_pipeline.consumer_wait(s1_corr_consumer_state) - mma_corr_pipeline.consumer_wait(o_corr_consumer_state) - self.correction_rescale(pv_thr_mma, tOtO0, scale) + seqlen_kv_loop_steps = ( + self.get_trip_count(curr_block_coord, self.cta_tiler, seqlen_k) + - 1 + ) + for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): + # wait for S0 + s0_corr_pipeline.consumer_wait(s0_corr_consumer_state) + tTMEM_LOAD_VECrS = cute.make_fragment( + tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype + ) + # read row_wise new global max + cute.copy( + tiled_tmem_load_vec, tTMEM_LOAD_VECtS0, tTMEM_LOAD_VECrS + ) + + scale_ = scale_softmax_log2 * ( + tTMEM_LOAD_VECrS[0] - tTMEM_LOAD_VECrS[1] + ) + scale = cute.arch.exp2(scale_) + + mma_corr_pipeline.consumer_wait(o_corr_consumer_state) + self.correction_rescale(pv_thr_mma, tOtO0, scale) + + s1_corr_pipeline.consumer_release(s1_corr_consumer_state) + s1_corr_consumer_state.advance() + + cute.arch.fence_view_async_tmem_store() + + mma_corr_pipeline.consumer_release(o_corr_consumer_state) + o_corr_consumer_state.advance() + + s1_corr_pipeline.consumer_wait(s1_corr_consumer_state) + + cute.copy( + tiled_tmem_load_vec, tTMEM_LOAD_VECtS1, tTMEM_LOAD_VECrS + ) + + scale_ = scale_softmax_log2 * ( + tTMEM_LOAD_VECrS[0] - tTMEM_LOAD_VECrS[1] + ) + scale = cute.arch.exp2(scale_) + + mma_corr_pipeline.consumer_wait(o_corr_consumer_state) + self.correction_rescale(pv_thr_mma, tOtO1, scale) + + s0_corr_pipeline.consumer_release(s0_corr_consumer_state) + s0_corr_consumer_state.advance() + + cute.arch.fence_view_async_tmem_store() + mma_corr_pipeline.consumer_release(o_corr_consumer_state) + o_corr_consumer_state.advance() + # End of seqlen_corr_loop_steps s1_corr_pipeline.consumer_release(s1_corr_consumer_state) s1_corr_consumer_state.advance() - cute.arch.fence_view_async_tmem_store() + s0_corr_pipeline.consumer_wait(s0_corr_consumer_state) - mma_corr_pipeline.consumer_release(o_corr_consumer_state) - o_corr_consumer_state.advance() - - s1_corr_pipeline.consumer_wait(s1_corr_consumer_state) - - cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS1, tTMEM_LOAD_VECrS) - - scale_ = scale_softmax_log2 * ( - tTMEM_LOAD_VECrS[0] - tTMEM_LOAD_VECrS[1] + tTMEM_LOAD_VECrS = cute.make_fragment( + tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype ) - scale = cute.arch.exp2(scale_) - - mma_corr_pipeline.consumer_wait(o_corr_consumer_state) - self.correction_rescale(pv_thr_mma, tOtO1, scale) + cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS0, tTMEM_LOAD_VECrS) + cute.arch.fence_view_async_tmem_load() s0_corr_pipeline.consumer_release(s0_corr_consumer_state) s0_corr_consumer_state.advance() - cute.arch.fence_view_async_tmem_store() + mma_corr_pipeline.consumer_wait(o_corr_consumer_state) + corr_epi_pipeline.producer_acquire(corr_epi_producer_state) + + self.correction_epilog( + pv_thr_mma, + tOtO0, + scale_output / tTMEM_LOAD_VECrS[0], + sO[None, None, 0], + ) + mma_corr_pipeline.consumer_release(o_corr_consumer_state) o_corr_consumer_state.advance() - # End of seqlen_corr_loop_steps - s1_corr_pipeline.consumer_release(s1_corr_consumer_state) - s1_corr_consumer_state.advance() + corr_epi_pipeline.producer_commit(corr_epi_producer_state) + corr_epi_producer_state.advance() - s0_corr_pipeline.consumer_wait(s0_corr_consumer_state) + s1_corr_pipeline.consumer_wait(s1_corr_consumer_state) + # load from V1 + cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS1, tTMEM_LOAD_VECrS) + cute.arch.fence_view_async_tmem_load() - tTMEM_LOAD_VECrS = cute.make_fragment( - tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype - ) - cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS0, tTMEM_LOAD_VECrS) - cute.arch.fence_view_async_tmem_load() + s1_corr_pipeline.consumer_release(s1_corr_consumer_state) + s1_corr_consumer_state.advance() - s0_corr_pipeline.consumer_release(s0_corr_consumer_state) - s0_corr_consumer_state.advance() + mma_corr_pipeline.consumer_wait(o_corr_consumer_state) - mma_corr_pipeline.consumer_wait(o_corr_consumer_state) - corr_epi_pipeline.producer_acquire(corr_epi_producer_state) + corr_epi_pipeline.producer_acquire(corr_epi_producer_state) + self.correction_epilog( + pv_thr_mma, + tOtO1, + scale_output / tTMEM_LOAD_VECrS[0], + sO[None, None, 1], + ) + mma_corr_pipeline.consumer_release(o_corr_consumer_state) + o_corr_consumer_state.advance() - self.correction_epilog( - pv_thr_mma, - tOtO0, - scale_output / tTMEM_LOAD_VECrS[0], - sO[None, None, 0], - ) - - mma_corr_pipeline.consumer_release(o_corr_consumer_state) - o_corr_consumer_state.advance() - - corr_epi_pipeline.producer_commit(corr_epi_producer_state) - corr_epi_producer_state.advance() - - s1_corr_pipeline.consumer_wait(s1_corr_consumer_state) - # load from V1 - cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS1, tTMEM_LOAD_VECrS) - cute.arch.fence_view_async_tmem_load() - - s1_corr_pipeline.consumer_release(s1_corr_consumer_state) - s1_corr_consumer_state.advance() - - mma_corr_pipeline.consumer_wait(o_corr_consumer_state) - - corr_epi_pipeline.producer_acquire(corr_epi_producer_state) - self.correction_epilog( - pv_thr_mma, - tOtO1, - scale_output / tTMEM_LOAD_VECrS[0], - sO[None, None, 1], - ) - mma_corr_pipeline.consumer_release(o_corr_consumer_state) - o_corr_consumer_state.advance() - - corr_epi_pipeline.producer_commit(corr_epi_producer_state) - corr_epi_producer_state.advance() + corr_epi_pipeline.producer_commit(corr_epi_producer_state) + corr_epi_producer_state.advance() # Advance to next tile tile_sched.advance_to_next_work() @@ -1689,15 +1729,16 @@ class BlackwellFusedMultiHeadAttentionForward: self, stage: int, need_apply_mask: bool, - row_max: cutlass.Float32, - row_sum: cutlass.Float32, - mma_si_consumer_state: utils.PipelineState, - si_corr_producer_state: utils.PipelineState, - s0_s1_sequence_state: utils.PipelineState, - mma_si_pipeline: utils.PipelineAsync, - si_corr_pipeline: utils.PipelineAsync, - s0_s1_sequence_pipeline: utils.PipelineAsync, - scale_softmax_log2: cutlass.Float32, + seqlen_k: Int32, + row_max: Float32, + row_sum: Float32, + mma_si_consumer_state: pipeline.PipelineState, + si_corr_producer_state: pipeline.PipelineState, + s0_s1_sequence_state: pipeline.PipelineState, + mma_si_pipeline: pipeline.PipelineAsync, + si_corr_pipeline: pipeline.PipelineAsync, + s0_s1_sequence_pipeline: pipeline.PipelineAsync, + scale_softmax_log2: Float32, cS: cute.Tensor, qk_thr_mma: cute.core.ThrMma, tiled_tmem_load: cute.TiledCopy, @@ -1709,13 +1750,12 @@ class BlackwellFusedMultiHeadAttentionForward: tTMEM_LOADtS: cute.Tensor, tTMEM_STORE_VECtS: cute.Tensor, tTMEM_STOREtS_x4: cute.Tensor, - fused_mask: cute.Tensor, ) -> Tuple[ - cutlass.Float32, - cutlass.Float32, - utils.PipelineState, - utils.PipelineState, - utils.PipelineState, + Float32, + Float32, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, ]: """Perform a single step of the softmax computation on a block of attention scores. @@ -1741,19 +1781,19 @@ class BlackwellFusedMultiHeadAttentionForward: :param row_sum: Current sum value for the row :type row_sum: cute.core.Tensor :param mma_si_consumer_state: Pipeline state for MMA consumer operations - :type mma_si_consumer_state: utils.PipelineState + :type mma_si_consumer_state: pipeline.PipelineState :param si_corr_producer_state: Pipeline state for correction producer operations - :type si_corr_producer_state: utils.PipelineState + :type si_corr_producer_state: pipeline.PipelineState :param s0_s1_sequence_state: Pipeline state for sequence synchronization - :type s0_s1_sequence_state: utils.PipelineState + :type s0_s1_sequence_state: pipeline.PipelineState :param mma_si_pipeline: Pipeline for MMA operations - :type mma_si_pipeline: utils.PipelineAsync + :type mma_si_pipeline: pipeline.PipelineAsync :param si_corr_pipeline: Pipeline for correction operations - :type si_corr_pipeline: utils.PipelineAsync + :type si_corr_pipeline: pipeline.PipelineAsync :param s0_s1_sequence_pipeline: Pipeline for sequence synchronization - :type s0_s1_sequence_pipeline: utils.PipelineAsync + :type s0_s1_sequence_pipeline: pipeline.PipelineAsync :param scale_softmax_log2: Log2 scale factor for softmax computation - :type scale_softmax_log2: cutlass.Float32 + :type scale_softmax_log2: Float32 :param cS: Current slice of attention matrix :type cS: cute.Tensor :param qk_thr_mma: Thread MMA operation @@ -1776,14 +1816,10 @@ class BlackwellFusedMultiHeadAttentionForward: :type tTMEM_STORE_VECtS: cute.Tensor :param tTMEM_STOREtS_x4: Tensor for storing processed data :type tTMEM_STOREtS_x4: cute.Tensor - :param fused_mask: Mask configuration for attention masking - :type fused_mask: cute.Tensor :return: Updated state values (row_max, row_sum, and pipeline states) :rtype: tuple """ - tilePlikeFP32 = ( - self.qk_mma_tiler[1] // cutlass.Float32.width * self.o_dtype.width - ) + tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width tScS = qk_thr_mma.partition_C(cS) tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2))) tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) @@ -1802,7 +1838,7 @@ class BlackwellFusedMultiHeadAttentionForward: cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) if need_apply_mask: - fused_mask.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS) + self.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, seqlen_k) old_row_max = row_max row_max = tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0) @@ -1884,7 +1920,7 @@ class BlackwellFusedMultiHeadAttentionForward: frg_tile = cute.size(tTMEM_LOADrS) // reduction_unroll tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) - for j in range(0, cute.size(tTMEM_LOADrS_frg, mode=[0]), 2): + for j in cutlass.range_constexpr(0, cute.size(tTMEM_LOADrS_frg, mode=[0]), 2): local_row_sum_0 = cute.arch.add_packed_f32x2( local_row_sum_0, (tTMEM_LOADrS_frg[j, 0], tTMEM_LOADrS_frg[j + 1, 0]) ) @@ -1916,15 +1952,17 @@ class BlackwellFusedMultiHeadAttentionForward: def softmax( self, stage: int, - scale_softmax_log2: cutlass.Float32, + seqlen_k: Int32, + cum_seqlen_q: cute.Tensor | None, + cum_seqlen_k: cute.Tensor | None, + scale_softmax_log2: Float32, qk_thr_mma: cute.core.ThrMma, tStS: cute.Tensor, tStSi: cute.Tensor, - mma_si_pipeline: utils.PipelineAsync, - si_corr_pipeline: utils.PipelineAsync, - s0_s1_sequence_pipeline: utils.PipelineAsync, + mma_si_pipeline: pipeline.PipelineAsync, + si_corr_pipeline: pipeline.PipelineAsync, + s0_s1_sequence_pipeline: pipeline.PipelineAsync, tile_sched_params: FmhaStaticTileSchedulerParams, - fused_mask: FusedMask, ): """Compute softmax on attention scores from QK matrix multiplication. @@ -1940,7 +1978,7 @@ class BlackwellFusedMultiHeadAttentionForward: :param stage: Processing stage (0 for first half, 1 for second half of attention matrix) :type stage: int :param scale_softmax_log2: Log2 scale factor for softmax operation - :type scale_softmax_log2: cutlass.Float32 + :type scale_softmax_log2: Float32 :param qk_thr_mma: Thread MMA operation for QK matrix multiplication :type qk_thr_mma: cute.core.ThrMma :param tStS: Shared tensor for softmax input/output @@ -1948,15 +1986,13 @@ class BlackwellFusedMultiHeadAttentionForward: :param tStSi: Input tensor containing attention scores :type tStSi: cute.Tensor :param mma_si_pipeline: Pipeline for synchronizing with MMA operations - :type mma_si_pipeline: utils.PipelineAsync + :type mma_si_pipeline: pipeline.PipelineAsync :param si_corr_pipeline: Pipeline for synchronizing with correction operations - :type si_corr_pipeline: utils.PipelineAsync + :type si_corr_pipeline: pipeline.PipelineAsync :param s0_s1_sequence_pipeline: Pipeline for synchronizing between stage 0 and 1 - :type s0_s1_sequence_pipeline: utils.PipelineAsync + :type s0_s1_sequence_pipeline: pipeline.PipelineAsync :param tile_sched_params: Parameters for tile scheduling :type tile_sched_params: FmhaStaticTileSchedulerParams - :param fused_mask: Mask configuration for attention masking - :type fused_mask: FusedMask """ tidx, _, _ = cute.arch.thread_idx() thread_idx = tidx % ( @@ -2023,17 +2059,17 @@ class BlackwellFusedMultiHeadAttentionForward: thr_tmem_store = tiled_tmem_store.get_slice(thread_idx) tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P) - mma_si_consumer_state = utils.make_pipeline_state( - utils.PipelineUserType.Consumer, self.mma_softmax_stage + mma_si_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_softmax_stage ) - si_corr_producer_state = utils.make_pipeline_state( - utils.PipelineUserType.Producer, self.softmax_corr_stage + si_corr_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.softmax_corr_stage ) - s0_s1_sequence_state = utils.make_pipeline_state( + s0_s1_sequence_state = pipeline.make_pipeline_state( ( - utils.PipelineUserType.Producer + pipeline.PipelineUserType.Producer if stage == 0 - else utils.PipelineUserType.Consumer + else pipeline.PipelineUserType.Consumer ), 1, ) @@ -2045,115 +2081,138 @@ class BlackwellFusedMultiHeadAttentionForward: while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx - logical_offset = ( - curr_block_coord[0] * self.cta_tiler[0] + stage * self.qk_mma_tiler[0], - 0, - ) + batch_coord = curr_block_coord[2][1] + seqlen_k_ = seqlen_k + continue_cond = False - cS = cute.domain_offset(logical_offset, cS_base) - - si_corr_pipeline.producer_acquire(si_corr_producer_state) - unmask_count = fused_mask.get_unmasked_trip_count( - curr_block_coord, - self.cta_tiler, - ) - - row_max = -cutlass.Float32.inf - row_sum = 0.0 - - for i in cutlass.range_dynamic(0, unmask_count, 1, unroll=1): - cS_iter = cute.domain_offset((0, i * self.qk_mma_tiler[1]), cS) - ( - row_max, - row_sum, - mma_si_consumer_state, - si_corr_producer_state, - s0_s1_sequence_state, - ) = self.softmax_step( - stage, - False, - row_max, - row_sum, - mma_si_consumer_state, - si_corr_producer_state, - s0_s1_sequence_state, - mma_si_pipeline, - si_corr_pipeline, - s0_s1_sequence_pipeline, - scale_softmax_log2, - cS_iter, - qk_thr_mma, - tiled_tmem_load, - tiled_tmem_store, - tiled_tmem_store_vec, - thr_tmem_load, - thr_tmem_store, - thr_tmem_store_vec, - tTMEM_LOADtS, - tTMEM_STORE_VECtS, - tTMEM_STOREtS_x4, - fused_mask, + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) ) - mask_count = fused_mask.get_masked_trip_count( - curr_block_coord, - self.cta_tiler, - ) + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k_ = cum_seqlen_k[batch_coord + 1] - cuseqlen_k - for i in cutlass.range_dynamic( - unmask_count, unmask_count + mask_count, 1, unroll=1 - ): - cS_iter = cute.domain_offset((0, i * self.qk_mma_tiler[1]), cS) - ( - row_max, - row_sum, - mma_si_consumer_state, - si_corr_producer_state, - s0_s1_sequence_state, - ) = self.softmax_step( - stage, - True, - row_max, - row_sum, - mma_si_consumer_state, - si_corr_producer_state, - s0_s1_sequence_state, - mma_si_pipeline, - si_corr_pipeline, - s0_s1_sequence_pipeline, - scale_softmax_log2, - cS_iter, - qk_thr_mma, - tiled_tmem_load, - tiled_tmem_store, - tiled_tmem_store_vec, - thr_tmem_load, - thr_tmem_store, - thr_tmem_store_vec, - tTMEM_LOADtS, - tTMEM_STORE_VECtS, - tTMEM_STOREtS_x4, - fused_mask, + logical_offset = ( + curr_block_coord[0] * self.cta_tiler[0] + + stage * self.qk_mma_tiler[0], + 0, ) - mma_si_pipeline.consumer_wait(mma_si_consumer_state) + cS = cute.domain_offset(logical_offset, cS_base) + si_corr_pipeline.producer_acquire(si_corr_producer_state) - tTMEM_STORE_VECrS = cute.make_fragment( - tTMEM_STORE_VECcS.shape, self.qk_acc_dtype - ) - tTMEM_STORE_VECrS[0] = row_sum - tTMEM_STORE_VECrS[1] = row_max - cute.copy(tiled_tmem_store_vec, tTMEM_STORE_VECrS, tTMEM_STORE_VECtS) - cute.arch.fence_view_async_tmem_store() + unmask_count = self.get_unmasked_trip_count( + curr_block_coord, + self.cta_tiler, + seqlen_k_, + ) - si_corr_pipeline.producer_commit(si_corr_producer_state) - si_corr_producer_state.advance() + row_max = -Float32.inf + row_sum = 0.0 - si_corr_pipeline.producer_acquire(si_corr_producer_state) + for i in cutlass.range(0, unmask_count, 1, unroll=1): + cS_iter = cute.domain_offset((0, i * self.qk_mma_tiler[1]), cS) + ( + row_max, + row_sum, + mma_si_consumer_state, + si_corr_producer_state, + s0_s1_sequence_state, + ) = self.softmax_step( + stage, + False, + seqlen_k_, + row_max, + row_sum, + mma_si_consumer_state, + si_corr_producer_state, + s0_s1_sequence_state, + mma_si_pipeline, + si_corr_pipeline, + s0_s1_sequence_pipeline, + scale_softmax_log2, + cS_iter, + qk_thr_mma, + tiled_tmem_load, + tiled_tmem_store, + tiled_tmem_store_vec, + thr_tmem_load, + thr_tmem_store, + thr_tmem_store_vec, + tTMEM_LOADtS, + tTMEM_STORE_VECtS, + tTMEM_STOREtS_x4, + ) - # Empty step to sync against pipe s - mma_si_pipeline.consumer_release(mma_si_consumer_state) - mma_si_consumer_state.advance() + mask_count = self.get_masked_trip_count( + curr_block_coord, + self.cta_tiler, + seqlen_k_, + ) + + for i in cutlass.range( + unmask_count, unmask_count + mask_count, 1, unroll=1 + ): + cS_iter = cute.domain_offset((0, i * self.qk_mma_tiler[1]), cS) + ( + row_max, + row_sum, + mma_si_consumer_state, + si_corr_producer_state, + s0_s1_sequence_state, + ) = self.softmax_step( + stage, + True, + seqlen_k_, + row_max, + row_sum, + mma_si_consumer_state, + si_corr_producer_state, + s0_s1_sequence_state, + mma_si_pipeline, + si_corr_pipeline, + s0_s1_sequence_pipeline, + scale_softmax_log2, + cS_iter, + qk_thr_mma, + tiled_tmem_load, + tiled_tmem_store, + tiled_tmem_store_vec, + thr_tmem_load, + thr_tmem_store, + thr_tmem_store_vec, + tTMEM_LOADtS, + tTMEM_STORE_VECtS, + tTMEM_STOREtS_x4, + ) + + mma_si_pipeline.consumer_wait(mma_si_consumer_state) + + tTMEM_STORE_VECrS = cute.make_fragment( + tTMEM_STORE_VECcS.shape, self.qk_acc_dtype + ) + tTMEM_STORE_VECrS[0] = row_sum + tTMEM_STORE_VECrS[1] = row_max + cute.copy(tiled_tmem_store_vec, tTMEM_STORE_VECrS, tTMEM_STORE_VECtS) + cute.arch.fence_view_async_tmem_store() + + si_corr_pipeline.producer_commit(si_corr_producer_state) + si_corr_producer_state.advance() + + si_corr_pipeline.producer_acquire(si_corr_producer_state) + + # Empty step to sync against pipe s + mma_si_pipeline.consumer_release(mma_si_consumer_state) + mma_si_consumer_state.advance() # Advance to next tile tile_sched.advance_to_next_work() @@ -2165,7 +2224,7 @@ class BlackwellFusedMultiHeadAttentionForward: self, thr_mma: cute.core.ThrMma, tOtO: cute.Tensor, - scale: cutlass.Float32, + scale: Float32, ): """Rescale intermediate attention results based on softmax normalization factor. @@ -2184,7 +2243,7 @@ class BlackwellFusedMultiHeadAttentionForward: :param tOtO: Tensor representing partial attention output to be rescaled :type tOtO: cute.Tensor :param scale: Scaling factor to apply to the partial results - :type scale: cutlass.Float32 + :type scale: Float32 """ pv_tiled_mma_shape = ( self.pv_mma_tiler[0], @@ -2254,7 +2313,7 @@ class BlackwellFusedMultiHeadAttentionForward: self, thr_mma: cute.core.ThrMma, tOtO: cute.Tensor, - scale: cutlass.Float32, + scale: Float32, sO: cute.Tensor, ): """Apply final scaling and transformation to attention output before writing to global memory. @@ -2275,7 +2334,7 @@ class BlackwellFusedMultiHeadAttentionForward: :param tOtO: Tensor containing accumulated attention output :type tOtO: cute.Tensor :param scale: Final scaling factor to apply to the output - :type scale: cutlass.Float32 + :type scale: Float32 :param sO: Shared memory tensor for the final output :type sO: cute.Tensor """ @@ -2347,129 +2406,206 @@ class BlackwellFusedMultiHeadAttentionForward: space=cute.arch.SharedSpace.shared_cta, ) + def get_trip_count( + self, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_k: Int32, + ) -> Int32: + result = 0 + if ( + self.mask_type == MaskType.NO_MASK + or self.mask_type == MaskType.RESIDUAL_MASK + ): + result = cute.ceil_div(seqlen_k, tile_shape[1]) + elif self.mask_type == MaskType.CAUSAL_MASK: + max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1]) + max_blocks_q = cute.ceil_div( + (blk_coord[0] + 1) * tile_shape[0], tile_shape[1] + ) + result = cutlass.min(max_blocks_k, max_blocks_q) + return result + + @cute.jit + def get_masked_trip_count( + self, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_k: Int32, + ) -> Int32: + result = 0 + if self.mask_type == MaskType.NO_MASK: + result = 0 + elif self.mask_type == MaskType.RESIDUAL_MASK: + if seqlen_k % tile_shape[1] != 0: + result = 1 + else: + result = 0 + elif self.mask_type == MaskType.CAUSAL_MASK: + result = cute.ceil_div(tile_shape[0], tile_shape[1]) + return result + + @cute.jit + def get_unmasked_trip_count( + self, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_k: Int32, + ) -> Int32: + result = 0 + if self.mask_type == MaskType.NO_MASK: + result = self.get_trip_count(blk_coord, tile_shape, seqlen_k) + elif self.mask_type == MaskType.RESIDUAL_MASK: + if seqlen_k % tile_shape[1] != 0: + result = self.get_trip_count(blk_coord, tile_shape, seqlen_k) - 1 + else: + result = self.get_trip_count(blk_coord, tile_shape, seqlen_k) + elif self.mask_type == MaskType.CAUSAL_MASK: + result = self.get_trip_count( + blk_coord, tile_shape, seqlen_k + ) - self.get_masked_trip_count(blk_coord, tile_shape, seqlen_k) + return result + + @cute.jit + def apply_mask( + self, + acc_qk: cute.Tensor, + index_qk: cute.Tensor, + seqlen_k: Int32, + ): + if self.mask_type == MaskType.RESIDUAL_MASK: + for i in range(cute.size(acc_qk)): + pos = index_qk[i] + if pos[1] >= seqlen_k: + acc_qk[i] = -Float32.inf + elif self.mask_type == MaskType.CAUSAL_MASK: + for i in range(cute.size(acc_qk)): + pos = index_qk[i] + if pos[0] < pos[1] or pos[1] >= seqlen_k: + acc_qk[i] = -Float32.inf + def make_and_init_load_q_pipeline(self, load_q_mbar_ptr): - load_q_producer_group = utils.CooperativeGroup( - utils.Agent.Thread, len([self.load_warp_id]) + load_q_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_warp_id]) ) - load_q_consumer_group = utils.CooperativeGroup( - utils.Agent.Thread, len([self.mma_warp_id]) + load_q_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) ) - return utils.PipelineTmaUmma.create( - barrier_storage=load_q_mbar_ptr, + return pipeline.PipelineTmaUmma.create( num_stages=self.q_stage, producer_group=load_q_producer_group, consumer_group=load_q_consumer_group, tx_count=self.tma_copy_q_bytes, + barrier_storage=load_q_mbar_ptr, ) def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): - load_kv_producer_group = utils.CooperativeGroup( - utils.Agent.Thread, len([self.load_warp_id]) + load_kv_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_warp_id]) ) - load_kv_consumer_group = utils.CooperativeGroup( - utils.Agent.Thread, len([self.mma_warp_id]) + load_kv_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) ) - return utils.PipelineTmaUmma.create( - barrier_storage=load_kv_mbar_ptr, + return pipeline.PipelineTmaUmma.create( num_stages=self.kv_stage, producer_group=load_kv_producer_group, consumer_group=load_kv_consumer_group, tx_count=self.tma_copy_kv_bytes, + barrier_storage=load_kv_mbar_ptr, ) def make_and_init_mma_si_pipeline(self, mma_si_mbar_ptr): - mma_si_producer_group = utils.CooperativeGroup( - utils.Agent.Thread, len([self.mma_warp_id]) + mma_si_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) ) - mma_si_consumer_group = utils.CooperativeGroup( - utils.Agent.Thread, + mma_si_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_warp * len(self.softmax0_warp_ids), self.threads_per_warp * len(self.softmax0_warp_ids), ) - return utils.PipelineUmmaAsync.create( - barrier_storage=mma_si_mbar_ptr, + return pipeline.PipelineUmmaAsync.create( num_stages=self.mma_softmax_stage, producer_group=mma_si_producer_group, consumer_group=mma_si_consumer_group, + barrier_storage=mma_si_mbar_ptr, ) def make_and_init_si_corr_pipeline(self, si_corr_mbar_ptr): - si_corr_producer_group = utils.CooperativeGroup( - utils.Agent.Thread, + si_corr_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_warp * len(self.softmax0_warp_ids), self.threads_per_warp * len(self.softmax0_warp_ids), ) - si_corr_consumer_group = utils.CooperativeGroup( - utils.Agent.Thread, + si_corr_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_warp * len(self.correction_warp_ids), self.threads_per_warp * len(self.correction_warp_ids), ) - return utils.PipelineAsync.create( - barrier_storage=si_corr_mbar_ptr, + return pipeline.PipelineAsync.create( num_stages=self.softmax_corr_stage, producer_group=si_corr_producer_group, consumer_group=si_corr_consumer_group, + barrier_storage=si_corr_mbar_ptr, ) def make_and_init_corr_epi_pipeline(self, corr_epi_mbar_ptr): - corr_epi_producer_group = utils.CooperativeGroup( - utils.Agent.Thread, + corr_epi_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_warp * len(self.correction_warp_ids), self.threads_per_warp * len(self.correction_warp_ids), ) - corr_epi_consumer_group = utils.CooperativeGroup( - utils.Agent.Thread, + corr_epi_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_warp * len([self.epilogue_warp_id]), self.threads_per_warp * len([self.epilogue_warp_id]), ) - return utils.PipelineAsync.create( - barrier_storage=corr_epi_mbar_ptr, + return pipeline.PipelineAsync.create( num_stages=self.epi_stage, producer_group=corr_epi_producer_group, consumer_group=corr_epi_consumer_group, + barrier_storage=corr_epi_mbar_ptr, ) def make_and_init_mma_corr_pipeline(self, mma_corr_mbar_ptr): - mma_corr_producer_group = utils.CooperativeGroup( - utils.Agent.Thread, len([self.mma_warp_id]) + mma_corr_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) ) - mma_corr_consumer_group = utils.CooperativeGroup( - utils.Agent.Thread, + mma_corr_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_warp * len(self.correction_warp_ids), self.threads_per_warp * len(self.correction_warp_ids), ) - return utils.PipelineUmmaAsync.create( - barrier_storage=mma_corr_mbar_ptr, + return pipeline.PipelineUmmaAsync.create( num_stages=self.mma_corr_stage, producer_group=mma_corr_producer_group, consumer_group=mma_corr_consumer_group, + barrier_storage=mma_corr_mbar_ptr, ) def make_and_init_si_sequence_pipeline(self, si_sequence_mbar_ptr): - s0_sequence_group = utils.CooperativeGroup( - utils.Agent.Thread, + s0_sequence_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_warp * len(self.softmax0_warp_ids), self.threads_per_warp * len(self.softmax0_warp_ids), ) - s1_sequence_group = utils.CooperativeGroup( - utils.Agent.Thread, + s1_sequence_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_warp * len(self.softmax1_warp_ids), self.threads_per_warp * len(self.softmax1_warp_ids), ) - return utils.PipelineAsync.create( - barrier_storage=si_sequence_mbar_ptr, + return pipeline.PipelineAsync.create( num_stages=1, producer_group=s0_sequence_group, consumer_group=s1_sequence_group, + barrier_storage=si_sequence_mbar_ptr, ) @staticmethod def _compute_grid( - o: cute.Tensor, + o_shape: cute.Shape, cta_tiler: Tuple[int, int, int], is_persistent: bool, ) -> Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]: - o_shape = o.shape tile_sched_params = create_fmha_static_tile_scheduler_params( is_persistent, ( @@ -2484,8 +2620,8 @@ class BlackwellFusedMultiHeadAttentionForward: def run_fmha_and_verify( - q_shape: Tuple[int, int, int, int], - k_shape: Tuple[int, int, int, int], + q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int], + k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int], in_dtype: Type[cutlass.Numeric], out_dtype: Type[cutlass.Numeric], qk_acc_dtype: Type[cutlass.Numeric], @@ -2515,11 +2651,13 @@ def run_fmha_and_verify( for maximum throughput. :param q_shape: Query tensor shape (B, S_q, H, D) where B=batch size, S_q=query sequence length, - H=number of heads, D=head dimension - :type q_shape: Tuple[int, int, int, int] + H=number of heads, D=head dimension. + If S_q is a tuple, it is the variable sequence length. + :type q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int] :param k_shape: Key tensor shape (B, S_k, H_k, D) where B=batch size, S_k=key sequence length, - H_k=number of key heads (H must be divisible by H_k), D=head dimension - :type k_shape: Tuple[int, int, int, int] + H_k=number of key heads (H must be divisible by H_k), D=head dimension. + If S_k is a tuple, it is the variable sequence length. + :type k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int] :param in_dtype: Input data type for query, key and value tensors :type in_dtype: Type[cutlass.Numeric] :param out_dtype: Output data type for attention output @@ -2575,7 +2713,7 @@ def run_fmha_and_verify( print(f" tolerance: {tolerance}") # Unpack parameters - b, s_q, h, d = q_shape + b, s_q, h_q, d = q_shape b_, s_k, h_k, d_ = k_shape if b != b_: @@ -2587,8 +2725,13 @@ def run_fmha_and_verify( if d not in {32, 64, 128}: raise ValueError("head dimension must be 32, 64, or 128") - if h % h_k != 0: - raise ValueError("h must be divisible by h_k") + if h_q % h_k != 0: + raise ValueError("h_q must be divisible by h_k") + + if isinstance(s_q, tuple) and len(s_q) != b: + raise ValueError("variable_seqlen s_q must have the length of batch size") + if isinstance(s_k, tuple) and len(s_k) != b: + raise ValueError("variable_seqlen s_k must have the length of batch size") if in_dtype not in {cutlass.Float8E4M3FN, cutlass.Float16}: raise ValueError("in_dtype must be Float8E4M3FN or Float16") @@ -2596,16 +2739,16 @@ def run_fmha_and_verify( if out_dtype not in {cutlass.Float8E4M3FN, cutlass.Float16}: raise ValueError("out_dtype must be Float8E4M3FN or Float16") - if qk_acc_dtype not in {cutlass.Float32}: + if qk_acc_dtype not in {Float32}: raise ValueError("qk_acc_dtype must be Float32") - if pv_acc_dtype not in {cutlass.Float32}: + if pv_acc_dtype not in {Float32}: raise ValueError("pv_acc_dtype must be Float32") if iterations < 1: raise ValueError("iterations must be at least 1") - h_r = h // h_k + h_r = h_q // h_k # Prepare pytorch tensors: Q, K, V (random from 0 to 2) and O (all zero) if not torch.cuda.is_available(): @@ -2613,56 +2756,127 @@ def run_fmha_and_verify( torch.manual_seed(1111) - def create_and_permute_tensor(b, s, h_r, h_k, d, dtype, is_dynamic_layout=True): - # (b, s, h_r, h_k, d) -> (s, d, h_r, h_k, b) - shape = (b, s, h_r, h_k, d) - permute_order = (1, 4, 2, 3, 0) - is_fp8 = dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} + def create_cumulative_sequence_lengths(s): + s_cumsum = [0] + for i in range(len(s)): + s_cumsum.append(s_cumsum[-1] + s[i]) - # torch does not support fp8 type - torch_dtype = cutlass.torch.dtype(dtype) if not is_fp8 else torch.uint8 - - # Create dtype torch tensor (cpu) - torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( - shape, - torch_dtype, - permute_order=permute_order, - init_type=cutlass.torch.TensorInitType.RANDOM, - init_config=cutlass.torch.RandomInitConfig( - min_val=0 if is_fp8 else -2, max_val=2 - ), + s_cumsum_cute_tensor, s_cumsum_torch_tensor = cutlass_torch.cute_tensor_like( + torch.tensor(s_cumsum, dtype=torch.int32), + Int32, + is_dynamic_layout=True, + assumed_align=16, ) - # Create dtype torch tensor (gpu) - torch_tensor_gpu = torch_tensor_cpu.cuda() + + return s_cumsum_cute_tensor, s_cumsum_torch_tensor + + cum_seqlen_q, cum_seqlen_q_torch = ( + create_cumulative_sequence_lengths(s_q) + if isinstance(s_q, tuple) + else (None, None) + ) + cum_seqlen_k, cum_seqlen_k_torch = ( + create_cumulative_sequence_lengths(s_k) + if isinstance(s_k, tuple) + else (None, None) + ) + + def create_and_pad_tensor( + shape, padding, dtype, s_cumsum=None, is_dynamic_layout=True + ): + # (b, s, h, d) + shape_ = tuple(map(lambda x, y: x + y, shape, padding)) + if s_cumsum is not None: + if shape_[0] != 1 or padding[0] != 0: + raise ValueError("Invalid tensor creation for variable sequence length") + # (s_total + padding, h, d) + shape_ = shape_[1:] + padding = padding[1:] # Create f32 torch tensor (cpu) - f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) - - # Create dtype cute tensor (gpu) - cute_tensor = from_dlpack(torch_tensor_gpu, assumed_align=16) - cute_tensor.element_type = dtype - if is_dynamic_layout: - cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=1) - cute_tensor = cutlass_torch.convert_cute_tensor( - f32_torch_tensor, - cute_tensor, + f32_torch_tensor_full = cutlass_torch.create_and_permute_torch_tensor( + shape_, + torch.float32, + permute_order=None, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=cutlass.torch.RandomInitConfig( + min_val=-2 if dtype.is_float or dtype.signed else 0, max_val=2 + ), + ) + # Create dtype cute & torch tensor (gpu) + _, torch_tensor_full = cutlass_torch.cute_tensor_like( + f32_torch_tensor_full, dtype, - is_dynamic_layout=is_dynamic_layout, + is_dynamic_layout, + assumed_align=16, ) - return f32_torch_tensor, cute_tensor, torch_tensor_gpu + # Offset the tensor + slices = tuple(slice(s, e) for s, e in zip(padding, shape_)) + torch_tensor = torch_tensor_full[slices].detach() + f32_torch_tensor = f32_torch_tensor_full[slices].detach() + torch_tensor._keep_alive = torch_tensor_full + f32_torch_tensor._keep_alive = f32_torch_tensor_full - q_ref, q_tensor, q_torch = create_and_permute_tensor( - b, s_q, h_r, h_k, d, in_dtype, is_dynamic_layout=True + # Create dtype cute tensor with offset (gpu) + cute_tensor = from_dlpack(torch_tensor, assumed_align=16) + cute_tensor.element_type = dtype + + # From ragged to jagged + if s_cumsum is not None: + torch_tensor = torch.nested.nested_tensor_from_jagged( + values=torch_tensor, offsets=s_cumsum + ) + f32_torch_tensor = torch.nested.nested_tensor_from_jagged( + values=f32_torch_tensor, offsets=s_cumsum.cpu() + ) + + return ( + f32_torch_tensor, + cute_tensor, + torch_tensor, + ) + + qo_shape = (b, s_q, h_r * h_k, d) + kv_shape = (b, s_k, h_k, d) + qo_padding = (0, 0, 0, 0, 0) + kv_padding = (0, 0, 0, 0, 0) + + if isinstance(s_q, tuple): + qo_shape = (1, sum(s_q), h_r * h_k, d) + qo_padding = (0, max(s_q), 0, 0, 0) + + if isinstance(s_k, tuple): + kv_shape = (1, sum(s_k), h_k, d) + kv_padding = (0, max(s_k), 0, 0, 0) + + q_ref, q_tensor, q_torch = create_and_pad_tensor( + qo_shape, + qo_padding, + in_dtype, + s_cumsum=cum_seqlen_q_torch, + is_dynamic_layout=True, ) - k_ref, k_tensor, k_torch = create_and_permute_tensor( - b, s_k, 1, h_k, d, in_dtype, is_dynamic_layout=True + k_ref, k_tensor, k_torch = create_and_pad_tensor( + kv_shape, + kv_padding, + in_dtype, + s_cumsum=cum_seqlen_k_torch, + is_dynamic_layout=True, ) - v_ref, v_tensor, v_torch = create_and_permute_tensor( - b, s_k, 1, h_k, d, in_dtype, is_dynamic_layout=True + v_ref, v_tensor, v_torch = create_and_pad_tensor( + kv_shape, + kv_padding, + in_dtype, + s_cumsum=cum_seqlen_k_torch, + is_dynamic_layout=True, ) - o_ref, o_tensor, o_torch = create_and_permute_tensor( - b, s_q, h_r, h_k, d, out_dtype, is_dynamic_layout=True + _, o_tensor, o_torch = create_and_pad_tensor( + qo_shape, + qo_padding, + out_dtype, + s_cumsum=cum_seqlen_q_torch, + is_dynamic_layout=True, ) mma_tiler = (*mma_tiler_mn, d) @@ -2671,8 +2885,13 @@ def run_fmha_and_verify( if has_casual_mask: mask_type = MaskType.CAUSAL_MASK else: - if s_k % mma_tiler_mn[1] != 0: - mask_type = MaskType.RESIDUAL_MASK + if isinstance(s_k, tuple): + for i in range(len(s_k)): + if s_k[i] % mma_tiler_mn[1] != 0: + mask_type = MaskType.RESIDUAL_MASK + else: + if s_k % mma_tiler_mn[1] != 0: + mask_type = MaskType.RESIDUAL_MASK fmha = BlackwellFusedMultiHeadAttentionForward( qk_acc_dtype, @@ -2682,13 +2901,11 @@ def run_fmha_and_verify( mask_type, ) - # Get current CUDA stream from PyTorch - torch_stream = torch.cuda.current_stream() - # Get the raw stream pointer as a CUstream - current_stream = cuda.CUstream(torch_stream.cuda_stream) + # Initialize Stream + current_stream = cutlass_torch.default_stream() if scale_softmax == 0.0: # default to 1/sqrt(d) - scale_softmax = 1.0 / math.sqrt(q_shape[1]) + scale_softmax = 1.0 / math.sqrt(d) log2_e = math.log2( math.exp(1.0) ) # gpu uses exp2 for perf concerns, we need an extra factor 'log2_e' here @@ -2697,15 +2914,27 @@ def run_fmha_and_verify( scale_softmax_log2 = scale_softmax * log2_e scale_output = scale_v * inv_scale_o + problem_size = ( + b, + max(s_q) if isinstance(s_q, tuple) else s_q, + max(s_k) if isinstance(s_k, tuple) else s_k, + h_q, + h_k, + d, + ) + print("Compiling kernel with cute.compile ...") start_time = time.time() # compile fmha kernel compiled_fmha = cute.compile( fmha, - q_tensor, - k_tensor, - v_tensor, - o_tensor, + q_tensor.iterator, + k_tensor.iterator, + v_tensor.iterator, + o_tensor.iterator, + problem_size, + cum_seqlen_q, + cum_seqlen_k, scale_softmax_log2, scale_output, current_stream, @@ -2716,10 +2945,13 @@ def run_fmha_and_verify( # Warmup for _ in range(warmup_iterations): compiled_fmha( - q_tensor, - k_tensor, - v_tensor, - o_tensor, + q_tensor.iterator, + k_tensor.iterator, + v_tensor.iterator, + o_tensor.iterator, + problem_size, + cum_seqlen_q, + cum_seqlen_k, scale_softmax_log2, scale_output, current_stream, @@ -2728,10 +2960,13 @@ def run_fmha_and_verify( # Execute kernel for _ in range(iterations): compiled_fmha( - q_tensor, - k_tensor, - v_tensor, - o_tensor, + q_tensor.iterator, + k_tensor.iterator, + v_tensor.iterator, + o_tensor.iterator, + problem_size, + cum_seqlen_q, + cum_seqlen_k, scale_softmax_log2, scale_output, current_stream, @@ -2742,84 +2977,127 @@ def run_fmha_and_verify( def run_torch_fmha( q, k, v, scale_softmax=1.0, scale_output=1.0, has_casual_mask=False ): - s_q, d, h_r, h_k, b = q.shape - s_k = k.shape[0] + h_q = q.shape[2] + h_k = k.shape[2] - # broadcast k and v to have the same shape as q - k = k.expand(s_k, d, h_r, h_k, b) - v = v.expand(s_k, d, h_r, h_k, b) + if not h_q == h_k: + repeat_factor = h_q // h_k + # nested tensor can not be broadcasted directly + if k.is_nested: + k_offsets = k.offsets() + v_offsets = v.offsets() + k_values = k.values().repeat(1, repeat_factor, 1) + v_values = v.values().repeat(1, repeat_factor, 1) - q_tmp = q.permute(4, 2, 3, 0, 1).contiguous().view(b, -1, s_q, d) - k_tmp = k.permute(4, 2, 3, 0, 1).contiguous().view(b, -1, s_k, d) - v_tmp = v.permute(4, 2, 3, 0, 1).contiguous().view(b, -1, s_k, d) + k = torch.nested.nested_tensor_from_jagged( + values=k_values, offsets=k_offsets + ) + v = torch.nested.nested_tensor_from_jagged( + values=v_values, offsets=v_offsets + ) + else: + k = k.repeat(1, 1, repeat_factor, 1) + v = v.repeat(1, 1, repeat_factor, 1) - ref = F.scaled_dot_product_attention( - q_tmp, - k_tmp, - v_tmp, - attn_mask=None, - dropout_p=0.0, - scale=scale_softmax, - is_causal=has_casual_mask, + # as we initialize q, k, v with shape (b, s, h, d) and SDPA of torch needs them to be (b, h, s, d) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # For the situation that torch has not supported, we need to handle it manually + situation1 = has_casual_mask and (q.is_nested or k.is_nested) + situation2 = (q.is_nested and not k.is_nested) or ( + not q.is_nested and k.is_nested ) - ref = ref.view(b, h_r, h_k, s_q, d).permute(3, 4, 1, 2, 0) * scale_output + if situation1 or situation2: + # Once torch supports the situation, we can remove this fallback + batch_size = q.size(0) + ref_list = [] + for batch_idx in range(batch_size): + q_i = q[batch_idx] + k_i = k[batch_idx] + v_i = v[batch_idx] + ref_i = F.scaled_dot_product_attention( + q_i, + k_i, + v_i, + attn_mask=None, + dropout_p=0.0, + scale=scale_softmax, + is_causal=has_casual_mask, + ) + ref_list.append(ref_i) + if q.is_nested: + ref = torch.nested.nested_tensor(ref_list, layout=torch.jagged) + else: + ref = torch.stack(ref_list) + else: + ref = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + scale=scale_softmax, + is_causal=has_casual_mask, + ) + ref = ref.transpose(1, 2) * scale_output return ref if not skip_ref_check: print("Verifying results...") - ref = run_torch_fmha( + o_ref = run_torch_fmha( q_ref, k_ref, v_ref, scale_softmax, scale_output, has_casual_mask ) - # Copy gpu result back - gpu_o = o_torch.cpu() + if o_ref.is_nested: + o_ref = o_ref.values() - # convert ref to out_type - if out_dtype == cutlass.Float16: - ref_o = ref.to(cutlass.torch.dtype(out_dtype)) - elif out_dtype in {cutlass.Float8E4M3FN, cutlass.Float8E5M2}: - # convert ref : f32 -> fp8 -> f32 - permute_order_0 = (4, 0, 2, 3, 1) - permute_order_1 = (1, 4, 2, 3, 0) + if o_torch.is_nested: + o_torch = o_torch.values() - shape = (b, s_q, h_r, h_k, d) + # convert o back to f32 for comparison + o_fp32, o_fp32_torch = cutlass_torch.cute_tensor_like( + torch.empty(*o_torch.shape, dtype=torch.float32), + Float32, + is_dynamic_layout=True, + assumed_align=16, + ) + cute.testing.convert(o_tensor, o_fp32) + o_result = o_fp32_torch.cpu() - f8_torch_tensor = cutlass.torch.create_and_permute_torch_tensor( - shape, - torch.uint8, - permute_order=permute_order_1, - init_type=cutlass.torch.TensorInitType.SKIP, - ).cuda() - - # Create dtype tensor (gpu) - ref_o_tensor = from_dlpack( - f8_torch_tensor, assumed_align=16 - ).mark_layout_dynamic(leading_dim=1) - ref_o_tensor.element_type = out_dtype - ref_o_tensor = cutlass.torch.convert_cute_tensor( - # ref for torch tensor is contiguous in shape (b, h_r, h_k, s_q, d), but shape is (s, d, h_r, h_k, b) - # need to make it contiguous first then permute - ref.permute(permute_order_0).contiguous().permute(permute_order_1).cuda(), - ref_o_tensor, + if out_dtype.is_float and out_dtype.width <= 8: + ref_narrow_precision, _ = cutlass_torch.cute_tensor_like( + torch.empty(*o_ref.shape, dtype=torch.uint8), out_dtype, is_dynamic_layout=True, + assumed_align=16, ) - ref_o = f8_torch_tensor.cpu() + ref_o_f32, ref_o_f32_torch = cutlass_torch.cute_tensor_like( + o_ref, + cutlass.Float32, + is_dynamic_layout=True, + assumed_align=16, + ) - # uint8 check; the minimum difference is 1 - tolerance = 2 - else: - pass + # convert ref : f32 -> fp4/fp8 -> f32 + cute.testing.convert(ref_o_f32, ref_narrow_precision) + cute.testing.convert(ref_narrow_precision, ref_o_f32) + + o_ref = ref_o_f32_torch.cpu() + + # override tolerance + tolerance = 0.13 # Assert close results - torch.testing.assert_close(gpu_o, ref_o, atol=tolerance, rtol=1e-05) + torch.testing.assert_close(o_result, o_ref, atol=tolerance, rtol=1e-05) print("Results verified successfully!") if __name__ == "__main__": - def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + def parse_comma_separated_ints(s: str): try: return tuple(int(x.strip()) for x in s.split(",")) except ValueError: @@ -2827,6 +3105,42 @@ if __name__ == "__main__": "Invalid format. Expected comma-separated integers." ) + def parse_nested_comma_separated_ints(s: str): + try: + s = s.strip() + if "(" not in s: + return tuple(int(x.strip()) for x in s.split(",")) + + start = s.find("(") + end = s.find(")") + if start == -1 or end == -1: + raise ValueError("Mismatched parentheses") + + before = s[:start].strip().rstrip(",") + middle = s[start + 1 : end].strip() + after = s[end + 1 :].strip().lstrip(",") + + result = [] + if before: + result.extend(int(x.strip()) for x in before.split(",")) + + if middle: + nested_tuple = tuple(int(x.strip()) for x in middle.split(",")) + result.append(nested_tuple) + + if after: + result.extend(int(x.strip()) for x in after.split(",")) + + return tuple(result) + + except ValueError as e: + if str(e) == "Mismatched parentheses": + raise argparse.ArgumentTypeError("Mismatched parentheses in input") + else: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers with optional parentheses for nested tuple." + ) + parser = argparse.ArgumentParser(description="Example of FMHA on Blackwell.") parser.add_argument( @@ -2846,14 +3160,14 @@ if __name__ == "__main__": parser.add_argument( "--qk_acc_dtype", type=cutlass.dtype, - default=cutlass.Float32, + default=Float32, help="QK accumulator data type", ) parser.add_argument( "--pv_acc_dtype", type=cutlass.dtype, - default=cutlass.Float32, + default=Float32, help="PV accumulator data type", ) @@ -2878,14 +3192,14 @@ if __name__ == "__main__": parser.add_argument( "--q_shape", - type=parse_comma_separated_ints, + type=parse_nested_comma_separated_ints, default=(1, 256, 8, 128), help="Shape of Q (B, S_q, H, D)", ) parser.add_argument( "--k_shape", - type=parse_comma_separated_ints, + type=parse_nested_comma_separated_ints, default=(1, 256, 8, 128), help="Shape of K (B, S_k, H_k, D)", ) @@ -2960,6 +3274,11 @@ if __name__ == "__main__": if len(args.mma_tiler_mn) != 2: parser.error("--mma_tiler_mn must contain exactly 2 values") + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + run_fmha_and_verify( args.q_shape, args.k_shape, diff --git a/examples/python/CuTeDSL/blackwell/grouped_gemm.py b/examples/python/CuTeDSL/blackwell/grouped_gemm.py index d2e6f9ab..3b67bf3b 100644 --- a/examples/python/CuTeDSL/blackwell/grouped_gemm.py +++ b/examples/python/CuTeDSL/blackwell/grouped_gemm.py @@ -40,7 +40,6 @@ import cutlass.utils as utils from cutlass.cute.nvgpu import cpasync, tcgen05 import cutlass.utils.blackwell_helpers as sm100_utils import cutlass.torch as cutlass_torch -from cutlass.cute.runtime import from_dlpack """ A grouped GEMM example for the NVIDIA Blackwell SM100 architecture using CUTE DSL @@ -89,7 +88,6 @@ there are also the following constrains: class GroupedGemmKernel: - def __init__( self, acc_dtype: type[cutlass.Numeric], @@ -159,7 +157,7 @@ class GroupedGemmKernel: self.tmem_ptr_sync_bar_id = 2 # Barrier ID used by MMA/TMA warps to signal A/B tensormap initialization completion self.tensormap_ab_init_bar_id = 4 - self.num_smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"] + self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"] self.num_tma_load_bytes = 0 def _setup_attributes(self): @@ -217,18 +215,20 @@ class GroupedGemmKernel: ) # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory - self.num_acc_stage, self.num_ab_stage, self.num_epi_stage = ( - self._compute_stages( - tiled_mma, - self.mma_tiler, - self.a_dtype, - self.b_dtype, - self.epi_tile, - self.c_dtype, - self.c_layout, - self.num_smem_capacity, - self.occupancy, - ) + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_epi_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.smem_capacity, + self.occupancy, ) self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( @@ -355,9 +355,11 @@ class GroupedGemmKernel: atom_thr_size = cute.size(tiled_mma.thr_id.shape) # Setup TMA load for A - a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast) + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) - tma_atom_a, tma_tensor_a = cute.nvgpu.make_tma_tile_atom_A( + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( a_op, initial_a, a_smem_layout, @@ -367,9 +369,11 @@ class GroupedGemmKernel: ) # Setup TMA load for B - b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast) + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) - tma_atom_b, tma_tensor_b = cute.nvgpu.make_tma_tile_atom_B( + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( b_op, initial_b, b_smem_layout, @@ -389,7 +393,7 @@ class GroupedGemmKernel: cute.make_identity_layout(initial_c.shape), self.epi_tile ) epi_smem_layout = cute.slice_(self.epi_smem_layout_staged, (None, None, 0)) - tma_atom_c, tma_tensor_c = cpasync.make_tma_tile_atom( + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), initial_c, epi_smem_layout, @@ -403,9 +407,7 @@ class GroupedGemmKernel: self.buffer_align_bytes = 1024 self.size_tensormap_in_i64 = ( 0 - if cutlass.const_expr( - self.tensormap_update_mode == utils.TensorMapUpdateMode.GMEM - ) + if self.tensormap_update_mode == utils.TensorMapUpdateMode.GMEM else GroupedGemmKernel.num_tensormaps * GroupedGemmKernel.bytes_per_tensormap // 8 @@ -564,16 +566,16 @@ class GroupedGemmKernel: for k_stage in range(self.num_ab_stage): num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 with cute.arch.elect_one(): - cute.arch.mbarrier_init_arrive_cnt(ab_full_mbar_ptr + k_stage, 1) - cute.arch.mbarrier_init_arrive_cnt( + cute.arch.mbarrier_init(ab_full_mbar_ptr + k_stage, 1) + cute.arch.mbarrier_init( ab_empty_mbar_ptr + k_stage, num_tma_producer ) # Accumulator barrier init if warp_idx == self.mma_warp_id: for acc_stage in range(self.num_acc_stage): with cute.arch.elect_one(): - cute.arch.mbarrier_init_arrive_cnt(acc_full_mbar_ptr + acc_stage, 1) - cute.arch.mbarrier_init_arrive_cnt( + cute.arch.mbarrier_init(acc_full_mbar_ptr + acc_stage, 1) + cute.arch.mbarrier_init( acc_empty_mbar_ptr + acc_stage, 8 if use_2cta_instrs else 4 ) # Tensor memory dealloc barrier init @@ -581,7 +583,7 @@ class GroupedGemmKernel: if warp_idx == self.tma_warp_id: num_tmem_dealloc_threads = 32 with cute.arch.elect_one(): - cute.arch.mbarrier_init_arrive_cnt( + cute.arch.mbarrier_init( tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads ) cute.arch.mbarrier_init_fence() @@ -612,7 +614,7 @@ class GroupedGemmKernel: a_full_mcast_mask = None b_full_mcast_mask = None ab_empty_mcast_mask = None - if self.is_a_mcast or self.is_b_mcast or use_2cta_instrs: + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): a_full_mcast_mask = cpasync.create_tma_multicast_mask( cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 ) @@ -621,7 +623,7 @@ class GroupedGemmKernel: ) ab_empty_mcast_mask = a_full_mcast_mask | b_full_mcast_mask acc_full_mcast_mask = None - if use_2cta_instrs: + if cutlass.const_expr(use_2cta_instrs): acc_full_mcast_mask = cute.make_layout_image_mask( cluster_layout_vmnk, block_in_cluster_coord_vmnk, mode=0 ) @@ -646,15 +648,15 @@ class GroupedGemmKernel: # # Local_tile partition global tensors # - # (bM, bK, loopM, loopK, loopL) + # (bM, bK, RestM, RestK, RestL) gA_mkl = cute.local_tile( mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) ) - # (bN, bK, loopN, loopK, loopL) + # (bN, bK, RestN, RestK, RestL) gB_nkl = cute.local_tile( mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) ) - # (bM, bN, loopM, loopN, loopL) + # (bM, bN, RestM, RestN, RestL) gC_mnl = cute.local_tile( mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) ) @@ -663,11 +665,11 @@ class GroupedGemmKernel: # Partition global tensor for TiledMMA_A/B/C # thr_mma = tiled_mma.get_slice(mma_tile_coord_v) - # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) tCgA = thr_mma.partition_A(gA_mkl) - # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) tCgB = thr_mma.partition_B(gB_nkl) - # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) tCgC = thr_mma.partition_C(gC_mnl) # @@ -677,7 +679,7 @@ class GroupedGemmKernel: cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape ) # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), loopM, loopK, loopL) + # ((atom_v, rest_v), RestM, RestK, RestL) tAsA, tAgA = cpasync.tma_partition( tma_atom_a, block_in_cluster_coord_vmnk[2], @@ -690,7 +692,7 @@ class GroupedGemmKernel: cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape ) # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), loopM, loopK, loopL) + # ((atom_v, rest_v), RestM, RestK, RestL) tBsB, tBgB = cpasync.tma_partition( tma_atom_b, block_in_cluster_coord_vmnk[1], @@ -849,11 +851,11 @@ class GroupedGemmKernel: # # Slice to per mma tile index # - # ((atom_v, rest_v), loopK) + # ((atom_v, rest_v), RestK) tAgA_slice = tAgA[ (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) ] - # ((atom_v, rest_v), loopK) + # ((atom_v, rest_v), RestK) tBgB_slice = tBgB[ (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) ] @@ -867,7 +869,7 @@ class GroupedGemmKernel: tma_wr_ab_empty_phase = ( num_prev_k_blk + tma_wr_k_block ) // self.num_ab_stage % 2 ^ 1 - peek_ab_empty_status = cute.arch.conditional_mbarrier_try_wait( + peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait( tma_wr_k_block < cur_k_block_cnt, ab_empty_mbar_ptr + smem_wr_buffer, tma_wr_ab_empty_phase, @@ -879,7 +881,7 @@ class GroupedGemmKernel: # # Tma load loop # - for k_block in cutlass.range_dynamic(0, cur_k_block_cnt, 1, unroll=1): + for k_block in cutlass.range(0, cur_k_block_cnt, 1, unroll=1): tma_wr_k_block_next = tma_wr_k_block + 1 smem_wr_buffer_next = ( num_prev_k_blk + tma_wr_k_block_next @@ -898,10 +900,10 @@ class GroupedGemmKernel: ab_empty_mbar_ptr + smem_wr_buffer, tma_wr_ab_empty_phase ) - # Init AB buffer full transaction byte + # Arrive AB buffer and expect full transaction bytes if is_leader_cta: with cute.arch.elect_one(): - cute.arch.mbarrier_init_tx_bytes( + cute.arch.mbarrier_arrive_and_expect_tx( smem_full_mbar_ptr, self.num_tma_load_bytes ) @@ -930,7 +932,7 @@ class GroupedGemmKernel: ) # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1 - peek_ab_empty_status = cute.arch.conditional_mbarrier_try_wait( + peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait( tma_wr_k_block_next < cur_k_block_cnt, ab_empty_mbar_ptr + smem_wr_buffer_next, tma_wr_ab_empty_phase_next, @@ -999,11 +1001,12 @@ class GroupedGemmKernel: while work_tile.is_valid_tile: cur_tile_coord = work_tile.tile_idx # MMA warp is only interested in number of tiles along K dimension - cur_k_block_cnt, cur_group_idx = ( - group_gemm_ts_helper.search_cluster_tile_count_k( - cur_tile_coord, - problem_sizes_mnkl, - ) + ( + cur_k_block_cnt, + cur_group_idx, + ) = group_gemm_ts_helper.search_cluster_tile_count_k( + cur_tile_coord, + problem_sizes_mnkl, ) # Set tensor memory buffer for current tile acc_buf_idx = tile_sched.num_tiles_executed % self.num_acc_stage @@ -1022,7 +1025,7 @@ class GroupedGemmKernel: mma_rd_ab_full_phase = ( (num_prev_k_blk + mma_rd_k_block) // self.num_ab_stage % 2 ) - peek_ab_full_status = cute.arch.conditional_mbarrier_try_wait( + peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait( need_check_rd_buffer_full, ab_full_mbar_ptr + smem_rd_buffer, mma_rd_ab_full_phase, @@ -1047,7 +1050,7 @@ class GroupedGemmKernel: # # Mma mainloop # - for k_block in cutlass.range_dynamic(0, cur_k_block_cnt, 1, unroll=1): + for k_block in range(cur_k_block_cnt): mma_rd_k_block_next = cutlass.Int32(k_block + 1) smem_rd_buffer_next = ( num_prev_k_blk + mma_rd_k_block_next @@ -1066,7 +1069,7 @@ class GroupedGemmKernel: # tCtAcc += tCrA * tCrB num_kphases = cute.size(tCrA, mode=[2]) - for kphase_idx in range(num_kphases): + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): kphase_coord = (None, None, kphase_idx, smem_rd_buffer) cute.gemm( @@ -1092,7 +1095,7 @@ class GroupedGemmKernel: mma_rd_k_block_next < cur_k_block_cnt and is_leader_cta ) - peek_ab_full_status = cute.arch.conditional_mbarrier_try_wait( + peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait( need_check_rd_buffer_full, ab_full_mbar_ptr + smem_rd_buffer_next, mma_rd_ab_full_phase_next, @@ -1161,19 +1164,23 @@ class GroupedGemmKernel: # # Partition for epilogue # - tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = ( - self.epilog_tmem_copy_and_partition( - epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs - ) + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs ) tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( tiled_copy_t2r, tTR_rC, epi_tidx, sC ) - tma_atom_c, bSG_sC, bSG_gC_partitioned = ( - self.epilog_gmem_copy_and_partition(tma_atom_c, tCgC, epi_tile, sC) - ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition(tma_atom_c, tCgC, epi_tile, sC) # # Persistent tile scheduling loop @@ -1270,7 +1277,7 @@ class GroupedGemmKernel: # subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt - for subtile_idx in cutlass.range_dynamic(subtile_cnt): + for subtile_idx in range(subtile_cnt): # # Load accumulator from tensor memory buffer to register # @@ -1493,11 +1500,11 @@ class GroupedGemmKernel: # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) - # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) gC_mnl_epi = cute.flat_divide( gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile ) - # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) # (T2R, T2R_M, T2R_N) tTR_rAcc = cute.make_fragment( @@ -1569,14 +1576,14 @@ class GroupedGemmKernel: - tCgC: The destination global memory tensor partitioned for the TMA operation. :rtype: tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] """ - # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) gC_epi = cute.flat_divide( gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile ) sC_for_tma_partition = cute.group_modes(sC, 0, 2) gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) # ((ATOM_V, REST_V), EPI_M, EPI_N) - # ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) bSG_sC, bSG_gC = cpasync.tma_partition( tma_atom_c, 0, @@ -1595,7 +1602,7 @@ class GroupedGemmKernel: epi_tile: cute.Tile, c_dtype: type[cutlass.Numeric], c_layout: utils.LayoutEnum, - num_smem_capacity: int, + smem_capacity: int, occupancy: int, ) -> tuple[int, int, int]: """Computes the number of stages for accumulator, A/B operands, and epilogue based on heuristics. @@ -1614,8 +1621,8 @@ class GroupedGemmKernel: :type c_dtype: type[cutlass.Numeric] :param c_layout: Layout enum of operand C in global memory. :type c_layout: utils.LayoutEnum - :param num_smem_capacity: Total available shared memory capacity in bytes. - :type num_smem_capacity: int + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int :param occupancy: Target number of CTAs per SM (occupancy). :type occupancy: int @@ -1658,7 +1665,7 @@ class GroupedGemmKernel: # Subtract reserved bytes and initial epilogue bytes # Divide remaining by bytes needed per A/B stage num_ab_stage = ( - num_smem_capacity // occupancy + smem_capacity // occupancy - GroupedGemmKernel.reserved_smem_bytes - epi_bytes ) // ab_bytes_per_stage @@ -1667,7 +1674,7 @@ class GroupedGemmKernel: # Calculate remaining smem after allocating for A/B stages and reserved bytes # Add remaining unused smem to epilogue remaining_smem = ( - num_smem_capacity + smem_capacity - occupancy * ab_bytes_per_stage * num_ab_stage - occupancy * (GroupedGemmKernel.reserved_smem_bytes + epi_bytes) ) @@ -1775,20 +1782,6 @@ class GroupedGemmKernel: epi_bytes = cute.size_in_bytes(c_dtype, epi_smem_layout_staged) return ab_bytes + epi_bytes - @staticmethod - def _get_tma_atom_kind(atom_sm_cnt: int, mcast: bool): - """Select the appropriate TMA copy atom based on the number of SMs and the multicast flag.""" - if atom_sm_cnt == 2 and mcast: - return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO) - elif atom_sm_cnt == 2 and not mcast: - return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO) - elif atom_sm_cnt == 1 and mcast: - return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE) - elif atom_sm_cnt == 1 and not mcast: - return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) - - raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}") - @staticmethod def _compute_num_tmem_alloc_cols( tiled_mma: cute.TiledMma, @@ -1909,8 +1902,6 @@ def run_grouped_gemm( if not torch.cuda.is_available(): raise RuntimeError("GPU is required to run this example!") - torch.manual_seed(2025) - # Create tensor and return the pointer, tensor, and stride def create_tensor_and_stride( l: int, @@ -1920,42 +1911,17 @@ def run_grouped_gemm( dtype: type[cutlass.Numeric], is_dynamic_layout: bool = True, ) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]: - # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) - # else: (l, mode0, mode1) -> (mode0, mode1, l) - shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) - permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) - # omit stride for L mode as it is always 1 for grouped GEMM - strides = (1, mode0) if is_mode0_major else (mode1, 1) - assert dtype in {cutlass.Float16, cutlass.BFloat16, cutlass.Float32} - is_unsigned = False - - torch_dtype = cutlass_torch.dtype(dtype) - torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( - shape, - torch_dtype, - permute_order=permute_order, - init_type=cutlass_torch.TensorInitType.RANDOM, - init_config=cutlass_torch.RandomInitConfig( - min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 - ), + torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype) + cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like( + torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16 ) - torch_tensor = torch_tensor_cpu.cuda() - f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) - - cute_tensor = from_dlpack(torch_tensor, assumed_align=16) - if is_dynamic_layout: - cute_tensor = cute_tensor.mark_layout_dynamic( - leading_dim=(0 if is_mode0_major else 1) - ) - cute_tensor = cutlass_torch.convert_cute_tensor( - f32_torch_tensor, + return ( + torch_tensor.data_ptr(), + torch_tensor, cute_tensor, - dtype, - is_dynamic_layout=is_dynamic_layout, + torch_tensor_cpu, + torch_tensor.stride()[:-1], ) - # Get pointer of the tensor - ptr = torch_tensor.data_ptr() - return ptr, torch_tensor, cute_tensor, f32_torch_tensor, strides # iterate all groups and create tensors for each group torch_fp32_tensors_abc = [] @@ -1964,15 +1930,27 @@ def run_grouped_gemm( strides_abc = [] ptrs_abc = [] for _, (m, n, k, l) in enumerate(problem_sizes_mnkl): - ptr_a, torch_tensor_a, cute_tensor_a, tensor_fp32_a, stride_mk_a = ( - create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype) - ) - ptr_b, torch_tensor_b, cute_tensor_b, tensor_fp32_b, stride_nk_b = ( - create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype) - ) - ptr_c, torch_tensor_c, cute_tensor_c, tensor_fp32_c, stride_mn_c = ( - create_tensor_and_stride(l, m, n, c_major == "m", c_dtype) - ) + ( + ptr_a, + torch_tensor_a, + cute_tensor_a, + tensor_fp32_a, + stride_mk_a, + ) = create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype) + ( + ptr_b, + torch_tensor_b, + cute_tensor_b, + tensor_fp32_b, + stride_nk_b, + ) = create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype) + ( + ptr_c, + torch_tensor_c, + cute_tensor_c, + tensor_fp32_c, + stride_mn_c, + ) = create_tensor_and_stride(l, m, n, c_major == "m", c_dtype) ptrs_abc.append([ptr_a, ptr_b, ptr_c]) torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c]) torch_fp32_tensors_abc.append([tensor_fp32_a, tensor_fp32_b, tensor_fp32_c]) @@ -2005,19 +1983,16 @@ def run_grouped_gemm( ) # Prepare tensormap buffer for each SM num_tensormap_buffers = sm_count - tensormap_pytorch_tensor = ( - torch.empty( - ( - num_tensormap_buffers, - GroupedGemmKernel.num_tensormaps, - GroupedGemmKernel.bytes_per_tensormap // 8, - ), - dtype=torch.int64, - ) - .fill_(0) - .cuda() + tensormap_shape = ( + num_tensormap_buffers, + GroupedGemmKernel.num_tensormaps, + GroupedGemmKernel.bytes_per_tensormap // 8, + ) + tensor_of_tensormap, tensor_of_tensormap_torch = cutlass_torch.cute_tensor_like( + torch.empty(tensormap_shape, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, ) - tensormap_cute_tensor = from_dlpack(tensormap_pytorch_tensor, assumed_align=16) grouped_gemm = GroupedGemmKernel( acc_dtype, @@ -2027,23 +2002,30 @@ def run_grouped_gemm( tensormap_update_mode, ) - # Convert integer list to torch tensor and cute tensor - def convert_list_to_tensor(l, dtype) -> tuple[torch.Tensor, cute.Tensor]: - torch_tensor = torch.tensor(l, dtype=dtype).cuda() - cute_tensor = from_dlpack(torch_tensor, assumed_align=16) - return torch_tensor, cute_tensor - # layout (num_groups, 4):(4, 1) - problem_sizes_mnkl_torch_tensor, problem_sizes_mnkl_cute_tensor = ( - convert_list_to_tensor(problem_sizes_mnkl, torch.int32) + ( + tensor_of_dim_size_mnkl, + tensor_of_dim_size_mnkl_torch, + ) = cutlass_torch.cute_tensor_like( + torch.tensor(problem_sizes_mnkl, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, ) # layout (num_groups, 3, 2):(6, 2, 1) - strides_abc_torch_tensor, strides_abc_cute_tensor = convert_list_to_tensor( - strides_abc, torch.int32 + tensor_of_strides_abc, tensor_of_strides_abc_torch = cutlass_torch.cute_tensor_like( + torch.tensor(strides_abc, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, ) + # layout (num_groups,3):(3, 1) - ptrs_abc_torch_tensor, ptrs_abc_cute_tensor = convert_list_to_tensor( - ptrs_abc, torch.int64 + tensor_of_ptrs_abc, tensor_of_ptrs_abc_torch = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_abc, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, ) # Compute total number of cluster tiles we need to compute for given grouped GEMM problem @@ -2077,10 +2059,9 @@ def run_grouped_gemm( problem_sizes_mnkl, cluster_tile_shape_mn ) - # Get current CUDA stream from PyTorch - torch_stream = torch.cuda.current_stream() - # Get the raw stream pointer as a CUstream - current_stream = cuda.CUstream(torch_stream.cuda_stream) + # Initialize Stream + current_stream = cutlass_torch.default_stream() + # Compile grouped GEMM kernel compiled_grouped_gemm = cute.compile( grouped_gemm, @@ -2088,11 +2069,11 @@ def run_grouped_gemm( initial_cute_tensors_abc[1], initial_cute_tensors_abc[2], num_groups, - problem_sizes_mnkl_cute_tensor, - strides_abc_cute_tensor, - ptrs_abc_cute_tensor, + tensor_of_dim_size_mnkl, + tensor_of_strides_abc, + tensor_of_ptrs_abc, total_num_clusters, - tensormap_cute_tensor, + tensor_of_tensormap, max_active_clusters, current_stream, ) @@ -2104,10 +2085,10 @@ def run_grouped_gemm( initial_cute_tensors_abc[0], initial_cute_tensors_abc[1], initial_cute_tensors_abc[2], - problem_sizes_mnkl_cute_tensor, - strides_abc_cute_tensor, - ptrs_abc_cute_tensor, - tensormap_cute_tensor, + tensor_of_dim_size_mnkl, + tensor_of_strides_abc, + tensor_of_ptrs_abc, + tensor_of_tensormap, current_stream, ) # Execution @@ -2116,28 +2097,27 @@ def run_grouped_gemm( initial_cute_tensors_abc[0], initial_cute_tensors_abc[1], initial_cute_tensors_abc[2], - problem_sizes_mnkl_cute_tensor, - strides_abc_cute_tensor, - ptrs_abc_cute_tensor, - tensormap_cute_tensor, + tensor_of_dim_size_mnkl, + tensor_of_strides_abc, + tensor_of_ptrs_abc, + tensor_of_tensormap, current_stream, ) + torch.cuda.synchronize() + # Compute reference result if not skip_ref_check: - refs = [] - for a, b, _ in torch_fp32_tensors_abc: - ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu() - refs.append(ref) - for i, ((_, _, c), ref) in enumerate(zip(torch_tensors_abc, refs)): + for i, (a, b, c) in enumerate(torch_tensors_abc): + ref = torch.einsum( + "mkl,nkl->mnl", + a.cpu().to(dtype=torch.float32), + b.cpu().to(dtype=torch.float32), + ) print(f"checking group {i}") - if c_dtype == cutlass.Float32: - ref_c = ref - else: - ref_c = ref.to(cutlass_torch.dtype(c_dtype)) torch.testing.assert_close( c.cpu(), - ref_c, + ref.to(cutlass_torch.dtype(c_dtype)), atol=tolerance, rtol=1e-05, ) @@ -2266,6 +2246,8 @@ if __name__ == "__main__": else: tensormap_update_mode = utils.TensorMapUpdateMode.SMEM + torch.manual_seed(2025) + run_grouped_gemm( args.num_groups, args.problem_sizes_mnkl, diff --git a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py new file mode 100644 index 00000000..342d5580 --- /dev/null +++ b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py @@ -0,0 +1,3619 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import argparse +from typing import List, Type, Tuple, Optional +from cuda import cuda + +import torch +import torch.nn.functional as F + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack + +from .mamba2_ssd_reference import ( + ssd_reference_fp32_all, + ssd_reference_lowprecision_intermediates, + analyze_relative_diffs, +) + +from .mamba2_ssd_tile_scheduler import ( + Mamba2SSDTileSchedulerParams, + Mamba2SSDTileScheduler, +) + + +class SSDKernel: + def __init__( + self, + io_dtype: Type[cutlass.Numeric], + cumsum_delta_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + L: int, + D: int, + N: int, + has_d: bool, + d_has_hdim: bool, + ): + self.io_dtype: Type[cutlass.Numeric] = io_dtype + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.cumsum_delta_dtype: Type[cutlass.Numeric] = cumsum_delta_dtype + # has_d means epilog warp performs Y += X*D fusion + self.has_d: bool = has_d + # d_has_hdim = True means D is (D, EH) shape and loaded by TMA + # d_has_hdim = False means D is (1, EH) shape and loaded directly to register + self.d_has_hdim: bool = d_has_hdim + self.tile_shape = (L, D, N) + + assert io_dtype in { + cutlass.Float16, + cutlass.BFloat16, + }, "Do not support other I/O types." + assert acc_dtype in {cutlass.Float32}, "Do not support other ACC types." + assert cumsum_delta_dtype in { + cutlass.Float32 + }, "Do not support other cumsum types." + assert not (not has_d and d_has_hdim), "D cannot have Hdim if has_d is False" + + # Hardcode default setting + self.use_2cta_instrs = False + self.cluster_shape_mnk = (1, 1, 1) + self.epi_tile = (128, 32) + + # Setup mma tile shapes + self.tile_shape_mnk_intra1 = (L, L, N) + self.tile_shape_mnk_intra2 = (L, D, L) + self.tile_shape_mnk_inter1 = (N, D, L) + self.tile_shape_mnk_inter2 = (L, D, N) + + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + # Launch config + self.occupancy = 1 + self.mma_inter_warp_id = 0 + self.mma_intra_warp_id = 1 + self.tma_b_c_warp_id = 2 + self.tma_deltas_x_d_warp_id = 3 + self.pre_inter_warp_id = [4, 5, 6, 7] + self.pre_intra_warp_id = [8, 9, 10, 11] + self.epilog_warp_id = [12, 13, 14, 15] + self.threads_per_cta = 32 * len( + ( + self.mma_inter_warp_id, + self.mma_intra_warp_id, + self.tma_b_c_warp_id, + self.tma_deltas_x_d_warp_id, + *self.pre_inter_warp_id, + *self.pre_intra_warp_id, + *self.epilog_warp_id, + ) + ) + self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"] + + # Named barriers + self.pre_inter_sync_bar_id = 1 + self.epilog_sync_bar_id = 2 + self.pre_intra_sync_bar_id = 3 + self.tmem_dealloc_sync_bar_id = 4 + + # Number of registers used by each warp + self.num_regs_uniform_warps = 24 + self.num_regs_pre_inter_warps = 168 + self.num_regs_pre_intra_warps = 208 + self.num_regs_epilogue_warps = 112 + + # Shared storage + self.shared_storage = None + + # TMEM buffer offsets + self.tmem_intra1_acc_offset = 0 + self.tmem_intra2_q_offset = 0 + self.tmem_intra2_acc_offset = 0 + self.tmem_inter1_acc_offset = 0 + self.tmem_inter2_acc_offset = 0 + self.num_tmem_cols_total = 0 + + def _setup_attributes(self): + ( + tiled_mma_intra1, + tiled_mma_intra2, + tiled_mma_inter1, + tiled_mma_inter2, + ) = self.make_tiled_mmas( + self.io_dtype, + self.acc_dtype, + self.cta_group, + self.tile_shape_mnk_intra1, + self.tile_shape_mnk_intra2, + self.tile_shape_mnk_inter1, + self.tile_shape_mnk_inter2, + ) + + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (tiled_mma_intra1.thr_id.shape,), + ) + + # Setup stages + ( + self.input_stages, + self.output_stages, + self.internal_stages, + self.intra1_acc_stages, + ) = self._compute_stages( + self.smem_capacity, + ) + + # Setup smem layouts + # X is B operand (from smem) of INTRA2_MMA and INTER1_MMA + self.x_smem_layout = sm100_utils.make_smem_layout_b( + tiled_mma_intra2, + self.tile_shape_mnk_intra2, + self.io_dtype, + self.input_stages, + ) + self.num_x_load_bytes = cute.size_in_bytes( + self.io_dtype, cute.slice_(self.x_smem_layout, (None, None, None, 0)) + ) + + # XT is same shape as ACC operand of INTER2_MMA, before postprocessing by EPILOG + self.xt_smem_layout = sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.COL_MAJOR, + self.tile_shape_mnk_intra2[:2], + self.input_stages, + ) + + # B is B operand (from smem) of INTRA1_MMA + self.b_smem_layout = sm100_utils.make_smem_layout_b( + tiled_mma_intra1, + self.tile_shape_mnk_intra1, + self.io_dtype, + self.input_stages, + ) + self.num_b_load_bytes = cute.size_in_bytes( + self.io_dtype, cute.slice_(self.b_smem_layout, (None, None, None, 0)) + ) + + # B_INTERNAL is also A operand (from smem) of INTER1_MMA, after preprocessed by PRE_INTER + self.bt_internal_smem_layout = sm100_utils.make_smem_layout_a( + tiled_mma_inter1, + self.tile_shape_mnk_inter1, + self.io_dtype, + self.internal_stages, + ) + + # B needs to be proprocessed to be used as A operand of INTER1_MMA + self.bt_smem_layout = cute.coalesce( + sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.ROW_MAJOR, + (self.tile_shape_mnk_inter1[0], self.tile_shape_mnk_inter1[2]), + self.input_stages, + ), + target_profile=(1, 1, 1), + ) + + # C is A operand (from smem) of INTRA1_MMA and INTER2_MMA + self.c_smem_layout = sm100_utils.make_smem_layout_a( + tiled_mma_intra1, + self.tile_shape_mnk_intra1, + self.io_dtype, + self.input_stages, + ) + self.num_c_load_bytes = cute.size_in_bytes( + self.io_dtype, cute.slice_(self.c_smem_layout, (None, None, None, 0)) + ) + + # P is B operand (from smem) of INTER2_MMA, after preprocessed by PRE_INTER + self.p_smem_layout = sm100_utils.make_smem_layout_b( + tiled_mma_inter2, + self.tile_shape_mnk_inter2, + self.io_dtype, + self.internal_stages, + ) + + # PT is ACC operand (from tmem) of INTER1_MMA, after postprocessed by PRE_INTER + self.pt_smem_layout = sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.COL_MAJOR, + self.tile_shape_mnk_inter1[:2], + self.internal_stages, + ) + + # Q is A operand (from tmem) of INTRA2_MMA, after preprocessed by PRE_INTRA + self.q_tmem_layout = sm100_utils.make_smem_layout_a( + tiled_mma_intra2, + self.tile_shape_mnk_intra2, + self.io_dtype, + self.internal_stages, + ) + + # P is ACC operand (from tmem) of INTER1_MMA, to be TMA stored by PRE_INTER + self.p_smem_layout_store = sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.ROW_MAJOR, + self.tile_shape_mnk_inter2[1:], + self.internal_stages, + ) + + # Y is ACC operand (from smem) of INTER2_MMA and INTRA2_MMA, after postprocessed and TMA stored by EPILOG + self.y_smem_layout = sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.COL_MAJOR, + self.epi_tile, + self.output_stages, + ) + + # Delta is linear smem layouts for pre/post processing + self.delta_linear_smem_layout = cute.make_layout( + (self.tile_shape_mnk_inter1[2], self.input_stages) + ) + self.num_delta_load_bytes = cute.size_in_bytes( + self.io_dtype, cute.slice_(self.delta_linear_smem_layout, (None, 0)) + ) + + # Cumsum delta is linear smem layouts for pre/post processing + self.cumsum_delta_linear_smem_layout = cute.make_layout( + (self.tile_shape_mnk_inter1[2], self.input_stages) + ) + self.num_cumsum_delta_load_bytes = cute.size_in_bytes( + self.cumsum_delta_dtype, + cute.slice_(self.cumsum_delta_linear_smem_layout, (None, 0)), + ) + + # D is linear smem layouts when d_has_hdim is True + self.d_linear_smem_layout = ( + cute.make_layout((self.tile_shape_mnk_inter2[1], self.input_stages)) + if self.d_has_hdim + else None + ) + self.num_d_load_bytes = ( + cute.size_in_bytes( + self.io_dtype, + cute.slice_(self.d_linear_smem_layout, (None, 0)), + ) + if self.d_has_hdim + else 0 + ) + + # Setup tmem offsets + ( + self.tmem_intra1_acc_offset, + self.tmem_intra2_q_offset, + self.tmem_intra2_acc_offset, + self.tmem_inter1_acc_offset, + self.tmem_inter2_acc_offset, + self.num_tmem_cols_total, + ) = self._plan_tmem_offsets( + tiled_mma_intra1, + self.tile_shape_mnk_intra1, + tiled_mma_intra2, + self.tile_shape_mnk_intra2, + tiled_mma_inter1, + self.tile_shape_mnk_inter1, + tiled_mma_inter2, + self.tile_shape_mnk_inter2, + self.internal_stages, + self.q_tmem_layout, + self.io_dtype, + self.internal_stages, + self.intra1_acc_stages, + ) + + return + + @cute.jit + def __call__( + self, + x: cute.Tensor, + cumsum_delta: cute.Tensor, + delta: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + y: cute.Tensor, + fstate: cute.Tensor, + d: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + ): + self._setup_attributes() + ( + tiled_mma_intra1, + tiled_mma_intra2, + tiled_mma_inter1, + tiled_mma_inter2, + ) = self.make_tiled_mmas( + self.io_dtype, + self.acc_dtype, + self.cta_group, + self.tile_shape_mnk_intra1, + self.tile_shape_mnk_intra2, + self.tile_shape_mnk_inter1, + self.tile_shape_mnk_inter2, + ) + + # Setup TMA atoms and convert TMA tensors + # TMA load for A + x_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mnk, tiled_mma_intra2.thr_id + ) + tma_atom_x, tma_tensor_x = cute.nvgpu.make_tiled_tma_atom_B( + x_op, + x, + cute.slice_(self.x_smem_layout, (None, None, None, 0)), + self.tile_shape_mnk_intra2, + tiled_mma_intra2, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if x.element_type is cutlass.Float32 else None + ), + ) + + # TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mnk, tiled_mma_intra1.thr_id + ) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + cute.slice_(self.b_smem_layout, (None, None, None, 0)), + self.tile_shape_mnk_intra1, + tiled_mma_intra1, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + # TMA load for C + c_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mnk, tiled_mma_intra1.thr_id + ) + tma_atom_c, tma_tensor_c = cute.nvgpu.make_tiled_tma_atom_A( + c_op, + c, + cute.slice_(self.c_smem_layout, (None, None, None, 0)), + self.tile_shape_mnk_intra1, + tiled_mma_intra1, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if c.element_type is cutlass.Float32 else None + ), + ) + + # TMA load for delta + # TODO: use bulkcp instead of tma + delta_cta_v_layout = cute.slice_( + cute.make_identity_layout(delta.shape), (None, 0, 0, 0) + ) + delta_linear_smem_layout = cute.slice_(self.delta_linear_smem_layout, (None, 0)) + tma_atom_delta, tma_tensor_delta = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + delta, + delta_linear_smem_layout, + delta_cta_v_layout, + ) + + # TMA load for cumsum_delta + cumsum_delta_cta_v_layout = cute.slice_( + cute.make_identity_layout(cumsum_delta.shape), (None, 0, 0, 0) + ) + cumsum_delta_linear_smem_layout = cute.slice_( + self.cumsum_delta_linear_smem_layout, (None, 0) + ) + ( + tma_atom_cumsum_delta, + tma_tensor_cumsum_delta, + ) = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + cumsum_delta, + cumsum_delta_linear_smem_layout, + cumsum_delta_cta_v_layout, + ) + + tma_atom_d = None + tma_tensor_d = d + # TMA load for D + if cutlass.const_expr(self.d_has_hdim): + d_cta_v_layout = cute.slice_(cute.make_identity_layout(d.shape), (None, 0)) + d_linear_smem_layout = cute.slice_(self.d_linear_smem_layout, (None, 0)) + ( + tma_atom_d, + tma_tensor_d, + ) = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + d, + d_linear_smem_layout, + d_cta_v_layout, + ) + + # TMA store for y + y_cta_v_layout = cute.composition( + cute.make_identity_layout(y.shape), self.epi_tile + ) + y_smem_layout = cute.slice_(self.y_smem_layout, (None, None, 0)) + tma_atom_y, tma_tensor_y = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + y, + y_smem_layout, + y_cta_v_layout, + ) + + # TMA store for fstate(p) + p_cta_v_layout = cute.slice_( + cute.make_identity_layout(fstate.shape), (None, None, 0, 0) + ) + p_smem_layout_store = cute.slice_(self.p_smem_layout_store, (None, None, 0)) + tma_atom_p, tma_tensor_p = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + fstate, + p_smem_layout_store, + p_cta_v_layout, + ) + + # Compute grid size + tile_sched_params, grid = self._compute_grid(y, b, max_active_clusters) + + # Plan shared memory storage + swizzle_buffer_align_bytes = 1024 + nonswizzle_buffer_align_bytes = 128 + + @cute.struct + class SharedStorage: + # Input stage barriers + x_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + x_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + b_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + b_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + c_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + c_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + deltas_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + deltas_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + d_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + d_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + # Intra1 acc stage barriers + intra1_acc_full: cute.struct.MemRange[cutlass.Int64, self.intra1_acc_stages] # type: ignore + intra1_acc_empty: cute.struct.MemRange[cutlass.Int64, self.intra1_acc_stages] # type: ignore + # Internal stage barriers + intra2_q_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + intra2_q_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + intra2_acc_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + intra2_acc_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter1_b_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter1_b_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter1_acc_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter1_acc_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter2_p_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter2_p_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter2_acc_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter2_acc_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + # Tmem holding buffer + tmem_holding_buf: cutlass.Int32 + # Smem tensors + smem_x: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(self.x_smem_layout)], + swizzle_buffer_align_bytes, + ] + smem_b: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(self.b_smem_layout)], + swizzle_buffer_align_bytes, + ] + smem_bt_internal: cute.struct.Align[ + cute.struct.MemRange[ + self.io_dtype, cute.cosize(self.bt_internal_smem_layout) + ], + swizzle_buffer_align_bytes, + ] + smem_c: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(self.c_smem_layout)], + swizzle_buffer_align_bytes, + ] + smem_p: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(self.p_smem_layout)], + swizzle_buffer_align_bytes, + ] + smem_y: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(self.y_smem_layout)], + swizzle_buffer_align_bytes, + ] + smem_cumsum_delta: cute.struct.Align[ + cute.struct.MemRange[ + self.cumsum_delta_dtype, + cute.cosize(self.cumsum_delta_linear_smem_layout), + ], + nonswizzle_buffer_align_bytes, + ] + smem_delta: cute.struct.Align[ + cute.struct.MemRange[ + self.io_dtype, cute.cosize(self.delta_linear_smem_layout) + ], + nonswizzle_buffer_align_bytes, + ] + smem_d: cute.struct.Align[ + cute.struct.MemRange[ + self.io_dtype, + cute.cosize(self.d_linear_smem_layout) if self.d_has_hdim else 0, + ], + nonswizzle_buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + if cutlass.const_expr(self.shared_storage.size_in_bytes() > self.smem_capacity): + raise ValueError( + f"SharedStorage size {self.shared_storage.size_in_bytes()} exceeds smem_capacity {self.smem_capacity}" + ) + + # Launch the kernel synchronously + self.kernel( + tma_atom_x, + tma_tensor_x, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + tma_atom_p, + tma_tensor_p, + tma_atom_y, + tma_tensor_y, + tma_atom_delta, + tma_tensor_delta, + tma_atom_cumsum_delta, + tma_tensor_cumsum_delta, + tma_atom_d, + tma_tensor_d, + self.cluster_layout_vmnk, + self.x_smem_layout, + self.xt_smem_layout, + self.b_smem_layout, + self.bt_smem_layout, + self.bt_internal_smem_layout, + self.c_smem_layout, + self.pt_smem_layout, + self.p_smem_layout, + self.q_tmem_layout, + self.p_smem_layout_store, + self.y_smem_layout, + self.delta_linear_smem_layout, + self.cumsum_delta_linear_smem_layout, + self.d_linear_smem_layout, + self.epi_tile, + tile_sched_params, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + min_blocks_per_mp=1, + smem=self.shared_storage.size_in_bytes(), + stream=stream, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + tma_atom_x: cute.CopyAtom, + tma_tensor_x: cute.Tensor, + tma_atom_b: cute.CopyAtom, + tma_tensor_b: cute.Tensor, + tma_atom_c: cute.CopyAtom, + tma_tensor_c: cute.Tensor, + tma_atom_p: cute.CopyAtom, + tma_tensor_p: cute.Tensor, + tma_atom_y: cute.CopyAtom, + tma_tensor_y: cute.Tensor, + tma_atom_delta: cute.CopyAtom, + tma_tensor_delta: cute.Tensor, + tma_atom_cumsum_delta: cute.CopyAtom, + tma_tensor_cumsum_delta: cute.Tensor, + tma_atom_d: Optional[cute.CopyAtom], + tma_tensor_d: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + x_smem_layout: cute.ComposedLayout, + xt_smem_layout: cute.ComposedLayout, + b_smem_layout: cute.ComposedLayout, + bt_smem_layout: cute.ComposedLayout, + bt_internal_smem_layout: cute.ComposedLayout, + c_smem_layout: cute.ComposedLayout, + pt_smem_layout: cute.ComposedLayout, + p_smem_layout: cute.ComposedLayout, + q_tmem_layout: cute.ComposedLayout, + p_smem_layout_store: cute.ComposedLayout, + y_smem_layout: cute.ComposedLayout, + delta_linear_smem_layout: cute.Layout, + cumsum_delta_linear_smem_layout: cute.Layout, + d_linear_smem_layout: Optional[cute.Layout], + epi_tile: cute.Tile, + tile_sched_params: Mamba2SSDTileSchedulerParams, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # Prefetch tma descriptor + if warp_idx == 0: + tma_atoms = [ + tma_atom_x, + tma_atom_b, + tma_atom_c, + tma_atom_p, + tma_atom_y, + tma_atom_delta, + tma_atom_cumsum_delta, + ] + if cutlass.const_expr(self.d_has_hdim): + tma_atoms.append(tma_atom_d) + for tma_atom in tma_atoms: + cpasync.prefetch_descriptor(tma_atom) + + # Static consts + D = cute.size(tma_tensor_x, mode=[0]) + L = cute.size(tma_tensor_x, mode=[1]) + N = cute.size(tma_tensor_b, mode=[1]) + # Dynamic values + C = cute.size(tma_tensor_x, mode=[2]) + EH = cute.size(tma_tensor_x, mode=[3]) + B = cute.size(tma_tensor_x, mode=[4]) + G = cute.size(tma_tensor_b, mode=[3]) + NGROUP_RATIO = EH // G + + # Make tiledMma + ( + tiled_mma_intra1, + tiled_mma_intra2, + tiled_mma_inter1, + tiled_mma_inter2, + ) = self.make_tiled_mmas( + self.io_dtype, + self.acc_dtype, + self.cta_group, + self.tile_shape_mnk_intra1, + self.tile_shape_mnk_intra2, + self.tile_shape_mnk_inter1, + self.tile_shape_mnk_inter2, + ) + + # Setup cta/thread coordinates + # Block coord + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma_intra1.thr_id.shape) + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Workload coord + tile_sched = Mamba2SSDTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + # Thread/warp coord + tidx, _, _ = cute.arch.thread_idx() + # Thread coord inside specialized warps + local_tidx = tidx % 128 + local_warp_idx = cute.arch.make_warp_uniform(local_tidx // 32) + + # Alloc and init smem tensors and pipelines + smem = utils.SmemAllocator() + smem_storage = smem.allocate(self.shared_storage) + + # Setup smem tensors + smem_x = smem_storage.smem_x.get_tensor( + x_smem_layout.outer, swizzle=x_smem_layout.inner + ) + smem_xt = smem_storage.smem_x.get_tensor( + xt_smem_layout.outer, swizzle=xt_smem_layout.inner + ) + smem_b = smem_storage.smem_b.get_tensor( + b_smem_layout.outer, swizzle=b_smem_layout.inner + ) + smem_bt = smem_storage.smem_b.get_tensor( + bt_smem_layout.outer, swizzle=bt_smem_layout.inner + ) + smem_bt_internal = smem_storage.smem_bt_internal.get_tensor( + bt_internal_smem_layout.outer, swizzle=bt_internal_smem_layout.inner + ) + smem_c = smem_storage.smem_c.get_tensor( + c_smem_layout.outer, swizzle=c_smem_layout.inner + ) + smem_p = smem_storage.smem_p.get_tensor( + p_smem_layout.outer, swizzle=p_smem_layout.inner + ) + smem_pt = smem_storage.smem_p.get_tensor( + pt_smem_layout.outer, swizzle=pt_smem_layout.inner + ) + smem_p_store = smem_storage.smem_p.get_tensor( + p_smem_layout_store.outer, swizzle=p_smem_layout_store.inner + ) + smem_y = smem_storage.smem_y.get_tensor( + y_smem_layout.outer, swizzle=y_smem_layout.inner + ) + smem_cumsum_delta = smem_storage.smem_cumsum_delta.get_tensor( + cumsum_delta_linear_smem_layout + ) + smem_delta = smem_storage.smem_delta.get_tensor(delta_linear_smem_layout) + smem_d = None + if cutlass.const_expr(self.d_has_hdim): + smem_d = smem_storage.smem_d.get_tensor(d_linear_smem_layout) + + # Init mbarrier for pipeline + x_pipeline = self.make_and_init_x_pipeline(smem_storage.x_full.data_ptr()) + b_pipeline = self.make_and_init_b_pipeline(smem_storage.b_full.data_ptr()) + c_pipeline = self.make_and_init_c_pipeline(smem_storage.c_full.data_ptr()) + deltas_pipeline = self.make_and_init_deltas_pipeline( + smem_storage.deltas_full.data_ptr() + ) + d_pipeline = self.make_and_init_d_pipeline(smem_storage.d_full.data_ptr()) + intra1_acc_pipeline = self.make_and_init_intra1_acc_pipeline( + smem_storage.intra1_acc_full.data_ptr() + ) + intra2_q_pipeline = self.make_and_init_intra2_q_pipeline( + smem_storage.intra2_q_full.data_ptr() + ) + intra2_acc_pipeline = self.make_and_init_intra2_acc_pipeline( + smem_storage.intra2_acc_full.data_ptr() + ) + inter1_b_pipeline = self.make_and_init_inter1_b_pipeline( + smem_storage.inter1_b_full.data_ptr() + ) + inter1_acc_pipeline = self.make_and_init_inter1_acc_pipeline( + smem_storage.inter1_acc_full.data_ptr() + ) + inter2_p_pipeline = self.make_and_init_inter2_p_pipeline( + smem_storage.inter2_p_full.data_ptr() + ) + inter2_acc_pipeline = self.make_and_init_inter2_acc_pipeline( + smem_storage.inter2_acc_full.data_ptr() + ) + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mnk) > 1: + cute.arch.cluster_arrive_relaxed() + + # Cluster wait before tmem alloc + if cute.size(self.cluster_shape_mnk) > 1: + cute.arch.cluster_wait() + + # Alloc tmem buffer + if warp_idx == self.epilog_warp_id[0]: + cute.arch.alloc_tmem( + self.num_tmem_cols_total, + smem_storage.tmem_holding_buf, + is_two_cta=self.use_2cta_instrs, + ) + + # Bar sync before retrieving tmem ptr from shared mem + cute.arch.barrier() + + # Retrieve tmem ptr + tmem_ptr_base = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=smem_storage.tmem_holding_buf, + ) + + # Specialized TMA load Delta/CumsumDelta/X warp + if warp_idx == self.tma_deltas_x_d_warp_id: + # Dealloc regs for pre-inter/pre-intra warps + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, 1, C, EH, B) + tXsX, tXgX_pre_slice = self.tma_partition_for_mma_b_operand( + tma_atom_x, + tma_tensor_x, + smem_x, + tiled_mma_intra2, + cluster_layout_vmnk, + mma_tile_coord_v, + block_in_cluster_coord_vmnk, + ) + + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, C, EH, B) + tDeltasDelta, tDeltagDelta_pre_slice = self.tma_partition_with_shape( + tma_atom_delta, + tma_tensor_delta, + smem_delta, + (self.tile_shape_mnk_inter1[2],), + ) + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, C, EH, B) + ( + tDeltasCumsumDelta, + tDeltagCumsumDelta_pre_slice, + ) = self.tma_partition_with_shape( + tma_atom_cumsum_delta, + tma_tensor_cumsum_delta, + smem_cumsum_delta, + (self.tile_shape_mnk_inter1[2],), + ) + + tDsD = None + tDgD_pre_slice = None + if cutlass.const_expr(self.d_has_hdim): + # Partition global/shared tensor for D + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, EH) + tDsD, tDgD_pre_slice = self.tma_partition_with_shape( + tma_atom_d, tma_tensor_d, smem_d, (self.tile_shape_mnk_inter2[1],) + ) + + # Pipeline X/Delta/CumsumDelta/D producer state + x_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.input_stages + ) + deltas_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.input_stages + ) + d_producer_state = None + if cutlass.const_expr(self.d_has_hdim): + # D is loaded by TMA only when d_has_hdim is True + d_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.input_stages + ) + + while work_tile.is_valid_tile: + b_idx, eh_idx, g_idx = work_tile.tile_idx + + # Slice global tensor to current tile idx + # ((ATOM_V, REST_V), C) + tXgX = tXgX_pre_slice[None, 0, 0, None, eh_idx, b_idx] + tDeltagDelta = tDeltagDelta_pre_slice[None, 0, None, eh_idx, b_idx] + tDeltagCumsumDelta = tDeltagCumsumDelta_pre_slice[ + None, 0, None, eh_idx, b_idx + ] + tDgD = None + if cutlass.const_expr(self.d_has_hdim): + # ((ATOM_V, REST_V)) + tDgD = tDgD_pre_slice[None, 0, eh_idx] + + # Reset count for pipeline state + x_producer_state.reset_count() + deltas_producer_state.reset_count() + if cutlass.const_expr(self.d_has_hdim): + d_producer_state.reset_count() + + # Peek (try_wait) X/deltas buffer empty status + peek_x_empty_status = self.conditional_producer_try_acquire( + x_producer_state, x_pipeline, C + ) + peek_deltas_empty_status = self.conditional_producer_try_acquire( + deltas_producer_state, deltas_pipeline, C + ) + + if cutlass.const_expr(self.d_has_hdim): + # Wait for D buffer empty + d_pipeline.producer_acquire(d_producer_state) + # TMA load D + cute.copy( + tma_atom_d, + tDgD, + tDsD[None, d_producer_state.index], + tma_bar_ptr=d_pipeline.producer_get_barrier(d_producer_state), + ) + # Advance D producer state + d_producer_state.advance() + + # Batched load over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for X buffer empty + x_pipeline.producer_acquire(x_producer_state, peek_x_empty_status) + + # TMA load X + cute.copy( + tma_atom_x, + tXgX[None, x_producer_state.count], + tXsX[None, x_producer_state.index], + tma_bar_ptr=x_pipeline.producer_get_barrier(x_producer_state), + ) + + # Conditionally wait for deltas buffer empty + deltas_pipeline.producer_acquire( + deltas_producer_state, peek_deltas_empty_status + ) + + # TMA load Delta/CumsumDelta + cute.copy( + tma_atom_delta, + tDeltagDelta[None, deltas_producer_state.count], + tDeltasDelta[None, deltas_producer_state.index], + tma_bar_ptr=deltas_pipeline.producer_get_barrier( + deltas_producer_state + ), + ) + cute.copy( + tma_atom_cumsum_delta, + tDeltagCumsumDelta[None, deltas_producer_state.count], + tDeltasCumsumDelta[None, deltas_producer_state.index], + tma_bar_ptr=deltas_pipeline.producer_get_barrier( + deltas_producer_state + ), + ) + + # Advance X/deltas producer state + x_producer_state.advance() + deltas_producer_state.advance() + + # Peek (try_wait) X/deltas buffer empty status + peek_x_empty_status = self.conditional_producer_try_acquire( + x_producer_state, x_pipeline, C + ) + peek_deltas_empty_status = self.conditional_producer_try_acquire( + deltas_producer_state, deltas_pipeline, C + ) + # END of for chunk_idx in cutlass.range(C, unroll=1) + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # END of while work_tile.is_valid_tile + + # Producer tail for X/Deltas/D + x_pipeline.producer_tail(x_producer_state) + deltas_pipeline.producer_tail(deltas_producer_state) + if cutlass.const_expr(self.d_has_hdim): + d_pipeline.producer_tail(d_producer_state) + # END of specialized tma load X/Deltas/D warp + + # Specialized TMA load B/C warp + elif warp_idx == self.tma_b_c_warp_id: + # Dealloc regs for pre-inter/pre-intra warps + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, 1, C, G, B) + tBsB, tBgB_pre_slice = self.tma_partition_for_mma_b_operand( + tma_atom_b, + tma_tensor_b, + smem_b, + tiled_mma_intra1, + cluster_layout_vmnk, + mma_tile_coord_v, + block_in_cluster_coord_vmnk, + ) + + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, 1, C, G, B) + tCsC, tCgC_pre_slice = self.tma_partition_for_mma_a_operand( + tma_atom_c, + tma_tensor_c, + smem_c, + tiled_mma_intra1, + cluster_layout_vmnk, + mma_tile_coord_v, + block_in_cluster_coord_vmnk, + ) + + # Pipeline B/C producer state + b_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.input_stages + ) + c_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.input_stages + ) + + while work_tile.is_valid_tile: + b_idx, eh_idx, g_idx = work_tile.tile_idx + + # Slice global tensor to current tile idx + # ((ATOM_V, REST_V), C) + tBgB = tBgB_pre_slice[None, 0, 0, None, g_idx, b_idx] + tCgC = tCgC_pre_slice[None, 0, 0, None, g_idx, b_idx] + + # Reset count for pipeline state + b_producer_state.reset_count() + c_producer_state.reset_count() + + # Peek (try_wait) B/C buffer empty status + peek_b_empty_status = self.conditional_producer_try_acquire( + b_producer_state, b_pipeline, C + ) + peek_c_empty_status = self.conditional_producer_try_acquire( + c_producer_state, c_pipeline, C + ) + + # Batched load over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for B buffer empty + b_pipeline.producer_acquire(b_producer_state, peek_b_empty_status) + + # TMA load B + cute.copy( + tma_atom_b, + tBgB[None, b_producer_state.count], + tBsB[None, b_producer_state.index], + tma_bar_ptr=b_pipeline.producer_get_barrier(b_producer_state), + ) + + # Conditionally wait for C buffer empty + c_pipeline.producer_acquire(c_producer_state, peek_c_empty_status) + + # TMA load C + cute.copy( + tma_atom_c, + tCgC[None, c_producer_state.count], + tCsC[None, c_producer_state.index], + tma_bar_ptr=c_pipeline.producer_get_barrier(c_producer_state), + ) + + # Advance B/C producer state + b_producer_state.advance() + c_producer_state.advance() + + # Peek (try_wait) B/C buffer empty status + peek_b_empty_status = self.conditional_producer_try_acquire( + b_producer_state, b_pipeline, C + ) + peek_c_empty_status = self.conditional_producer_try_acquire( + c_producer_state, c_pipeline, C + ) + # END of for chunk_idx in cutlass.range(C, unroll=1) + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # END of while work_tile.is_valid_tile + + # Producer tail for B/C + b_pipeline.producer_tail(b_producer_state) + c_pipeline.producer_tail(c_producer_state) + # END of specialized tma load B/C warp + + # Specialized MMA Intra warp + elif warp_idx == self.mma_intra_warp_id: + # Dealloc regs for pre-inter/pre-intra warps + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + + # Make shared/tmem fragments for INTRA_MMA1 B/C/ACC + # (MMA, MMA_N, MMA_K, INPUT_STAGE) + # (MMA, MMA_M, MMA_K, INPUT_STAGE) + # (MMA, MMA_M, MMA_N, INTRA1_ACC_STAGE) + tCrC, tCrB, tCtAccIntra1 = self.mma_partition_ss( + tiled_mma_intra1, + self.tile_shape_mnk_intra1, + smem_c, + smem_b, + tmem_ptr_base + self.tmem_intra1_acc_offset, + self.intra1_acc_stages, + ) + + # Make shared/tmem fragments for INTRA_MMA2 X/Q/ACC + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + # (MMA, MMA_N, MMA_K, INPUT_STAGE) + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCrQ, tCrX, tCtAccIntra2 = self.mma_partition_ts( + tiled_mma_intra2, + self.tile_shape_mnk_intra2, + q_tmem_layout, + smem_x, + tmem_ptr_base + self.tmem_intra2_q_offset, + tmem_ptr_base + self.tmem_intra2_acc_offset, + self.internal_stages, + ) + + # Pipeline B/C/X/INTRA2_Q consumer state + b_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + c_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + x_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + intra2_q_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + + # Pipeline INTRA1_ACC/INTRA2_ACC producer state + intra1_acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.intra1_acc_stages + ) + intra2_acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + + while work_tile.is_valid_tile: + # Reset count for pipeline state + b_consumer_state.reset_count() + c_consumer_state.reset_count() + intra1_acc_producer_state.reset_count() + x_consumer_state.reset_count() + intra2_q_consumer_state.reset_count() + intra2_acc_producer_state.reset_count() + + # Peek (try_wait) B/C/X/INTRA1_ACC buffer full/full/full/empty status + peek_b_full_status = self.conditional_consumer_try_wait( + b_consumer_state, b_pipeline, C + ) + peek_c_full_status = self.conditional_consumer_try_wait( + c_consumer_state, c_pipeline, C + ) + peek_wr_intra1_acc_empty_status = self.conditional_producer_try_acquire( + intra1_acc_producer_state, intra1_acc_pipeline, C + ) + peek_x_full_status = self.conditional_consumer_try_wait( + x_consumer_state, x_pipeline, C + ) + + # Manual pipeline: unrolled INTRA_MMA1 chunk_idx = 0 loop + # Conditionally wait for B/C/INTRA1_ACC buffer full/full/empty + b_pipeline.consumer_wait(b_consumer_state, peek_b_full_status) + c_pipeline.consumer_wait(c_consumer_state, peek_c_full_status) + intra1_acc_pipeline.producer_acquire( + intra1_acc_producer_state, peek_wr_intra1_acc_empty_status + ) + + # INTRA_MMA1 + tiled_mma_intra1 = self.exec_mma( + tiled_mma_intra1, + tCtAccIntra1, + tCrC, + tCrB, + intra1_acc_producer_state, + c_consumer_state, + b_consumer_state, + ) + + # Async arrive B/C/INTRA1_ACC buffer empty/empty/full + b_pipeline.consumer_release( + b_consumer_state, pipeline.PipelineOp.TCGen05Mma + ) + c_pipeline.consumer_release(c_consumer_state) + intra1_acc_pipeline.producer_commit(intra1_acc_producer_state) + + # Advance B/C/INTRA1_ACC state + b_consumer_state.advance() + c_consumer_state.advance() + intra1_acc_producer_state.advance() + + # Peek (try_wait) B/C/INTRA1_ACC buffer full/full/empty for chunk_idx = chunk_idx + 1 + peek_b_full_status = self.conditional_consumer_try_wait( + b_consumer_state, b_pipeline, C + ) + peek_c_full_status = self.conditional_consumer_try_wait( + c_consumer_state, c_pipeline, C + ) + peek_wr_intra1_acc_empty_status = self.conditional_producer_try_acquire( + intra1_acc_producer_state, intra1_acc_pipeline, C + ) + + # Manual pipeline: batched gemm over C-1 dimension + for chunk_idx in cutlass.range(C - 1, unroll=1): + # Conditionally wait for B/C/INTRA1_ACC buffer full/full/empty + b_pipeline.consumer_wait(b_consumer_state, peek_b_full_status) + c_pipeline.consumer_wait(c_consumer_state, peek_c_full_status) + intra1_acc_pipeline.producer_acquire( + intra1_acc_producer_state, peek_wr_intra1_acc_empty_status + ) + + # INTRA_MMA1 + tiled_mma_intra1 = self.exec_mma( + tiled_mma_intra1, + tCtAccIntra1, + tCrC, + tCrB, + intra1_acc_producer_state, + c_consumer_state, + b_consumer_state, + ) + + # Async arrive B/C/INTRA1_ACC buffer empty/empty/full + b_pipeline.consumer_release( + b_consumer_state, pipeline.PipelineOp.TCGen05Mma + ) + c_pipeline.consumer_release(c_consumer_state) + intra1_acc_pipeline.producer_commit(intra1_acc_producer_state) + + # Conditionally wait for X/INTRA2_Q/INTRA2_ACC buffer full/full/empty + x_pipeline.consumer_wait(x_consumer_state, peek_x_full_status) + intra2_q_pipeline.consumer_wait(intra2_q_consumer_state) + intra2_acc_pipeline.producer_acquire(intra2_acc_producer_state) + + # INTRA_MMA2 + tiled_mma_intra2 = self.exec_mma( + tiled_mma_intra2, + tCtAccIntra2, + tCrQ, + tCrX, + intra2_acc_producer_state, + intra2_q_consumer_state, + x_consumer_state, + ) + + # Async arrive X/INTRA2_Q/INTRA2_ACC buffer empty/empty/full + if cutlass.const_expr(self.has_d): + x_pipeline.consumer_release( + x_consumer_state, pipeline.PipelineOp.TCGen05Mma + ) + else: + x_pipeline.consumer_release(x_consumer_state) + intra2_q_pipeline.consumer_release(intra2_q_consumer_state) + intra2_acc_pipeline.producer_commit(intra2_acc_producer_state) + + # Advance B/C/INTRA1_ACC cstate + b_consumer_state.advance() + c_consumer_state.advance() + intra1_acc_producer_state.advance() + + # Peek (try_wait) B/C/INTRA1_ACC buffer full/full/empty for chunk_idx = chunk_idx + 1 + peek_b_full_status = self.conditional_consumer_try_wait( + b_consumer_state, b_pipeline, C + ) + peek_c_full_status = self.conditional_consumer_try_wait( + c_consumer_state, c_pipeline, C + ) + peek_wr_intra1_acc_empty_status = ( + self.conditional_producer_try_acquire( + intra1_acc_producer_state, intra1_acc_pipeline, C + ) + ) + + # Advance X/INTRA2_Q/INTRA2_ACC state + x_consumer_state.advance() + intra2_q_consumer_state.advance() + intra2_acc_producer_state.advance() + + # Peek (try_wait) X buffer full for chunk_idx = chunk_idx + 1 + peek_x_full_status = self.conditional_consumer_try_wait( + x_consumer_state, x_pipeline, C + ) + # END of for chunk_idx in cutlass.range(C-1, unroll=1) + + # Manual pipeline: unrolled INTRA_MMA2 chunk_idx = C-1 loop + # Conditionally wait for X/INTRA2_Q/INTRA2_ACC buffer full/full/empty + x_pipeline.consumer_wait(x_consumer_state, peek_x_full_status) + intra2_q_pipeline.consumer_wait(intra2_q_consumer_state) + intra2_acc_pipeline.producer_acquire(intra2_acc_producer_state) + + # INTRA_MMA2 + tiled_mma_intra2 = self.exec_mma( + tiled_mma_intra2, + tCtAccIntra2, + tCrQ, + tCrX, + intra2_acc_producer_state, + intra2_q_consumer_state, + x_consumer_state, + ) + + # Async arrive X/INTRA2_Q/INTRA2_ACC buffer empty/empty/full + if cutlass.const_expr(self.has_d): + x_pipeline.consumer_release( + x_consumer_state, pipeline.PipelineOp.TCGen05Mma + ) + else: + x_pipeline.consumer_release(x_consumer_state) + intra2_q_pipeline.consumer_release(intra2_q_consumer_state) + intra2_acc_pipeline.producer_commit(intra2_acc_producer_state) + + # Advance X/INTRA2_Q/INTRA2_ACC state + x_consumer_state.advance() + intra2_q_consumer_state.advance() + intra2_acc_producer_state.advance() + + # Peek (try_wait) X buffer full for chunk_idx = chunk_idx + 1 + peek_x_full_status = self.conditional_consumer_try_wait( + x_consumer_state, x_pipeline, C + ) + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # END of while work_tile.is_valid_tile + + # Producer tail for INTRA1_ACC/INTRA2_ACC + intra1_acc_pipeline.producer_tail(intra1_acc_producer_state) + intra2_acc_pipeline.producer_tail(intra2_acc_producer_state) + # END of specialized mma-intra warp + + # Specialized MMA Inter warp + elif warp_idx == self.mma_inter_warp_id: + # Dealloc regs for pre-inter/pre-intra warps + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + + # Make shared/tmem fragments for INTER_MMA1 X/B/ACC + # (MMA, MMA_N, MMA_K, INPUT_STAGE) + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCrB, tCrX, tCtAccInter1 = self.mma_partition_ss( + tiled_mma_inter1, + self.tile_shape_mnk_inter1, + smem_bt_internal, + smem_x, + tmem_ptr_base + self.tmem_inter1_acc_offset, + self.internal_stages, + ) + + # Make shared/tmem fragments for INTER_MMA2 C/P/ACC + # (MMA, MMA_M, MMA_K, INPUT_STAGE) + # (MMA, MMA_N, MMA_K, INTERNAL_STAGE) + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCrC, tCrP, tCtAccInter2 = self.mma_partition_ss( + tiled_mma_inter2, + self.tile_shape_mnk_inter2, + smem_c, + smem_p, + tmem_ptr_base + self.tmem_inter2_acc_offset, + self.internal_stages, + ) + + # Pipeline X/C/INTER1_B/INTER2_P consumer state + x_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + c_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + inter1_b_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + inter2_p_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + + # Pipeline INTER1_ACC/INTER2_ACC producer state + inter1_acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + inter2_acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + + while work_tile.is_valid_tile: + # Reset count for pipeline state + x_consumer_state.reset_count() + c_consumer_state.reset_count() + inter1_acc_producer_state.reset_count() + inter1_b_consumer_state.reset_count() + inter2_p_consumer_state.reset_count() + inter2_acc_producer_state.reset_count() + + # Peek (try_wait) C/INTER2_P/INTER2_ACC buffer full/full/empty status + peek_c_full_status = self.conditional_consumer_try_wait( + c_consumer_state, c_pipeline, C + ) + peek_inter2_p_full_status = self.conditional_consumer_try_wait( + inter2_p_consumer_state, inter2_p_pipeline, C + ) + peek_inter2_acc_empty_status = self.conditional_producer_try_acquire( + inter2_acc_producer_state, inter2_acc_pipeline, C + ) + + # Batched gemm over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for C/INTER2_P/INTER2_ACC buffer full/full/empty + c_pipeline.consumer_wait(c_consumer_state, peek_c_full_status) + inter2_p_pipeline.consumer_wait( + inter2_p_consumer_state, peek_inter2_p_full_status + ) + inter2_acc_pipeline.producer_acquire( + inter2_acc_producer_state, peek_inter2_acc_empty_status + ) + + # INTER MMA2 + tiled_mma_inter2 = self.exec_mma( + tiled_mma_inter2, + tCtAccInter2, + tCrC, + tCrP, + inter2_acc_producer_state, + c_consumer_state, + inter2_p_consumer_state, + ) + + # Async arrive C/INTER2_P/INTER2_ACC buffer empty/empty/full + c_pipeline.consumer_release(c_consumer_state) + inter2_p_pipeline.consumer_release(inter2_p_consumer_state) + inter2_acc_pipeline.producer_commit(inter2_acc_producer_state) + + # Wait for X/INTER1_B/INTER1_ACC buffer full/full/empty + x_pipeline.consumer_wait(x_consumer_state) + inter1_b_pipeline.consumer_wait(inter1_b_consumer_state) + inter1_acc_pipeline.producer_acquire(inter1_acc_producer_state) + + # INTER MMA1 + tiled_mma_inter1 = self.exec_mma( + tiled_mma_inter1, + tCtAccInter1, + tCrB, + tCrX, + inter1_acc_producer_state, + inter1_b_consumer_state, + x_consumer_state, + ) + + # Async arrive X/INTER1_B/INTER1_ACC buffer empty/empty/full + if cutlass.const_expr(self.has_d): + x_pipeline.consumer_release( + x_consumer_state, pipeline.PipelineOp.TCGen05Mma + ) + else: + x_pipeline.consumer_release(x_consumer_state) + inter1_b_pipeline.consumer_release(inter1_b_consumer_state) + inter1_acc_pipeline.producer_commit(inter1_acc_producer_state) + + # Advance X/C/INTER1_B/INTER1_ACC/INTER2_P/INTER2_ACC state + x_consumer_state.advance() + c_consumer_state.advance() + inter1_b_consumer_state.advance() + inter1_acc_producer_state.advance() + inter2_p_consumer_state.advance() + inter2_acc_producer_state.advance() + + # Peek (try_wait) C/INTER2_P/INTER2_ACC buffer full/full/empty for chunk_idx = chunk_idx + 1 + peek_c_full_status = self.conditional_consumer_try_wait( + c_consumer_state, c_pipeline, C + ) + peek_inter2_p_full_status = self.conditional_consumer_try_wait( + inter2_p_consumer_state, inter2_p_pipeline, C + ) + peek_inter2_acc_empty_status = ( + self.conditional_producer_try_acquire( + inter2_acc_producer_state, inter2_acc_pipeline, C + ) + ) + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # Producer tail for INTER1_ACC/INTER2_ACC + inter1_acc_pipeline.producer_tail(inter1_acc_producer_state) + inter2_acc_pipeline.producer_tail(inter2_acc_producer_state) + + # Specialized Pre-Inter warp + elif ( + warp_idx == self.pre_inter_warp_id[0] + or warp_idx == self.pre_inter_warp_id[1] + or warp_idx == self.pre_inter_warp_id[2] + or warp_idx == self.pre_inter_warp_id[3] + ): + # Alloc regs in pre_inter warps + cute.arch.warpgroup_reg_alloc(self.num_regs_pre_inter_warps) + + # Make tiledCopy and partition smem/register tensor for smem load Bt + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N, INPUT_STAGE) + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) + tiled_s2r_b, tBsB_s2r, tBrB_s2r = self.pre_inter_smem_load_and_partition_b( + local_tidx, smem_bt + ) + + # Partition shared tensor for smem store Bt + smem_bt_internal_ = cute.make_tensor( + smem_bt_internal.iterator, smem_bt.layout + ) + # Make tiledCopy and partition register/smem tensor for smem store Bt + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N) + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE) + tiled_r2s_b, tBrB_r2s, tBsB_r2s = self.pre_inter_smem_store_and_partition_b( + local_tidx, smem_bt_internal_, tiled_s2r_b, tBrB_s2r + ) + + # (MMA, MMA_M, MMA_K, INPUT_STAGE) + sDelta = self.pre_inter_make_delta(smem_delta, smem_bt.layout) + sDeltaA = self.pre_inter_make_delta(smem_cumsum_delta, smem_bt.layout) + + # Make copy_atom and partition register/smem tensor for smem load/store of Delta/DeltaA + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N, INPUT_STAGE) + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) + ( + s2r_atom_delta, + tBsDelta_s2r, + tBrDelta_s2r, + ) = self.smem_load_and_partition_delta_d( + tiled_s2r_b, local_tidx, sDelta, (None, None, None, 0) + ) + ( + s2r_atom_cumsum, + tBsDeltaA_s2r, + tBrDeltaA_s2r, + ) = self.smem_load_and_partition_delta_d( + tiled_s2r_b, local_tidx, sDeltaA, (None, None, None, 0) + ) + + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N) + thr_r2s_b = tiled_r2s_b.get_slice(local_tidx) + tBrDelta_r2s = thr_r2s_b.retile(tBrDelta_s2r) + tBrDeltaA_r2s = thr_r2s_b.retile(tBrDeltaA_s2r) + + # Make tmem fragment for INTER1_ACC + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCtAccInter1 = self.mma_partition_c( + tiled_mma_inter1, + self.tile_shape_mnk_inter1, + tmem_ptr_base + self.tmem_inter1_acc_offset, + self.internal_stages, + ) + # (M_PER_MMA, N_PER_MMA, INTERNAL_STAGE) + tInter1 = tCtAccInter1[((None, None), 0, 0, None)] + + # Make tiledCopy and partition tmem/register tensor for tmem load INTER1_ACC + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, INTERNAL_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + ( + tiled_t2r_inter1, + tTR_tP, + tTR_rP, + ) = self.pre_inter_tmem_load_and_partition_p(local_tidx, tInter1, smem_pt) + + # Make fragment for register to hold P after post-processing (in acc dtype) + tState = cute.make_fragment(tTR_rP.shape, self.acc_dtype) + + # Make tiledCopy and partition smem/register tensor for smem store INTER2_P + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N) + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE) + tiled_r2s_p, tRS_rP, tRS_sP = self.smem_store_and_partition_p_y( + local_tidx, smem_pt, tiled_t2r_inter1 + ) + + # Partition global/shared tensor for P (State) + # ((ATOM_V, REST_V), INTERNAL_STAGE) + # ((ATOM_V, REST_V), 1, 1, EH, B) + bSG_sP, bSG_gP_pre_slice = self.tma_partition_with_shape( + tma_atom_p, + tma_tensor_p, + smem_p_store, + self.tile_shape_mnk_inter2[1:], + ) + + # Pipeline B/Delta/INTER1_ACC consumer state + b_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + deltas_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + inter1_acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + + # Pipeline INTER1_B/INTER2_P producer state + inter1_b_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + inter2_p_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + + # Pipeline TMA store P + tma_p_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.internal_stages, + producer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + ), + ) + + while work_tile.is_valid_tile: + b_idx, eh_idx, g_idx = work_tile.tile_idx + + # Slice global tensor to current tile idx + # ((ATOM_V, REST_V)) + bSG_gP = bSG_gP_pre_slice[(None, 0, 0, eh_idx, b_idx)] + + # Reset count for pipeline state + b_consumer_state.reset_count() + deltas_consumer_state.reset_count() + inter1_b_producer_state.reset_count() + inter1_acc_consumer_state.reset_count() + inter2_p_producer_state.reset_count() + + # State (P) init + tState.fill(0.0) + + # Peek (try_wait) B/Delta/INTER1_B buffer full/full/empty status + peek_b_full_status = self.conditional_consumer_try_wait( + b_consumer_state, b_pipeline, C + ) + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_wr_inter1_b_empty_status = self.conditional_producer_try_acquire( + inter1_b_producer_state, inter1_b_pipeline, C + ) + + # Prefill INTER2_P with 0 + # Wait for INTER2_P buffer empty + inter2_p_pipeline.producer_acquire(inter2_p_producer_state) + + tRS_rP.fill(0.0) + # Copy INTER2_P from register to smem + inter2_p_coord = (None, None, None, inter2_p_producer_state.index) + cute.copy(tiled_r2s_p, tRS_rP, tRS_sP[inter2_p_coord]) + + # Fence for shared memory + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + # Async arrive INTER2_P buffer full + inter2_p_pipeline.producer_commit(inter2_p_producer_state) + # Advance INTER2_P producer state + inter2_p_producer_state.advance() + + # Batched processing over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for B/Delta/B_TMEM buffer full/full/empty + b_pipeline.consumer_wait(b_consumer_state, peek_b_full_status) + deltas_pipeline.consumer_wait( + deltas_consumer_state, peek_deltas_full_status + ) + inter1_b_pipeline.producer_acquire( + inter1_b_producer_state, peek_wr_inter1_b_empty_status + ) + + # Load B/Delta/DeltaA/last_column + b_coord = (None, None, None, b_consumer_state.index) + delta_coord = (None, None, None, deltas_consumer_state.index) + cute.copy(tiled_s2r_b, tBsB_s2r[b_coord], tBrB_s2r) + cute.copy(s2r_atom_delta, tBsDelta_s2r[delta_coord], tBrDelta_s2r) + cute.copy( + s2r_atom_cumsum, tBsDeltaA_s2r[delta_coord], tBrDeltaA_s2r + ) + last_column = smem_cumsum_delta[ + smem_cumsum_delta.shape[0] - 1, deltas_consumer_state.index + ] + + # Fence for shared memory + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + # Combine B/Delta/DeltaA/last_column + tScaledB = self.pre_inter_scale_bt_with_delta( + tBrB_s2r, tBrDelta_r2s, tBrDeltaA_r2s, last_column + ) + + # Store scaled B to tBrB_r2s + for reg_idx in range(cute.size(tBrB_r2s)): + tBrB_r2s[reg_idx] = tScaledB[reg_idx].to(self.io_dtype) + + # Store tBrB_r2s to bt_smem_internal + inter1_b_coord = (None, None, None, inter1_b_producer_state.index) + cute.copy(tiled_r2s_b, tBrB_r2s, tBsB_r2s[inter1_b_coord]) + + # Fence for shared memory + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + # Async arrive B/Delta/B_TMEM buffer empty/empty/full + b_pipeline.consumer_release( + b_consumer_state, pipeline.PipelineOp.AsyncThread + ) + deltas_pipeline.consumer_release(deltas_consumer_state) + inter1_b_pipeline.producer_commit(inter1_b_producer_state) + + # Wait for INTER1_ACC/INTER2_P buffer full/empty + inter1_acc_pipeline.consumer_wait(inter1_acc_consumer_state) + inter2_p_pipeline.producer_acquire(inter2_p_producer_state) + + # Load INTER1_ACC + inter1_acc_coord = ( + None, + None, + None, + inter1_acc_consumer_state.index, + ) + cute.copy(tiled_t2r_inter1, tTR_tP[inter1_acc_coord], tTR_rP) + + # Fence for TMEM load + cute.arch.fence_view_async_tmem_load() + + # Combine INTER1_ACC/last_column/State + exp_last_column = cute.arch.exp(last_column.ir_value()) + for reg_idx in range(0, cute.size(tTR_rP), 2): + ( + tTR_rP[reg_idx], + tTR_rP[reg_idx + 1], + ) = cute.arch.fma_packed_f32x2( + (exp_last_column, exp_last_column), + (tState[reg_idx], tState[reg_idx + 1]), + (tTR_rP[reg_idx], tTR_rP[reg_idx + 1]), + ) + + # Store scaled P to tRS_rP + for reg_idx in range(cute.size(tTR_rP)): + tRS_rP[reg_idx] = tTR_rP[reg_idx].to(self.io_dtype) + + # Update old state + tState.store(tTR_rP.load()) + + # Store INTER2_P + inter2_p_coord = (None, None, None, inter2_p_producer_state.index) + cute.copy(tiled_r2s_p, tRS_rP, tRS_sP[inter2_p_coord]) + + # Fence for shared memory + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + # Async arrive INTER1_ACC/INTER2_P buffer empty/full + inter1_acc_pipeline.consumer_release(inter1_acc_consumer_state) + # Last iteration consumer is PRE_INTER warp itself, not MMA_INTER warp + if inter2_p_producer_state.count < C: + inter2_p_pipeline.producer_commit(inter2_p_producer_state) + + # Advance B/Delta/INTER1_B/INTER1_ACC state + b_consumer_state.advance() + deltas_consumer_state.advance() + inter1_b_producer_state.advance() + inter1_acc_consumer_state.advance() + # Peek (try_wait) B/Delta/INTER1_B buffer full/full./empty for chunk_idx = chunk_idx + 1 + peek_b_full_status = self.conditional_consumer_try_wait( + b_consumer_state, b_pipeline, C + ) + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_wr_inter1_b_empty_status = ( + self.conditional_producer_try_acquire( + inter1_b_producer_state, inter1_b_pipeline, C + ) + ) + + # Last iteration producer is PRE_INTER warp itself, not MMA_INTER warp + if inter2_p_producer_state.count < C: + # Advance INTER2_P producer state + inter2_p_producer_state.advance() + # END of for chunk_idx in cutlass.range(C, unroll=1) + + # Store last INTER2_P (State) from smem to gmem + # Wait for all previous stores to smem to be done + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + cute.arch.barrier( + barrier_id=self.pre_inter_sync_bar_id, + number_of_threads=len(self.pre_inter_warp_id) * 32, + ) + + if local_warp_idx == 0: + # TMA store P + cute.copy( + tma_atom_p, + bSG_sP[(None, inter2_p_producer_state.index)], + bSG_gP, + ) + # Wait for TMA store done + tma_p_pipeline.producer_commit() + tma_p_pipeline.producer_acquire() + + cute.arch.barrier( + barrier_id=self.pre_inter_sync_bar_id, + number_of_threads=len(self.pre_inter_warp_id) * 32, + ) + tma_p_pipeline.producer_tail() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # END of while work_tile.is_valid_tile + + # Producer tail for INTER1_B/INTER2_P/TMA store P + inter1_b_pipeline.producer_tail(inter1_b_producer_state) + inter2_p_pipeline.producer_tail(inter2_p_producer_state) + # END of specialized pre-inter warp + + # Specialized Pre-Intra warp + elif ( + warp_idx == self.pre_intra_warp_id[0] + or warp_idx == self.pre_intra_warp_id[1] + or warp_idx == self.pre_intra_warp_id[2] + or warp_idx == self.pre_intra_warp_id[3] + ): + # Alloc regs in pre_inter warps + cute.arch.warpgroup_reg_alloc(self.num_regs_pre_intra_warps) + + # Make tmem fragment for INTRA1_ACC + # (MMA, MMA_M, MMA_N, INTRA1_ACC_STAGE) + tCtAccIntra1 = self.mma_partition_c( + tiled_mma_intra1, + self.tile_shape_mnk_intra1, + tmem_ptr_base + self.tmem_intra1_acc_offset, + self.intra1_acc_stages, + ) + # (M_PER_MMA, N_PER_MMA, INTRA1_ACC_STAGE) + tIntra1 = tCtAccIntra1[((None, None), 0, 0, None)] + + # Make tiledCopy and partition tmem/register tensor for tensor memory load INTRA1_ACC + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, INTERNAL_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + tiled_t2r_intra1, tTR_tQ, tTR_rQ = self.pre_intra_tmem_load_and_partition_q( + tIntra1, local_tidx + ) + + # Broadcast delta/delta_cumsum smem tensor from LxINPUT_STAGE to LxLxINPUT_STAGE + sDeltaA_Row = self.pre_intra_make_delta(smem_cumsum_delta, 0) + sDeltaA_Col = self.pre_intra_make_delta(smem_cumsum_delta, 1) + sDelta = self.pre_intra_make_delta(smem_delta, 0) + + # Make tiledCopy and partition smem/register tensor for smem memory load delta/delta_cumsum + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, INPUT_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + ( + s2r_atom_cumsum, + tQsDeltaA_Row, + tQrDeltaA_Row, + ) = self.smem_load_and_partition_delta_d( + tiled_t2r_intra1, local_tidx, sDeltaA_Row, (None, None, None, 0) + ) + ( + s2r_atom_cumsum, + tQsDeltaA_Col, + tQrDeltaA_Col, + ) = self.smem_load_and_partition_delta_d( + tiled_t2r_intra1, local_tidx, sDeltaA_Col, (None, None, None, 0) + ) + ( + s2r_atom_delta, + tQsDelta, + tQrDelta, + ) = self.smem_load_and_partition_delta_d( + tiled_t2r_intra1, local_tidx, sDelta, (None, None, None, 0) + ) + + # Make and partition coord tensor for delta_cumsum load + # (L, L) + coord_tensor = cute.make_identity_tensor( + cute.dice(self.tile_shape_mnk_intra1, (1, 1, None)) + ) + thr_t2r_intra1 = tiled_t2r_intra1.get_slice(local_tidx) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + tCoord = thr_t2r_intra1.partition_D(coord_tensor) + + # Make tmem tensor for INTRA2_Q + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + tCrQ = self.mma_partition_a_tmem( + tiled_mma_intra2, + q_tmem_layout, + tmem_ptr_base + self.tmem_intra2_q_offset, + ) + + # Make tiledCopy and partition tmem/register tensor for tensor memory store INTRA2_Q + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ...) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ..., INTERNAL_STAGE) + tiled_r2t_q, tRT_rQ, tRT_tQ = self.pre_intra_tmem_store_and_partition_q( + local_tidx, tCrQ + ) + + # Pipeline DELTA/INTRA1_ACC consumer state + deltas_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + intra1_acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.intra1_acc_stages + ) + # Pipeline INTRA2_Q producer state + intra2_q_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + + while work_tile.is_valid_tile: + # Reset count for pipeline state + deltas_consumer_state.reset_count() + intra1_acc_consumer_state.reset_count() + intra2_q_producer_state.reset_count() + + # Peek (try_wait) DELTA/INTRA1_ACC buffer full + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_rd_intra1_acc_full_status = self.conditional_consumer_try_wait( + intra1_acc_consumer_state, intra1_acc_pipeline, C + ) + + # Batched processing over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for Delta/INTRA1_ACC buffer full + deltas_pipeline.consumer_wait( + deltas_consumer_state, peek_deltas_full_status + ) + intra1_acc_pipeline.consumer_wait( + intra1_acc_consumer_state, peek_rd_intra1_acc_full_status + ) + + # Load Q from tmem + intra1_coord = (None, None, None, intra1_acc_consumer_state.index) + cute.copy(tiled_t2r_intra1, tTR_tQ[intra1_coord], tTR_rQ) + cute.arch.fence_view_async_tmem_load() + + # Load tQsDeltaA_Row/tQsDeltaA_Col/tQsDelta from smem + delta_coord = (None, None, None, deltas_consumer_state.index) + cute.copy( + s2r_atom_cumsum, tQsDeltaA_Row[delta_coord], tQrDeltaA_Row + ) + cute.copy( + s2r_atom_cumsum, tQsDeltaA_Col[delta_coord], tQrDeltaA_Col + ) + cute.copy(s2r_atom_delta, tQsDelta[delta_coord], tQrDelta) + + # SegSum + tRT_rQ = self.pre_intra_segsum( + tTR_rQ, tQrDeltaA_Row, tQrDeltaA_Col, tQrDelta, tCoord, tRT_rQ + ) + + # Wait for INTRA2_Q buffer empty + # Delay producer_acquire to right before data store + intra2_q_pipeline.producer_acquire(intra2_q_producer_state) + + # Store Q from reg to tmem + q_coord = (None, None, None, None, intra2_q_producer_state.index) + cute.copy(tiled_r2t_q, tRT_rQ, tRT_tQ[q_coord]) + + # Async arrive Delta/INTRA1_ACC buffer empty + intra1_acc_pipeline.consumer_release(intra1_acc_consumer_state) + deltas_pipeline.consumer_release(deltas_consumer_state) + + cute.arch.fence_view_async_tmem_store() + + # Async arrive INTRA2_Q buffer full + intra2_q_pipeline.producer_commit(intra2_q_producer_state) + + # Advance deltas/intra1_acc/intra2_q states + deltas_consumer_state.advance() + intra1_acc_consumer_state.advance() + intra2_q_producer_state.advance() + + # Peek (try_wait) Delta/INTRA1_ACC buffer full for chunk_idx = chunk_idx + 1 + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_rd_intra1_acc_full_status = self.conditional_consumer_try_wait( + intra1_acc_consumer_state, intra1_acc_pipeline, C + ) + # END of for chunk_idx in cutlass.range(C, unroll=1) + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # END of while work_tile.is_valid_tile + + # Producer tail for INTRA2_Q + intra2_q_pipeline.producer_tail(intra2_q_producer_state) + # END of specialized pre-intra warp + + # Specialized Epilogue warp + else: + # Dealloc regs for pre-inter/pre-intra warps + cute.arch.warpgroup_reg_dealloc(self.num_regs_epilogue_warps) + + # (L, D, INPUT_STAGE) + sDeltaA = self.epilog_make_delta(smem_cumsum_delta) + + # Make tmem tensor for INTRA2_ACC/INTER2_ACC + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + tCtAccIntra2 = self.mma_partition_c( + tiled_mma_intra2, + self.tile_shape_mnk_intra2, + tmem_ptr_base + self.tmem_intra2_acc_offset, + self.internal_stages, + ) + # (M_PER_MMA, N_PER_MMA, INTERNAL_STAGE) + tIntra2 = tCtAccIntra2[((None, None), 0, 0, None)] + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + tCtAccInter2 = self.mma_partition_c( + tiled_mma_inter2, + self.tile_shape_mnk_inter2, + tmem_ptr_base + self.tmem_inter2_acc_offset, + self.internal_stages, + ) + # (M_PER_MMA, N_PER_MMA, INTERNAL_STAGE) + tInter2 = tCtAccInter2[((None, None), 0, 0, None)] + + # Subtiling INTRA2_ACC/INTER2_ACC/Delta/Y + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, INTERNAL_STAGE) + tIntra_epi = cute.flat_divide(tIntra2, epi_tile) + tInter_epi = cute.flat_divide(tInter2, epi_tile) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, INPUT_STAGE) + sDeltaA_epi = cute.flat_divide(sDeltaA, epi_tile) + + # Make tiled copy and partition tmem/reg tensor w.r.t tensor memory load + # ((T2R_ATOM_V, T2R_REST_V), REST_M, REST_N, EPI_M, EPI_N, INTERNAL_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), REST_M, REST_N) + ( + tiled_t2r_intra2, + tTR_tIntra, + tTR_rIntra, + ) = self.epilog_tmem_load_and_partition_acc(local_tidx, tIntra_epi, smem_y) + ( + tiled_t2r_inter2, + tTR_tInter2, + tTR_rInter, + ) = self.epilog_tmem_load_and_partition_acc(local_tidx, tInter_epi, smem_y) + + # Make tiled copy and partition smem/reg tensor w.r.t smem load Delta + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, EPI_M, EPI_N, INPUT_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + ( + s2r_atom_delta, + tTR_sDeltaA, + tTR_rDeltaA, + ) = self.smem_load_and_partition_delta_d( + tiled_t2r_inter2, local_tidx, sDeltaA_epi, (None, None, None, 0, 0, 0) + ) + + # Make tiled copy and Partition smem/register tensor w.r.t smem store Y + # ((R2S_ATOM_V, R2S_REST_V), REST_M, REST_N, OUTPUT_STAGE) + # ((R2S_ATOM_V, R2S_REST_V), REST_M, REST_N) + tiled_r2s_y, tRS_rY, tRS_sY = self.smem_store_and_partition_p_y( + local_tidx, smem_y, tiled_t2r_inter2 + ) + + tRS_rCompute = cute.make_fragment(tRS_rY.shape, self.acc_dtype) + + tiled_s2r_x = None + tSR_sX = None + tSR_rX = None + if cutlass.const_expr(self.has_d): + # Make TiledCopy/smem/register tensor for smem load X + # (R2S_ATOM, R2S_M, R2S_N, EPI_M, EPI_N, INPUT_STAGES) + # (R2S_ATOM, R2S_M, R2S_N) + tiled_s2r_x, tSR_sX, tSR_rX = self.epilog_smem_load_and_partition_x( + tiled_t2r_inter2, local_tidx, smem_xt, epi_tile + ) + + tRS_sD = None + tRS_rD = None + s2r_atom_d = None + if cutlass.const_expr(self.d_has_hdim): + # (L, D, INPUT_STAGE) + sD = self.epilog_make_d(smem_d) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, INPUT_STAGE) + tD_sepi = cute.flat_divide(sD, epi_tile) + + # Make tiled copy and partition smem/reg tensor w.r.t smem load D + # ((T2R_ATOM_V, T2R_REST_V), REST_M, REST_N, EPI_M, EPI_N, INPUT_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), REST_M, REST_N) + s2r_atom_d, tRS_sD, tRS_rD = self.smem_load_and_partition_delta_d( + tiled_t2r_inter2, local_tidx, tD_sepi, (None, None, None, 0, 0, 0) + ) + + elif cutlass.const_expr(self.has_d): + tRS_rD = cutlass.Float32(0.0).to(self.io_dtype) + + # Partition global/shared tensor for TMA store Y + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), EPI_M, EPI_N, 1, 1, C, EH, B) + bSG_sY, bSG_gY_pre_slice = self.epilog_tma_partition_y( + tma_tensor_y, tma_atom_y, smem_y, epi_tile + ) + + # Make TMA store pipeline Y + tma_y_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.output_stages, + producer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128 + ), + ) + + # Make consumer pipeline states for Delta/INTRA2_ACC/INTER2_ACC/X/D buffer + deltas_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + intra2_acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + inter2_acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + x_consumer_state = None + if cutlass.const_expr(self.has_d): + x_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + d_consumer_state = None + if cutlass.const_expr(self.d_has_hdim): + d_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + + while work_tile.is_valid_tile: + b_idx, eh_idx, g_idx = work_tile.tile_idx + + # Slice global tensor to current tile idx + # ((ATOM_V, REST_V), EPI_M, EPI_N, C) + bSG_gY = bSG_gY_pre_slice[(None, None, None, 0, 0, None, eh_idx, b_idx)] + if cutlass.const_expr(self.has_d and not self.d_has_hdim): + tRS_rD = tma_tensor_d[0, eh_idx] + + # Reset count for pipeline state + deltas_consumer_state.reset_count() + intra2_acc_consumer_state.reset_count() + inter2_acc_consumer_state.reset_count() + if cutlass.const_expr(self.has_d): + x_consumer_state.reset_count() + if cutlass.const_expr(self.d_has_hdim): + d_consumer_state.reset_count() + + # Peek Delta/INTRA2_ACC/INTER2_ACC buffer status + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_rd_intra2_acc_full_status = self.conditional_consumer_try_wait( + intra2_acc_consumer_state, intra2_acc_pipeline, C + ) + peek_rd_inter2_acc_full_status = self.conditional_consumer_try_wait( + inter2_acc_consumer_state, inter2_acc_pipeline, C + ) + peek_rd_x_full_status = None + if cutlass.const_expr(self.has_d): + peek_rd_x_full_status = self.conditional_consumer_try_wait( + x_consumer_state, x_pipeline, C + ) + + if cutlass.const_expr(self.d_has_hdim): + d_pipeline.consumer_wait(d_consumer_state) + + # Batched processing over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for Delta/INTRA2_ACC/INTER2_ACC/X buffer full + deltas_pipeline.consumer_wait( + deltas_consumer_state, peek_deltas_full_status + ) + intra2_acc_pipeline.consumer_wait( + intra2_acc_consumer_state, peek_rd_intra2_acc_full_status + ) + inter2_acc_pipeline.consumer_wait( + inter2_acc_consumer_state, peek_rd_inter2_acc_full_status + ) + if cutlass.const_expr(self.has_d): + x_pipeline.consumer_wait( + x_consumer_state, peek_rd_x_full_status + ) + # Loop over EPI_M and EPI_N subtiles + for epi_n in range(cute.size(tTR_tIntra, mode=[4])): + for epi_m in range(cute.size(tTR_tIntra, mode=[3])): + epi_iter_cnt = ( + epi_n * cute.size(tTR_tIntra, mode=[3]) + epi_m + ) + epi_buffer_idx = epi_iter_cnt % self.output_stages + + # Load INTRA2_ACC/INTER2_ACC from tmem + subtile_coord = ( + None, + None, + None, + epi_m, + epi_n, + ) + intra2_coord = subtile_coord + ( + intra2_acc_consumer_state.index, + ) + cute.copy( + tiled_t2r_intra2, + tTR_tIntra[intra2_coord], + tTR_rIntra, + ) + inter2_coord = subtile_coord + ( + inter2_acc_consumer_state.index, + ) + cute.copy( + tiled_t2r_inter2, + tTR_tInter2[inter2_coord], + tTR_rInter, + ) + # Fence for T2R load + cute.arch.fence_view_async_tmem_load() + + # Load Delta from smem + delta_coord = subtile_coord + (deltas_consumer_state.index,) + cute.copy( + s2r_atom_delta, tTR_sDeltaA[delta_coord], tTR_rDeltaA + ) + + # Load X from smem + if cutlass.const_expr(self.has_d): + x_coord = subtile_coord + (x_consumer_state.index,) + cute.copy(tiled_s2r_x, tSR_sX[x_coord], tSR_rX) + + # Load D from smem + if cutlass.const_expr(self.d_has_hdim): + # Load vector D from smem (d_has_hdim = True) + d_coord = subtile_coord + (d_consumer_state.index,) + cute.copy(s2r_atom_d, tRS_sD[d_coord], tRS_rD) + + # Combine INTRA2_ACC/INTER2_ACC/Delta/X/D + for reg_idx in range(0, cute.size(tRS_rCompute), 2): + ( + tRS_rCompute[reg_idx], + tRS_rCompute[reg_idx + 1], + ) = cute.arch.fma_packed_f32x2( + (tTR_rInter[reg_idx], tTR_rInter[reg_idx + 1]), + ( + cute.arch.exp(tTR_rDeltaA[reg_idx].ir_value()), + cute.arch.exp( + tTR_rDeltaA[reg_idx + 1].ir_value() + ), + ), + (tTR_rIntra[reg_idx], tTR_rIntra[reg_idx + 1]), + ) + # Fuse Y += X * D + if cutlass.const_expr(self.d_has_hdim): + ( + tRS_rCompute[reg_idx], + tRS_rCompute[reg_idx + 1], + ) = cute.arch.fma_packed_f32x2( + ( + tRS_rD[reg_idx].to(self.acc_dtype), + tRS_rD[reg_idx + 1].to(self.acc_dtype), + ), + ( + tSR_rX[reg_idx].to(self.acc_dtype), + tSR_rX[reg_idx + 1].to(self.acc_dtype), + ), + ( + tRS_rCompute[reg_idx], + tRS_rCompute[reg_idx + 1], + ), + ) + elif cutlass.const_expr(self.has_d): + ( + tRS_rCompute[reg_idx], + tRS_rCompute[reg_idx + 1], + ) = cute.arch.fma_packed_f32x2( + ( + tRS_rD.to(self.acc_dtype), + tRS_rD.to(self.acc_dtype), + ), + ( + tSR_rX[reg_idx].to(self.acc_dtype), + tSR_rX[reg_idx + 1].to(self.acc_dtype), + ), + ( + tRS_rCompute[reg_idx], + tRS_rCompute[reg_idx + 1], + ), + ) + + tRS_rY.store(tRS_rCompute.load().to(self.io_dtype)) + + # Store Y to smem + cute.copy( + tiled_r2s_y, + tRS_rY, + tRS_sY[None, None, None, epi_buffer_idx], + ) + + # Fence for R2S store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + # Sync before TMA store + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=len(self.epilog_warp_id) * 32, + ) + + # Async arrive Delta/INTRA2_ACC/INTER2_ACC buffer empty + if ( + epi_iter_cnt + == cute.size(tTR_tIntra, mode=[4]) + * cute.size(tTR_tIntra, mode=[3]) + - 1 + ): + deltas_pipeline.consumer_release(deltas_consumer_state) + intra2_acc_pipeline.consumer_release( + intra2_acc_consumer_state + ) + inter2_acc_pipeline.consumer_release( + inter2_acc_consumer_state + ) + if cutlass.const_expr(self.has_d): + x_pipeline.consumer_release( + x_consumer_state, + pipeline.PipelineOp.AsyncThread, + ) + + # TMA store Y to global memory + if local_warp_idx == 0: + cute.copy( + tma_atom_y, + bSG_sY[None, epi_buffer_idx], + bSG_gY[None, epi_m, epi_n, chunk_idx], + ) + + # Commit TMA store + tma_y_pipeline.producer_commit() + # Wait for TMA store + tma_y_pipeline.producer_acquire() + # Sync before smem store + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=len(self.epilog_warp_id) * 32, + ) + + # Advance deltas/intra2_acc/inter2_acc consumer states + deltas_consumer_state.advance() + intra2_acc_consumer_state.advance() + inter2_acc_consumer_state.advance() + + # Peek (try_wait) Delta/INTRA2_ACC/INTER2_ACC buffer full for chunk_idx = chunk_idx + 1 + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_rd_intra2_acc_full_status = self.conditional_consumer_try_wait( + intra2_acc_consumer_state, intra2_acc_pipeline, C + ) + peek_rd_inter2_acc_full_status = self.conditional_consumer_try_wait( + inter2_acc_consumer_state, inter2_acc_pipeline, C + ) + + if cutlass.const_expr(self.has_d): + # Advance x consumer states + x_consumer_state.advance() + # Peek (try_wait) X buffer full for chunk_idx = chunk_idx + 1 + peek_rd_x_full_status = self.conditional_consumer_try_wait( + x_consumer_state, x_pipeline, C + ) + + if cutlass.const_expr(self.d_has_hdim): + d_pipeline.consumer_release(d_consumer_state) + d_consumer_state.advance() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # Producer tail for TMA store Y + tma_y_pipeline.producer_tail() + + # Dealloc tmem buffer + if warp_idx == self.epilog_warp_id[0]: + cute.arch.barrier( + barrier_id=self.tmem_dealloc_sync_bar_id, + number_of_threads=self.threads_per_cta, + ) + cute.arch.dealloc_tmem( + tmem_ptr_base, + self.num_tmem_cols_total, + is_two_cta=self.use_2cta_instrs, + ) + else: + cute.arch.barrier_arrive( + barrier_id=self.tmem_dealloc_sync_bar_id, + number_of_threads=self.threads_per_cta, + ) + + return + + @staticmethod + def _compute_stages(smem_capacity): + return 2, 2, 1, 2 # input, output, internal, intra1_acc + + @staticmethod + def _compute_grid(y, b, max_active_clusters): + B = cute.size(y, mode=[4]) + EH = cute.size(y, mode=[3]) + G = cute.size(b, mode=[3]) + NGROUP_RATIO = EH // G + num_blocks = B * EH + + tile_sched_params = Mamba2SSDTileSchedulerParams(num_blocks, EH, NGROUP_RATIO) + grid = Mamba2SSDTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + return tile_sched_params, grid + + @staticmethod + def _plan_tmem_offsets( + tiled_mma_intra1, + tile_shape_mnk_intra1, + tiled_mma_intra2, + tile_shape_mnk_intra2, + tiled_mma_inter1, + tile_shape_mnk_inter1, + tiled_mma_inter2, + tile_shape_mnk_inter2, + acc_stages, + intra2_a_tmem_layout, + a_dtype, + internal_stages, + intra1_acc_stages, + ): + SM100_TMEM_CAPACITY_COLUMNS = 512 + BITS_PER_TMEM_COL = 32 + # (MMA, MMA_M, MMA_N) + acc_shape_intra1 = tiled_mma_intra1.partition_shape_C(tile_shape_mnk_intra1[:2]) + # (MMA, MMA_M, MMA_N) + tCtAccIntra1_fake = tiled_mma_intra1.make_fragment_C( + cute.append(acc_shape_intra1, intra1_acc_stages) + ) + num_intra1_acc_cols = tcgen05.find_tmem_tensor_col_offset(tCtAccIntra1_fake) + assert tile_shape_mnk_intra1[1] * intra1_acc_stages == num_intra1_acc_cols + # (MMA, MMA_N, MMA_K, STAGE) + tCrQ_fake = tiled_mma_intra2.make_fragment_A(intra2_a_tmem_layout.outer.shape) + num_intra2_a_cols = tcgen05.find_tmem_tensor_col_offset(tCrQ_fake) + assert ( + tile_shape_mnk_intra2[2] + * internal_stages + * a_dtype.width + // BITS_PER_TMEM_COL + == num_intra2_a_cols + ) + # (MMA, MMA_M, MMA_N) + acc_shape_intra2 = tiled_mma_intra2.partition_shape_C(tile_shape_mnk_intra2[:2]) + # (MMA, MMA_M, MMA_N) + tCtAccIntra2_fake = tiled_mma_intra2.make_fragment_C( + cute.append(acc_shape_intra2, acc_stages) + ) + num_intra2_acc_cols = tcgen05.find_tmem_tensor_col_offset(tCtAccIntra2_fake) + assert tile_shape_mnk_intra2[1] * acc_stages == num_intra2_acc_cols + + # (MMA, MMA_M, MMA_N) + acc_shape_inter1 = tiled_mma_inter1.partition_shape_C(tile_shape_mnk_inter1[:2]) + # (MMA, MMA_M, MMA_N) + tCtAccInter1_fake = tiled_mma_inter1.make_fragment_C( + cute.append(acc_shape_inter1, acc_stages) + ) + num_inter1_acc_cols = tcgen05.find_tmem_tensor_col_offset(tCtAccInter1_fake) + assert tile_shape_mnk_inter1[1] * acc_stages == num_inter1_acc_cols + + # (MMA, MMA_M, MMA_N) + acc_shape_inter2 = tiled_mma_inter2.partition_shape_C(tile_shape_mnk_inter2[:2]) + # (MMA, MMA_M, MMA_N) + tCtAccInter2_fake = tiled_mma_inter2.make_fragment_C( + cute.append(acc_shape_inter2, acc_stages) + ) + num_inter2_acc_cols = tcgen05.find_tmem_tensor_col_offset(tCtAccInter2_fake) + assert tile_shape_mnk_inter2[1] * acc_stages == num_inter2_acc_cols + + tmem_intra1_acc_offset = 0 + tmem_intra2_q_offset = tmem_intra1_acc_offset + num_intra1_acc_cols + tmem_intra2_acc_offset = tmem_intra2_q_offset + num_intra2_a_cols + tmem_inter1_acc_offset = tmem_intra2_acc_offset + num_intra2_acc_cols + tmem_inter2_acc_offset = tmem_inter1_acc_offset + num_inter1_acc_cols + num_tmem_cols_total_tmp = tmem_inter2_acc_offset + num_inter2_acc_cols + # Turn num_tmem_cols_total to the nearest power of 2 + num_tmem_cols_total = 1 + while num_tmem_cols_total < num_tmem_cols_total_tmp: + num_tmem_cols_total *= 2 + assert num_tmem_cols_total <= SM100_TMEM_CAPACITY_COLUMNS + + return ( + tmem_intra1_acc_offset, + tmem_intra2_q_offset, + tmem_intra2_acc_offset, + tmem_inter1_acc_offset, + tmem_inter2_acc_offset, + num_tmem_cols_total, + ) + + @staticmethod + def make_tiled_mmas( + io_dtype, + acc_dtype, + cta_group, + tile_shape_mnk_intra1, + tile_shape_mnk_intra2, + tile_shape_mnk_inter1, + tile_shape_mnk_inter2, + ): + tiled_mma_intra1 = sm100_utils.make_trivial_tiled_mma( + io_dtype, + tcgen05.OperandMajorMode("mn"), + tcgen05.OperandMajorMode("mn"), + acc_dtype, + cta_group, + tile_shape_mnk_intra1[:2], + tcgen05.OperandSource.SMEM, + ) + tiled_mma_intra2 = sm100_utils.make_trivial_tiled_mma( + io_dtype, + tcgen05.OperandMajorMode("k"), + tcgen05.OperandMajorMode("k"), + acc_dtype, + cta_group, + tile_shape_mnk_intra2[:2], + tcgen05.OperandSource.TMEM, + ) + tiled_mma_inter1 = sm100_utils.make_trivial_tiled_mma( + io_dtype, + tcgen05.OperandMajorMode("k"), + tcgen05.OperandMajorMode("k"), + acc_dtype, + cta_group, + tile_shape_mnk_inter1[:2], + tcgen05.OperandSource.SMEM, + ) + tiled_mma_inter2 = sm100_utils.make_trivial_tiled_mma( + io_dtype, + tcgen05.OperandMajorMode("mn"), + tcgen05.OperandMajorMode("k"), + acc_dtype, + cta_group, + tile_shape_mnk_inter2[:2], + tcgen05.OperandSource.SMEM, + ) + return tiled_mma_intra1, tiled_mma_intra2, tiled_mma_inter1, tiled_mma_inter2 + + def make_and_init_x_pipeline(self, x_full_mbar_ptr): + x_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.tma_deltas_x_d_warp_id]) + ) + if not self.has_d: + x_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_intra_warp_id, self.mma_inter_warp_id]), + ) + return pipeline.PipelineTmaUmma.create( + num_stages=self.input_stages, + producer_group=x_producer_group, + consumer_group=x_consumer_group, + tx_count=self.num_x_load_bytes, + barrier_storage=x_full_mbar_ptr, + ) + else: + x_consumer_group_umma = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_intra_warp_id, self.mma_inter_warp_id]), + ) + x_consumer_group_async = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128 + ) + return pipeline.PipelineTmaMultiConsumersAsync.create( + num_stages=self.input_stages, + producer_group=x_producer_group, + consumer_group_umma=x_consumer_group_umma, + consumer_group_async=x_consumer_group_async, + tx_count=self.num_x_load_bytes, + barrier_storage=x_full_mbar_ptr, + ) + + def make_and_init_b_pipeline(self, b_full_mbar_ptr): + b_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.tma_b_c_warp_id]) + ) + b_consumer_group_umma = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_intra_warp_id]) + ) + b_consumer_group_async = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + ) + return pipeline.PipelineTmaMultiConsumersAsync.create( + num_stages=self.input_stages, + producer_group=b_producer_group, + consumer_group_umma=b_consumer_group_umma, + consumer_group_async=b_consumer_group_async, + tx_count=self.num_b_load_bytes, + barrier_storage=b_full_mbar_ptr, + ) + + def make_and_init_c_pipeline(self, c_full_mbar_ptr): + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.tma_b_c_warp_id]) + ) + c_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_intra_warp_id, self.mma_inter_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + num_stages=self.input_stages, + producer_group=c_producer_group, + consumer_group=c_consumer_group, + tx_count=self.num_c_load_bytes, + barrier_storage=c_full_mbar_ptr, + ) + + def make_and_init_deltas_pipeline(self, deltas_full_mbar_ptr): + deltas_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.tma_deltas_x_d_warp_id]) + ) + deltas_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len( + [*self.pre_inter_warp_id, *self.pre_intra_warp_id, *self.epilog_warp_id] + ), + len( + [*self.pre_inter_warp_id, *self.pre_intra_warp_id, *self.epilog_warp_id] + ), + ) + + return pipeline.PipelineTmaAsync.create( + num_stages=self.input_stages, + producer_group=deltas_producer_group, + consumer_group=deltas_consumer_group, + tx_count=self.num_delta_load_bytes + self.num_cumsum_delta_load_bytes, + barrier_storage=deltas_full_mbar_ptr, + ) + + def make_and_init_d_pipeline(self, d_full_mbar_ptr): + if not self.d_has_hdim: + return None + else: + d_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.tma_deltas_x_d_warp_id]) + ) + d_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len(self.epilog_warp_id), + len(self.epilog_warp_id), + ) + + return pipeline.PipelineTmaAsync.create( + num_stages=self.input_stages, + producer_group=d_producer_group, + consumer_group=d_consumer_group, + tx_count=self.num_d_load_bytes, + barrier_storage=d_full_mbar_ptr, + ) + + def make_and_init_intra1_acc_pipeline(self, intra1_acc_full_mbar_ptr): + intra1_acc_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_intra_warp_id]) + ) + intra1_acc_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_intra_warp_id), 128 + ) + return pipeline.PipelineUmmaAsync.create( + num_stages=self.intra1_acc_stages, + producer_group=intra1_acc_producer_group, + consumer_group=intra1_acc_consumer_group, + barrier_storage=intra1_acc_full_mbar_ptr, + ) + + def make_and_init_intra2_q_pipeline(self, intra2_q_full_mbar_ptr): + intra2_q_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_intra_warp_id), 128 + ) + intra2_q_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_intra_warp_id]) + ) + return pipeline.PipelineAsyncUmma.create( + num_stages=self.internal_stages, + producer_group=intra2_q_producer_group, + consumer_group=intra2_q_consumer_group, + barrier_storage=intra2_q_full_mbar_ptr, + ) + + def make_and_init_intra2_acc_pipeline(self, intra2_acc_full_mbar_ptr): + intra2_acc_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_intra_warp_id]) + ) + intra2_acc_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128 + ) + return pipeline.PipelineUmmaAsync.create( + num_stages=self.internal_stages, + producer_group=intra2_acc_producer_group, + consumer_group=intra2_acc_consumer_group, + barrier_storage=intra2_acc_full_mbar_ptr, + ) + + def make_and_init_inter1_b_pipeline(self, inter1_b_full_mbar_ptr): + inter1_b_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + ) + inter1_b_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_inter_warp_id]) + ) + return pipeline.PipelineAsyncUmma.create( + num_stages=self.internal_stages, + producer_group=inter1_b_producer_group, + consumer_group=inter1_b_consumer_group, + barrier_storage=inter1_b_full_mbar_ptr, + ) + + def make_and_init_inter1_acc_pipeline(self, inter1_acc_full_mbar_ptr): + inter1_acc_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_inter_warp_id]) + ) + inter1_acc_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + ) + return pipeline.PipelineUmmaAsync.create( + num_stages=self.internal_stages, + producer_group=inter1_acc_producer_group, + consumer_group=inter1_acc_consumer_group, + barrier_storage=inter1_acc_full_mbar_ptr, + ) + + def make_and_init_inter2_p_pipeline(self, inter2_p_full_mbar_ptr): + inter2_p_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + ) + inter2_p_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_inter_warp_id]) + ) + return pipeline.PipelineAsyncUmma.create( + num_stages=self.internal_stages, + producer_group=inter2_p_producer_group, + consumer_group=inter2_p_consumer_group, + barrier_storage=inter2_p_full_mbar_ptr, + ) + + def make_and_init_inter2_acc_pipeline(self, inter2_acc_full_mbar_ptr): + inter2_acc_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_inter_warp_id]) + ) + inter2_acc_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128 + ) + return pipeline.PipelineUmmaAsync.create( + num_stages=self.internal_stages, + producer_group=inter2_acc_producer_group, + consumer_group=inter2_acc_consumer_group, + barrier_storage=inter2_acc_full_mbar_ptr, + ) + + def tma_partition_for_mma_b_operand( + self, + tma_atom_x, + tma_tensor_x, + smem_x, + tiled_mma_intra2, + cluster_layout_vmnk, + mma_tile_coord_v, + block_in_cluster_coord_vmnk, + ): + # Local_tile partition global tensors + # (D, L, 1, 1, C, EH, B) + gX = cute.local_tile( + tma_tensor_x, + self.tile_shape_mnk_intra2[1:], + (None, None, None, None, None), + ) + # Partition global tensor with regard to TiledMMA + thr_mma_intra2 = tiled_mma_intra2.get_slice(mma_tile_coord_v) + # (MMA, MMA_N, MMA_K, 1, 1, C, EH, B) + tCgX = thr_mma_intra2.partition_B(gX) + + # Partition global/shared tensor for X + x_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, 1, C, EH, B) + tXsX, tXgX_pre_slice = cpasync.tma_partition( + tma_atom_x, + block_in_cluster_coord_vmnk[2], + x_cta_layout, + cute.group_modes(smem_x, 0, 3), + cute.group_modes(tCgX, 0, 3), + ) + return tXsX, tXgX_pre_slice + + def tma_partition_for_mma_a_operand( + self, + tma_atom_c, + tma_tensor_c, + smem_c, + tiled_mma_intra1, + cluster_layout_vmnk, + mma_tile_coord_v, + block_in_cluster_coord_vmnk, + ): + # Local_tile partition global tensors + # (L, N, 1, 1, C, G, B) + gC = cute.local_tile( + tma_tensor_c, + cute.slice_(self.tile_shape_mnk_intra1, (None, 0, None)), + (None, None, None, None, None), + ) + # Partition global tensor with regard to TiledMMA + thr_mma_intra1 = tiled_mma_intra1.get_slice(mma_tile_coord_v) + # (MMA, MMA_M/N, MMA_K, 1, 1, C, G, B) + tCgC = thr_mma_intra1.partition_A(gC) + + # Partition global/shared tensor for TMA C + c_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, 1, C, G, B) + tCsC, tCgC_pre_slice = cpasync.tma_partition( + tma_atom_c, + block_in_cluster_coord_vmnk[1], + c_cta_layout, + cute.group_modes(smem_c, 0, 3), + cute.group_modes(tCgC, 0, 3), + ) + return tCsC, tCgC_pre_slice + + def tma_partition_with_shape( + self, tma_atom_delta, tma_tensor_delta, smem_delta, shape + ): + # Local_tile partition global tensors + # (L, 1, C, EH, B) + gDelta = cute.local_tile( + tma_tensor_delta, + shape, + (None,) * cute.rank(tma_tensor_delta), + ) + # Partition global/shared tensor for DELTA + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, C, EH, B) + tDeltasDelta, tDeltagDelta_pre_slice = cpasync.tma_partition( + tma_atom_delta, + 0, + cute.make_layout(1), + cute.group_modes(smem_delta, 0, cute.rank(shape)), + cute.group_modes(gDelta, 0, cute.rank(shape)), + ) + + return tDeltasDelta, tDeltagDelta_pre_slice + + def mma_partition_ss( + self, + tiled_mma, + tile_shape_mnk, + smem_a, + smem_b, + tmem_acc_ptr, + acc_stages, + ): + # (MMA, MMA_M, MMA_K, INPUT_STAGE) + tCrA = tiled_mma.make_fragment_A(smem_a) + # (MMA, MMA_N, MMA_K, INPUT_STAGE) + tCrB = tiled_mma.make_fragment_B(smem_b) + # (MMA, MMA_M, MMA_N, ACC_STAGE) + tCtAcc = self.mma_partition_c( + tiled_mma, tile_shape_mnk, tmem_acc_ptr, acc_stages + ) + return tCrA, tCrB, tCtAcc + + def mma_partition_ts( + self, + tiled_mma, + tile_shape_mnk, + a_tmem_layout, + smem_b, + tmem_a_ptr, + tmem_acc_ptr, + acc_stages, + ): + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + tCrA = self.mma_partition_a_tmem(tiled_mma, a_tmem_layout, tmem_a_ptr) + # (MMA, MMA_N, MMA_K, INPUT_STAGE) + tCrB = tiled_mma.make_fragment_B(smem_b) + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCtAcc = self.mma_partition_c( + tiled_mma, tile_shape_mnk, tmem_acc_ptr, acc_stages + ) + return tCrA, tCrB, tCtAcc + + def mma_partition_a_tmem(self, tiled_mma, a_tmem_layout, tmem_a_ptr): + tCrA_fake = tiled_mma.make_fragment_A(a_tmem_layout.outer.shape) + tCrA = cute.make_tensor( + cute.recast_ptr( + tmem_a_ptr, + dtype=tCrA_fake.element_type, + ), + tCrA_fake.layout, + ) + return tCrA + + def mma_partition_c(self, tiled_mma, tile_shape_mnk, tmem_acc_ptr, acc_stages): + acc_shape = tiled_mma.partition_shape_C(tile_shape_mnk[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, acc_stages)) + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCtAcc = cute.make_tensor(tmem_acc_ptr, tCtAcc_fake.layout) + return tCtAcc + + @cute.jit + def exec_mma( + self, + tiled_mma, + tCtAcc, + tCrA, + tCrB, + acc_producer_state, + a_consumer_state, + b_consumer_state, + ): + for kphase_idx in cutlass.range(cute.size(tCrB, mode=[2]), unroll_full=True): + # set accu = 1 + tiled_mma.set( + tcgen05.Field.ACCUMULATE, + cutlass.Boolean(kphase_idx != 0), + ) + cute.gemm( + tiled_mma, + tCtAcc[None, None, None, acc_producer_state.index], + tCrA[None, None, kphase_idx, a_consumer_state.index], + tCrB[None, None, kphase_idx, b_consumer_state.index], + tCtAcc[None, None, None, acc_producer_state.index], + ) + return tiled_mma + + @cute.jit + def conditional_consumer_try_wait(self, b_consumer_state, b_pipeline, C): + peek_b_full_status = cutlass.Boolean(1) + if b_consumer_state.count < C: + peek_b_full_status = b_pipeline.consumer_try_wait(b_consumer_state) + return peek_b_full_status + + @cute.jit + def conditional_producer_try_acquire( + self, intra1_acc_producer_state, intra1_acc_pipeline, C + ): + peek_wr_intra1_acc_empty_status = cutlass.Boolean(1) + if intra1_acc_producer_state.count < C: + peek_wr_intra1_acc_empty_status = intra1_acc_pipeline.producer_try_acquire( + intra1_acc_producer_state + ) + return peek_wr_intra1_acc_empty_status + + def pre_intra_tmem_load_and_partition_q(self, tIntra1, local_tidx): + copy_atom_t2r_intra1 = cute.make_copy_atom( + tcgen05.Ld16x256bOp(tcgen05.Repetition(16), tcgen05.Pack.NONE), + self.acc_dtype, + ) + # (L, L) + fake_sQ = cute.make_tensor( + cute.make_ptr(self.io_dtype, 0, cute.AddressSpace.smem), + cute.dice(self.tile_shape_mnk_intra1, (1, 1, None)), + ) + return self.make_tmem_load_and_partition( + copy_atom_t2r_intra1, tIntra1, (None, None, 0), local_tidx, fake_sQ + ) + + def pre_intra_make_delta(self, smem_delta, extend_on_row_or_col): + smem_iterator = smem_delta.iterator + delta_linear_smem_layout = smem_delta.layout + # extend L linear layout to LxL + extend_layout = cute.make_layout(delta_linear_smem_layout.shape[0], stride=0) + if extend_on_row_or_col == 0: + # (L, L, INPUT_STAGE):(0, 1, L) + sDelta = cute.make_tensor( + smem_iterator, + cute.prepend( + delta_linear_smem_layout, + extend_layout, + ), + ) + else: + # (L, L, INPUT_STAGE):(1, 0, L) + sDelta = cute.make_tensor( + smem_iterator, + cute.append( + cute.append( + cute.get(delta_linear_smem_layout, mode=[0]), + extend_layout, + ), + cute.get(delta_linear_smem_layout, mode=[1]), + ), + ) + return sDelta + + def pre_intra_tmem_store_and_partition_q(self, local_tidx, tCrQ): + dtype = tCrQ.element_type + # Make tiledCopy for tensor memory store INTRA2_Q + copy_atom_r2t_q = cute.make_copy_atom( + tcgen05.St16x128bOp(tcgen05.Repetition(16), tcgen05.Unpack.NONE), + dtype, + ) + tiled_r2t_q = tcgen05.make_tmem_copy(copy_atom_r2t_q, tCrQ) + thr_r2t_q = tiled_r2t_q.get_slice(local_tidx) + + # Partition tmem/register tensor for tensor memory store INTRA2_Q + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ...) + tRT_rQ = cute.make_fragment( + cute.slice_(thr_r2t_q.partition_S(tCrQ).shape, (None, None, None, None, 0)), + dtype, + ) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ..., INTERNAL_STAGE) + tRT_tQ = thr_r2t_q.partition_D(tCrQ) + + return tiled_r2t_q, tRT_rQ, tRT_tQ + + @cute.jit + def pre_intra_segsum( + self, tTR_rQ, tQrDeltaA_Row, tQrDeltaA_Col, tQrDelta, tCoord, tRT_rQ + ): + # Make tmp acc type fragments + tCrDeltaA_Row = cute.make_fragment(tQrDeltaA_Row.shape, self.acc_dtype) + tCrDeltaA_Col = cute.make_fragment(tQrDeltaA_Col.shape, self.acc_dtype) + tCrDelta = cute.make_fragment(tQrDelta.shape, self.acc_dtype) + tCompute = cute.make_fragment(tRT_rQ.shape, self.acc_dtype) + + # Combine tTR_rQ/tCrDeltaA_Row/tCrDeltaA_Col/tCrDelta + tCrDeltaA_Row.store(tQrDeltaA_Row.load().to(self.acc_dtype)) + tCrDeltaA_Col.store(tQrDeltaA_Col.load().to(self.acc_dtype)) + tCrDelta.store(tQrDelta.load().to(self.acc_dtype)) + + # SegSum + # fadd2 + fsel + fmul2/mufu + fmul2 + for subtile_idx in range(0, cute.size(tTR_rQ), 2): + ( + tCompute[subtile_idx], + tCompute[subtile_idx + 1], + ) = cute.arch.add_packed_f32x2( + (tCrDeltaA_Col[subtile_idx], tCrDeltaA_Col[subtile_idx + 1]), + (-tCrDeltaA_Row[subtile_idx], -tCrDeltaA_Row[subtile_idx + 1]), + ) + for subtile_idx in range(cute.size(tTR_rQ)): + m, n = tCoord[subtile_idx] + if m < n: + tCompute[subtile_idx] = cutlass.Float32(-float("inf")) + for subtile_idx in range(0, cute.size(tTR_rQ), 2): + # TODO: use math.exp directly + ( + tCompute[subtile_idx], + tCompute[subtile_idx + 1], + ) = cute.arch.mul_packed_f32x2( + cute.arch.exp_packed_f32x2( + (tCompute[subtile_idx], tCompute[subtile_idx + 1]) + ), + (tCrDelta[subtile_idx], tCrDelta[subtile_idx + 1]), + ) + ( + tCompute[subtile_idx], + tCompute[subtile_idx + 1], + ) = cute.arch.mul_packed_f32x2( + (tCompute[subtile_idx], tCompute[subtile_idx + 1]), + (tTR_rQ[subtile_idx], tTR_rQ[subtile_idx + 1]), + ) + + tRT_rQ.store(tCompute.load().to(self.io_dtype)) + return tRT_rQ + + def pre_inter_smem_load_and_partition_b(self, local_tidx, smem_bt): + dtype = smem_bt.element_type + copy_atom_s2r_b = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + dtype, + num_bits_per_copy=128, + ) + num_elements_per_thread = 128 // dtype.width + num_threads_per_row = self.tile_shape_mnk_inter1[2] // num_elements_per_thread + num_threads_per_col = 128 // num_threads_per_row + thread_layout = cute.make_layout( + (num_threads_per_col, num_threads_per_row), + stride=(num_threads_per_row, 1), + ) + val_layout = cute.make_layout((1, num_elements_per_thread)) + tiled_s2r_b = cute.make_tiled_copy_tv( + copy_atom_s2r_b, + thread_layout, + val_layout, + ) + thr_s2r_b = tiled_s2r_b.get_slice(local_tidx) + + # Partition shared tensor for smem load Bt + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N, INPUT_STAGE) + tBsB_s2r = thr_s2r_b.partition_S(smem_bt) + + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) + tBrB_s2r = cute.make_fragment( + cute.slice_(tBsB_s2r.shape, (None, None, None, 0)), + dtype, + ) + return tiled_s2r_b, tBsB_s2r, tBrB_s2r + + def pre_inter_smem_store_and_partition_b( + self, local_tidx, smem_bt_internal, tiled_s2r_b, tBrB_s2r + ): + dtype = smem_bt_internal.element_type + # Make tiledCopy from register to smem store Bt + copy_atom_r2s_b = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + dtype, + num_bits_per_copy=128, + ) + tiled_r2s_b = cute.make_tiled_copy( + copy_atom_r2s_b, + layout_tv=tiled_s2r_b.layout_tv_tiled, + tiler_mn=tiled_s2r_b.tiler_mn, + ) + thr_r2s_b = tiled_r2s_b.get_slice(local_tidx) + + # Partition shared tensor for smem store Bt + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE) + tBsB_r2s = thr_r2s_b.partition_D(smem_bt_internal) + + # Make register fragments for smem load/store Bt + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) + tBrB_r2s = thr_r2s_b.retile(tBrB_s2r) + return tiled_r2s_b, tBrB_r2s, tBsB_r2s + + def smem_load_and_partition_delta_d( + self, tiled_s2r_b, local_tidx, smem_delta, smem_tile_coord + ): + dtype = smem_delta.element_type + s2r_atom_delta = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), dtype) + + thr_s2r_b = tiled_s2r_b.get_slice(local_tidx) + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N, INPUT_STAGE) + tBsDelta_s2r = thr_s2r_b.partition_D(smem_delta) + + # Make register fragments for smem load/store of Delta/DeltaA + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) + tBrDelta_s2r = cute.make_fragment(tBsDelta_s2r[smem_tile_coord].shape, dtype) + return s2r_atom_delta, tBsDelta_s2r, tBrDelta_s2r + + def pre_inter_tmem_load_and_partition_p(self, local_tidx, tInter1, smem_pt): + copy_atom_t2r_inter1 = cute.make_copy_atom( + tcgen05.Ld16x256bOp(tcgen05.Repetition(8), tcgen05.Pack.NONE), + self.acc_dtype, + ) + return self.make_tmem_load_and_partition( + copy_atom_t2r_inter1, + tInter1, + (None, None, 0), + local_tidx, + smem_pt[None, None, 0], + ) + + def make_tmem_load_and_partition( + self, copy_atom_t2r, tmem_tensor, tmem_tile_coord, local_tidx, smem_tensor + ): + dtype = tmem_tensor.element_type + tiled_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tmem_tensor[tmem_tile_coord]) + thr_t2r = tiled_t2r.get_slice(local_tidx) + # Partition tmem/shared tensor for tmem load INTER1_ACC + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + tTR_t = thr_t2r.partition_S(tmem_tensor) + tTR_s = thr_t2r.partition_D(smem_tensor) + # Make register fragments for tmem load INTER1_ACC + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + tTR_r = cute.make_fragment( + tTR_s.shape, + dtype, + ) + return tiled_t2r, tTR_t, tTR_r + + def smem_store_and_partition_p_y(self, local_tidx, smem_pt, tiled_t2r_inter1): + dtype = smem_pt.element_type + copy_atom_r2s_p = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=True, num_matrices=4), + dtype, + ) + tiled_r2s_p = cute.make_tiled_copy_D(copy_atom_r2s_p, tiled_t2r_inter1) + thr_r2s_p = tiled_r2s_p.get_slice(local_tidx) + # Partition smem/register tensor for smem store INTER2_P + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE) + tRS_sP = thr_r2s_p.partition_D(smem_pt) + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N) + tRS_rP = cute.make_fragment( + cute.slice_(tRS_sP.shape, (None, None, None, 0)), self.io_dtype + ) + return tiled_r2s_p, tRS_rP, tRS_sP + + def pre_inter_make_delta(self, smem_delta, smem_bt_layout): + # Broadcast Delta/DeltaA to Bt shape on M dimension + # before: (128,(64,2),2):(64,(1,8192),16384) + # after : (128,(64,2),2):(0,(1,64),128) + # (MMA, MMA_M, MMA_K, INPUT_STAGE) + sDeltaA = cute.make_tensor( + smem_delta.iterator, + cute.make_layout( + smem_bt_layout.shape, + stride=( + 0, + (1, cute.get(smem_bt_layout.shape, mode=[1, 0])), + smem_delta.layout.stride[1], + ), + ), + ) + return sDeltaA + + def pre_inter_scale_bt_with_delta( + self, tBrB_s2r, tBrDelta_s2r, tBrDeltaA_s2r, last_column + ): + tCompute = cute.make_fragment(tBrB_s2r.shape, self.acc_dtype) + tBrB_Compute = cute.make_fragment(tBrB_s2r.shape, self.acc_dtype) + tBrDelta_Compute = cute.make_fragment(tBrDelta_s2r.shape, self.acc_dtype) + tBrDeltaA_Compute = cute.make_fragment(tBrDeltaA_s2r.shape, self.acc_dtype) + + tBrB_Compute.store(tBrB_s2r.load().to(self.acc_dtype)) + tBrDelta_Compute.store(tBrDelta_s2r.load().to(self.acc_dtype)) + tBrDeltaA_Compute.store(tBrDeltaA_s2r.load().to(self.acc_dtype)) + + for reg_idx in range(0, cute.size(tBrB_Compute), 2): + tCompute[reg_idx], tCompute[reg_idx + 1] = cute.arch.mul_packed_f32x2( + ( + cute.arch.exp( + (last_column - tBrDeltaA_Compute[reg_idx]).ir_value() + ), + cute.arch.exp( + (last_column - tBrDeltaA_Compute[reg_idx + 1]).ir_value() + ), + ), + (tBrDelta_Compute[reg_idx], tBrDelta_Compute[reg_idx + 1]), + ) + tCompute[reg_idx], tCompute[reg_idx + 1] = cute.arch.mul_packed_f32x2( + (tCompute[reg_idx], tCompute[reg_idx + 1]), + (tBrB_Compute[reg_idx], tBrB_Compute[reg_idx + 1]), + ) + return tCompute + + def epilog_make_delta(self, smem_cumsum_delta): + # Broadcast cumsum delta from LxINPUT_STAGE to LxDxINPUT_STAGE + sDeltaA = cute.make_tensor( + smem_cumsum_delta.iterator, + cute.make_layout( + (*self.tile_shape_mnk_inter2[:2], self.input_stages), + stride=(1, 0, smem_cumsum_delta.layout.shape[0]), + ), + ) + return sDeltaA + + def epilog_make_d(self, smem_d): + # Broadcast d from DxINPUT_STAGE to LxDxINPUT_STAGE + sD = cute.make_tensor( + smem_d.iterator, + cute.make_layout( + (*self.tile_shape_mnk_inter2[:2], self.input_stages), + stride=(0, 1, smem_d.layout.shape[0]), + ), + ) + return sD + + def epilog_tma_partition_y(self, tma_tensor_y, tma_atom_y, smem_y, epi_tile): + # Local_tile partition global tensors + # (L, D, 1, 1, C, EH, B) + gY = cute.local_tile( + tma_tensor_y, + cute.slice_(self.tile_shape_mnk_inter2, (None, None, 0)), + (None, None, None, None, None), + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, 1, 1, C, EH, B) + gY_epi = cute.flat_divide(gY, epi_tile) + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), EPI_M, EPI_N, 1, 1, C, EH, B) + bSG_sY, bSG_gY_pre_slice = cpasync.tma_partition( + tma_atom_y, + 0, + cute.make_layout(1), + cute.group_modes(smem_y, 0, 2), + cute.group_modes(gY_epi, 0, 2), + ) + return bSG_sY, bSG_gY_pre_slice + + def epilog_smem_load_and_partition_x( + self, tiled_t2r_inter2_intra2, local_tidx, smem_xt, epi_tile + ): + dtype = smem_xt.element_type + copy_atom_s2r_x = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), + dtype, + ) + tiled_s2r_x = cute.make_tiled_copy_D(copy_atom_s2r_x, tiled_t2r_inter2_intra2) + thr_s2r_x = tiled_s2r_x.get_slice(local_tidx) + # Partition smem/register tensor for smem store INTER2_P + # (R2S_ATOM, R2S_M, R2S_N, EPI_M, EPI_N, INPUT_STAGES) + tSR_sX = thr_s2r_x.partition_S(cute.flat_divide(smem_xt, epi_tile)) + # (R2S_ATOM, R2S_M, R2S_N) + tSR_rX = cute.make_fragment( + cute.slice_(tSR_sX.shape, (None, None, None, 0, 0, 0)), dtype + ) + return tiled_s2r_x, tSR_sX, tSR_rX + + def epilog_tmem_load_and_partition_acc(self, local_tidx, tIntra, smem_y): + copy_atom_t2r_inter2_intra2 = cute.make_copy_atom( + tcgen05.Ld16x256bOp(tcgen05.Repetition(4), tcgen05.Pack.NONE), + self.acc_dtype, + ) + return self.make_tmem_load_and_partition( + copy_atom_t2r_inter2_intra2, + tIntra, + (None, None, 0, 0, 0), + local_tidx, + smem_y[None, None, 0], + ) + + +def run_ssd( + gbehcdln: Tuple[int, int, int, int, int, int, int, int], + io_dtype: Type[cutlass.Numeric], + cumsum_delta_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + has_d: bool, + d_has_hdim: bool, + tolerance: float, + print_rtol_stats: bool, + ref_lower_precision: bool, +): + print(f"Running B100 Mamba2 SSD with:") + print(f"GBEHCDLN: {gbehcdln}") + print( + f"Input/Output dtype: {io_dtype}, Intermediate delta dtype: {cumsum_delta_dtype}, Acc dtype: {acc_dtype}" + ) + print( + f"Has D (True means fuse Y+=X*D): {has_d}, D has Hdim (True means D.shape DxEH, False means 1xEH): {d_has_hdim}" + ) + print(f"Tolerance: {tolerance}") + + # Unpack parameters + G, B, E, H, C, D, L, N = gbehcdln + EH = E * H + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + # Match same seed in ssd_reference.py for reference check + torch.manual_seed(42) + + # Create and permute tensor A/B/C + def create_and_permute_tensor( + shape, permute_order, dtype, dt_or_a=0, dynamic_modes=None, ref_tensor=None + ): + # Build fp32 reference torch tensor + if ref_tensor is None: + ref_tensor = ( + torch.empty(*shape, dtype=torch.float32) + # .random_(-1, 1) + .normal_(0, 0.5) + # .uniform_(-1,1) + .permute(permute_order) + ) + if dt_or_a == 1: # dt: + ref_tensor = F.softplus(ref_tensor - 4) + elif dt_or_a == 2: # A: + ref_tensor = -torch.exp(ref_tensor) + + # Build torch_dtype torch tensor + torch_dtype = cutlass_torch.dtype(dtype) + + dst_tensor = ref_tensor.to(torch_dtype).cuda() + cute_tensor = from_dlpack(dst_tensor, assumed_align=16) + for mode in dynamic_modes: + cute_tensor = cute_tensor.mark_compact_shape_dynamic( + mode=mode, stride_order=dst_tensor.dim_order() + ) + + return ref_tensor, cute_tensor, dst_tensor + + # INPUT tensors + # x: (D, L, C, EH, B):(C*L, 1, L, D*C*L, EH*D*C*L) + x_ref, x_tensor, x_torch = create_and_permute_tensor( + [B, EH, D, C, L], [2, 4, 3, 1, 0], io_dtype, dynamic_modes=[2, 3, 4] + ) + # delta/delta_a/cumsum_delta: (L, C, EH, B):(1, L, C*L, EH*C*L) + delta_ref, delta_tensor, delta_torch = create_and_permute_tensor( + [B, EH, C, L], [3, 2, 1, 0], io_dtype, dt_or_a=1, dynamic_modes=[1, 2, 3] + ) + # a: (EH):(1) + a_ref, a_tensor, a_torch = create_and_permute_tensor( + [EH], [0], io_dtype, dt_or_a=2, dynamic_modes=[0] + ) + + if has_d: + # d: (D, EH):(1, D) or (1, EH):(0, 1) + d_ref, d_tensor, d_torch = create_and_permute_tensor( + [EH, D if d_has_hdim else 1], [1, 0], io_dtype, dynamic_modes=[1] + ) + else: + d_ref = None + d_tensor = None + + # b/c: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L) + b_ref, b_tensor, b_torch = create_and_permute_tensor( + [B, G, N, C, L], [4, 2, 3, 1, 0], io_dtype, dynamic_modes=[2, 3, 4] + ) + c_ref, c_tensor, c_torch = create_and_permute_tensor( + [B, G, N, C, L], [4, 2, 3, 1, 0], io_dtype, dynamic_modes=[2, 3, 4] + ) + + # OUTPUT tensors + # y: (L, D, C, EH, B):(1, C*L, L, D*C*L, EH*D*C*L) + y_ref, y_tensor, y_torch = create_and_permute_tensor( + [B, EH, D, C, L], [4, 2, 3, 1, 0], io_dtype, dynamic_modes=[2, 3, 4] + ) + # fstate: (D, N, EH, B):(N, 1, D*N, EH*D*N) + fstate_ref, fstate_tensor, fstate_torch = create_and_permute_tensor( + [B, EH, D, N], [2, 3, 1, 0], io_dtype, dynamic_modes=[2, 3] + ) + + # Call pytorch reference on cpu + if not ref_lower_precision: + ssd_reference_fp32_all( + x_ref, + a_ref, + delta_ref, + b_ref, + c_ref, + y_ref, + fstate_ref, + d_ref, + has_d, + d_has_hdim, + ) + else: + ssd_reference_lowprecision_intermediates( + x_ref, + a_ref, + delta_ref, + b_ref, + c_ref, + y_ref, + fstate_ref, + cutlass_torch.dtype(io_dtype), + d_ref, + has_d, + d_has_hdim, + ) + + # Compute cumsum with pytorch on cpu + delta_a_ref = delta_ref * a_ref.view(1, 1, -1, 1) + cumsum_delta_ref = torch.empty([B, EH, C, L], dtype=torch.float32).permute( + [3, 2, 1, 0] + ) + cumsum_delta_ref.copy_(torch.cumsum(delta_a_ref, dim=0).permute([0, 1, 2, 3])) + # Copy cumsum_delta_ref to cumsum_delta_tensor + ( + cumsum_delta_ref, + cumsum_delta_tensor, + cumsum_delta_torch, + ) = create_and_permute_tensor( + [B, EH, C, L], + [3, 2, 1, 0], + cumsum_delta_dtype, + ref_tensor=cumsum_delta_ref, + dynamic_modes=[1, 2, 3], + ) + + # Call fused ssd kernel + ssd = SSDKernel( + io_dtype, + cumsum_delta_dtype, + acc_dtype, + L, + D, + N, + has_d, + d_has_hdim, + ) + + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters(1) + + stream = cutlass.cuda.default_stream() + # Compile ssd kernel + compiled_ssd = cute.compile( + ssd, + x_tensor, + cumsum_delta_tensor, + delta_tensor, + b_tensor, + c_tensor, + y_tensor, + fstate_tensor, + d_tensor, + max_active_clusters, + stream, + ) + + # Launch compiled ssd kernel + compiled_ssd( + x_tensor, + cumsum_delta_tensor, + delta_tensor, + b_tensor, + c_tensor, + y_tensor, + fstate_tensor, + d_tensor, + stream, + ) + + # Reference check + if print_rtol_stats: + print("\nY's Relative diffs:") + analyze_relative_diffs(y_torch.cpu(), y_ref.to(cutlass_torch.dtype(io_dtype))) + print("\nFstate's Relative diffs:") + analyze_relative_diffs( + fstate_torch.cpu(), fstate_ref.to(cutlass_torch.dtype(io_dtype)) + ) + torch.testing.assert_close( + y_torch.cpu(), + y_ref.to(cutlass_torch.dtype(io_dtype)), + atol=tolerance, + rtol=1e-02, + ) + torch.testing.assert_close( + fstate_torch.cpu(), + fstate_ref.to(cutlass_torch.dtype(io_dtype)), + atol=tolerance, + rtol=1e-05, + ) + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> List[int]: + try: + return [int(x.strip()) for x in s.split(",")] + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of MxNxKxL GEMM on Blackwell." + ) + + parser.add_argument( + "--gbehcdln", + type=parse_comma_separated_ints, + default=[2, 4, 2, 40, 32, 64, 128, 128], + # default=[2, 3, 2, 2, 8, 64, 128, 128], + # default=[1, 2, 1, 4, 8, 64, 128, 128], + help="gbehcdln dimensions (comma-separated)", + ) + parser.add_argument("--io_dtype", type=cutlass.dtype, default=cutlass.BFloat16) + parser.add_argument( + "--cumsum_delta_dtype", type=cutlass.dtype, default=cutlass.Float32 + ) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--fuse_scale_d", + type=str, + choices=["none", "scalar", "vector"], + default="vector", + help="Fuse scale type: none (no Y+=X*D fusion), scalar (Y+=X*D fusion with D.shape=1xEH), or vector (Y+=X*D fusion with D.shape=DxEH)", + ) + parser.add_argument( + "--ref_lower_precision", + type=bool, + default=True, + help="Use lower precision for reference check", + ) + parser.add_argument( + "--tolerance", type=float, default=5e-02, help="Tolerance for validation" + ) + parser.add_argument( + "--print_rtol_stats", type=bool, default=True, help="Print rtol stats" + ) + + args = parser.parse_args() + + if len(args.gbehcdln) != 8: + parser.error("--gbehcdln must contain exactly 8 values") + + has_d = args.fuse_scale_d != "none" + d_has_hdim = args.fuse_scale_d == "vector" + + run_ssd( + args.gbehcdln, + args.io_dtype, + args.cumsum_delta_dtype, + args.acc_dtype, + has_d, + d_has_hdim, + args.tolerance, + args.print_rtol_stats, + args.ref_lower_precision, + ) + print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_reference.py b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_reference.py new file mode 100644 index 00000000..ced5c6f2 --- /dev/null +++ b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_reference.py @@ -0,0 +1,397 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch +import torch.nn.functional as F + + +def ssd_reference_fp32_all(x, a, delta, B, C, Y_out, Fstate_out, D, has_d, d_has_hdim): + """ + Rearrange tensor dimensions from cuda layout to reference layout, then directly call TriDao's ssd implementation + Arguments: + X/x: (D, L, C, H, B):(C*L, 1, L, D*C*L, H*D*C*L) + A/delta: (L, C, H, B):(1, L, C*L, H*C*L) + a: (H):(1) + B/C: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L) + D: (1, H):(0, 1) or (D, H):(1, D) + has_d: bool + d_has_hdim: bool + Return: + Y_out: (L, D, C, H, B):(1, C*L, L, D*C*L, H*D*C*L) + Fstate_out: (D, N, H, B):(N, 1, D*N, H*D*N) + """ + assert x.dtype == a.dtype == delta.dtype == B.dtype == C.dtype + + A = delta * a.view(1, 1, -1, 1) + X = x * delta.unsqueeze(0) + + # Rearrange to match cutlass layout to tridao's layout + block_len = A.shape[0] + initial_states = None + # A: l c h b-> b c l h + A = A.permute(3, 1, 0, 2) + # X: p l c h b -> b c l h p + X = X.permute(4, 2, 1, 3, 0) + # B: l n c g b -> b c l g n + B = B.permute(4, 2, 0, 3, 1) + # C: l n c g b -> b c l g n + C = C.permute(4, 2, 0, 3, 1) + # X/A/B/C: b c l ... -> b (c l) ... + X, A, B, C = [x.reshape(x.shape[0], -1, *x.shape[3:]) for x in (X, A, B, C)] + + # Ngroup (g to h) mapping + B_val, CL_val, G_val, N_val = B.shape + H_val = X.shape[2] + ngroup_ratio = H_val // G_val + # B/C: (B, CL, H, N) + h_to_g_mapping = torch.arange(H_val, device=B.device) // ngroup_ratio + B = B.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val)) + C = C.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val)) + + ################################################################### + # Call reference implementation from Tri Dao ssd_minimal_discrete + Y, final_state = ssd_minimal_discrete_fp32_all( + X, A, B, C, block_len, initial_states + ) + ################################################################### + + if has_d: + D_val = Y.shape[3] + if not d_has_hdim: + D = D.expand(D_val, -1) + Y = Y + torch.einsum("bchp,ph->bchp", X, D) + + # Rearrange to match tridao's layout to cutlass layout + # Y: b (c l) h p -> b c l h p + Y = Y.reshape(Y.shape[0], -1, block_len, Y.shape[2], Y.shape[3]) + # Y: b c l h p -> l p c h b + Y = Y.permute(2, 4, 1, 3, 0) + # Fstate_out: b h p n -> p n h b + Fstate_out.copy_(final_state.permute(2, 3, 1, 0)) + Y_out.copy_(Y) + return + + +def ssd_reference_lowprecision_intermediates( + x, a, delta, B, C, Y_out, Fstate_out, intermediate_dtype, D, has_d, d_has_hdim +): + """ + Rearrange tensor dimensions from cuda layout to reference layout, then call a reduced intermediate dtype version of ssd implementation + Arguments: + X/x: (D, L, C, H, B):(C*L, 1, L, D*C*L, H*D*C*L) + A/delta: (L, C, H, B):(1, L, C*L, H*C*L) + a: (H):(1) + B/C: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L) + intermediate_dtype: input and intermediate data type + D: (1, H):(0, 1) or (D, H):(1, D) + has_d: bool + d_has_hdim: bool + Return: + Y_out: (L, D, C, H, B):(1, C*L, L, D*C*L, H*D*C*L) + Fstate_out: (D, N, H, B):(N, 1, D*N, H*D*N) + """ + assert x.dtype == a.dtype == delta.dtype == B.dtype == C.dtype + + A = delta * a.view(1, 1, -1, 1) + + # Rearrange to match cutlass layout to tridao's layout + block_len = A.shape[0] + initial_states = None + # A: l c h b-> b c l h + A = A.permute(3, 1, 0, 2) + # delta: l c h b-> b c l h + delta = delta.permute(3, 1, 0, 2) + # x: p l c h b -> b c l h p + x = x.permute(4, 2, 1, 3, 0) + # B: l n c g b -> b c l g n + B = B.permute(4, 2, 0, 3, 1) + # C: l n c g b -> b c l g n + C = C.permute(4, 2, 0, 3, 1) + # x/A/delta/B/C: b c l ... -> b (c l) ... + x, A, delta, B, C = [ + tensor.reshape(tensor.shape[0], -1, *tensor.shape[3:]) + for tensor in (x, A, delta, B, C) + ] + + # Ngroup (g to h) mapping + B_val, CL_val, G_val, N_val = B.shape + H_val = x.shape[2] + ngroup_ratio = H_val // G_val + # B/C: (B, CL, H, N) + h_to_g_mapping = torch.arange(H_val, device=B.device) // ngroup_ratio + B = B.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val)) + C = C.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val)) + + # Type convert input tensors to input dtype (same as intermediate dtype) + x = x.to(intermediate_dtype).to(torch.float32) + A = A.to(intermediate_dtype).to(torch.float32) + delta = delta.to(intermediate_dtype).to(torch.float32) + B = B.to(intermediate_dtype).to(torch.float32) + C = C.to(intermediate_dtype).to(torch.float32) + + ######################################################################### + # Call reference implementation ssd_minimal_discrete_bf16_intermediates + Y, final_state = ssd_minimal_discrete_lowprecision_intermediates( + x, A, delta, B, C, block_len, intermediate_dtype, initial_states + ) + ######################################################################### + + if has_d: + D = D.to(intermediate_dtype).to(torch.float32) + D_val = Y.shape[3] + if not d_has_hdim: + D = D.expand(D_val, -1) + Y = Y + torch.einsum("bchp,ph->bchp", x, D) + + # Type convert output tensors to output dtype (same as intermediate dtype) + Y = Y.to(intermediate_dtype).to(torch.float32) + final_state = final_state.to(intermediate_dtype).to(torch.float32) + + # Rearrange to match tridao's layout to cutlass layout + # Y: b (c l) h p -> b c l h p + Y = Y.reshape(Y.shape[0], -1, block_len, Y.shape[2], Y.shape[3]) + # Y: b c l h p -> l p c h b + Y = Y.permute(2, 4, 1, 3, 0) + # Fstate_out: b h p n -> p n h b + Fstate_out.copy_(final_state.permute(2, 3, 1, 0)) + Y_out.copy_(Y) + return + + +def analyze_relative_diffs(actual, expected): + """ + Print statistics of relative differences between actual and expected tensors + """ + # Calculate relative differences + abs_diff = (actual - expected).abs() + rel_diff = abs_diff / (torch.maximum(expected.abs(), actual.abs()) + 0.00001) + + total_elements = rel_diff.numel() + + # Handle special cases first + nan_mask = torch.isnan(rel_diff) + inf_mask = torch.isinf(rel_diff) + nan_count = nan_mask.sum().item() + inf_count = inf_mask.sum().item() + + # Find position and value of maximum relative difference + max_rel_diff = ( + rel_diff[~nan_mask & ~inf_mask].max() + if (~nan_mask & ~inf_mask).any() + else float("nan") + ) + max_rel_diff_pos = ( + rel_diff[~nan_mask & ~inf_mask].argmax() + if (~nan_mask & ~inf_mask).any() + else -1 + ) + + # Print max relative difference info + print(f"Maximum relative difference:") + print(f"Position: {max_rel_diff_pos}") + print(f"Value: {max_rel_diff:.6e}") + print(f"Actual value: {actual.flatten()[max_rel_diff_pos]}") + print(f"Expected value: {expected.flatten()[max_rel_diff_pos]}") + print(f"NaN values: {nan_count} ({100.0 * nan_count / total_elements:.2f}%)") + print(f"Inf values: {inf_count} ({100.0 * inf_count / total_elements:.2f}%)\n") + + # Check different rtol thresholds + rtol_levels = [1e-5, 1e-4, 1e-3, 1e-2, 5e-02, 1e-01] + + for i, rtol in enumerate(rtol_levels): + if i == 0: + mask = rel_diff <= rtol + else: + mask = (rel_diff <= rtol) & (rel_diff > rtol_levels[i - 1]) + + count = mask.sum().item() + percentage = (count / total_elements) * 100 + + if i == 0: + print(f"Elements with rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)") + else: + print( + f"Elements with {rtol_levels[i-1]:.0e} < rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)" + ) + + # Print elements exceeding the largest rtol + mask = rel_diff > rtol_levels[-1] + count = mask.sum().item() + percentage = (count / total_elements) * 100 + print(f"Elements with rtol > {rtol_levels[-1]:.0e}: {count} ({percentage:.2f}%)\n") + + +def segsum(x): + """ + More stable segment sum calculation. + x: b h c l + """ + T = x.size(-1) + # x: b h c l -> b h c l l + x = x.unsqueeze(-1).expand(*x.shape, T) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + x_segsum = torch.cumsum(x, dim=-2) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def ssd_minimal_discrete_fp32_all(X, A, B, C, block_len, initial_states=None): + """ + This is same with https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py + (all accumulation and intermediate results in fp32) + + Arguments: + X: (batch(B), length(C*L), n_heads(H), d_head(D)) + A: (batch(B), length(C*L), n_heads(H)) + B: (batch(B), length(C*L), n_heads(H), d_state(N)) + C: (batch(B), length(C*L), n_heads(H), d_state(N)) + Return: + Y: (batch(B), length(C*L), n_heads(H), d_head(D)) + final_state: (B, H, D, N) + """ + assert X.dtype == A.dtype == B.dtype == C.dtype + assert X.shape[1] % block_len == 0 + + # Rearrange into blocks/chunks + # X/A/B/C:b (c l) ... -> b c l ... + X, A, B, C = [ + x.reshape(x.shape[0], -1, block_len, *x.shape[2:]) for x in (X, A, B, C) + ] + + # A: b c l h -> b h c l + A = A.permute(0, 3, 1, 2) + # A_cumsum: (B, H, C, L) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + segsum_A = segsum(A) + L = torch.exp(segsum_A) + Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + # Y: b c l h p -> b (c l) h p + Y = (Y_diag + Y_off).reshape(Y_diag.shape[0], -1, Y_diag.shape[3], Y_diag.shape[4]) + return Y, final_state + + +def ssd_minimal_discrete_lowprecision_intermediates( + X, A, delta, B, C, block_len, intermediate_dtype, initial_states=None +): + """ + This is adjusted from ssd_minimal_discrete_fp32_all, with exceptions: + 1. accumulation in fp32 but intermediates Q/b_tmem/P are in intermediate_dtype + 2. delta is not pre-multiplied with X, delta was applied to generate Q/b_tmem to match GPU implementation + + Arguments: + X: (batch(B), length(C*L), n_heads(H), d_head(D)) + A: (batch(B), length(C*L), n_heads(H)) + delta: (batch(B), length(C*L), n_heads(H)) + B: (batch(B), length(C*L), n_heads(H), d_state(N)) + C: (batch(B), length(C*L), n_heads(H), d_state(N)) + Return: + Y: (batch(B), length(C*L), n_heads(H), d_head(D)) + final_state: (B, H, D, N) + """ + assert X.dtype == A.dtype == B.dtype == C.dtype + assert X.shape[1] % block_len == 0 + + # Rearrange into blocks/chunks + # X/A/delta/B/C: b (c l) ... -> b c l ... + X, A, delta, B, C = [ + x.reshape(x.shape[0], -1, block_len, *x.shape[2:]) for x in (X, A, delta, B, C) + ] + + # A: b c l h -> b h c l + A = A.permute(0, 3, 1, 2) + # delta: b c l h -> b h c l + delta = delta.permute(0, 3, 1, 2) + # A_cumsum: (B, H, C, L) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + segsum_A = segsum(A) + L = torch.exp(segsum_A) + intra_acc_0 = torch.einsum("bclhn,bcshn->bclhs", C, B) + Q = torch.einsum("bclhs,bhcls,bhcs->bclhs", intra_acc_0, L, delta) + Y_diag = torch.einsum( + "bclhs,bcshp->bclhp", Q.to(intermediate_dtype).to(torch.float32), X + ) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + b_tmem = torch.einsum("bclhn,bhcl,bhcl->bclhn", B, decay_states, delta) + states = torch.einsum( + "bclhn,bclhp->bchpn", b_tmem.to(intermediate_dtype).to(torch.float32), X + ) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + final_state = final_state + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off_tmp = torch.einsum( + "bclhn,bchpn->bclhp", C, states.to(intermediate_dtype).to(torch.float32) + ) + Y_off = torch.einsum("bclhp,bhcl->bclhp", Y_off_tmp, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + # Y: b c l h p -> b (c l) h p + Y = (Y_diag + Y_off).reshape( + Y_diag.shape[0], -1, Y_diag.shape[3], Y_diag.shape[4] + ) # b (c l) h p + return Y, final_state diff --git a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_tile_scheduler.py b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_tile_scheduler.py new file mode 100644 index 00000000..544b4772 --- /dev/null +++ b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_tile_scheduler.py @@ -0,0 +1,200 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import Tuple + +from cutlass.cutlass_dsl import ( + Boolean, + Integer, + Int32, + min, + extract_mlir_values, + new_from_mlir_values, + dsl_user_op, +) +from cutlass._mlir import ir +import cutlass.cute as cute +from cutlass.utils import WorkTileInfo + + +class Mamba2SSDTileSchedulerParams: + def __init__( + self, + problem_shape_ntiles: int, + eh: int, + ngroup_ratio: int, + *, + loc=None, + ip=None, + ): + self.problem_shape_ntiles = problem_shape_ntiles + self.eh = eh + self.ngroup_ratio = ngroup_ratio + self._loc = loc + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.problem_shape_ntiles, self.eh, self.ngroup_ratio]: + obj_values = extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.problem_shape_ntiles, self.eh, self.ngroup_ratio], self._values_pos + ): + obj_list.append(new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return Mamba2SSDTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) + + @dsl_user_op + def get_grid_shape( + self, max_active_clusters: Int32, *, loc=None, ip=None + ) -> Tuple[Integer, Integer, Integer]: + return (min(self.problem_shape_ntiles, max_active_clusters), 1, 1) + + +class Mamba2SSDTileScheduler: + def __init__( + self, + params: Mamba2SSDTileSchedulerParams, + num_persistent_ctas: Int32, + current_work_linear_idx: Int32, + num_tiles_executed: Int32, + ): + self.params = params + self.num_persistent_ctas = num_persistent_ctas + self._current_work_linear_idx = current_work_linear_idx + self._num_tiles_executed = num_tiles_executed + + def __extract_mlir_values__(self) -> list[ir.Value]: + values = extract_mlir_values(self.num_persistent_ctas) + values.extend(extract_mlir_values(self._current_work_linear_idx)) + values.extend(extract_mlir_values(self._num_tiles_executed)) + return values + + def __new_from_mlir_values__( + self, values: list[ir.Value] + ) -> "Mamba2SSDTileScheduler": + assert len(values) == 3 + new_num_persistent_ctas = new_from_mlir_values( + self.num_persistent_ctas, [values[0]] + ) + new_current_work_linear_idx = new_from_mlir_values( + self._current_work_linear_idx, [values[1]] + ) + new_num_tiles_executed = new_from_mlir_values( + self._num_tiles_executed, [values[2]] + ) + return Mamba2SSDTileScheduler( + self.params, + new_num_persistent_ctas, + new_current_work_linear_idx, + new_num_tiles_executed, + ) + + # called by host + @dsl_user_op + @staticmethod + def create( + params: Mamba2SSDTileSchedulerParams, + block_idx: Tuple[Integer, Integer, Integer], + grid_dim: Tuple[Integer, Integer, Integer], + *, + loc=None, + ip=None, + ): + params = params + + # Calculate the number of persistent clusters by dividing the total grid size + # by the number of CTAs per cluster + num_persistent_ctas = Int32(cute.size(grid_dim, loc=loc, ip=ip)) + + bidx, bidy, bidz = block_idx + + # Initialize workload index equals to the cluster index in the grid + current_work_linear_idx = Int32(bidx) + + # Initialize number of tiles executed to zero + num_tiles_executed = Int32(0) + return Mamba2SSDTileScheduler( + params, + num_persistent_ctas, + current_work_linear_idx, + num_tiles_executed, + ) + + # called by host + @staticmethod + def get_grid_shape( + params: Mamba2SSDTileSchedulerParams, + max_active_clusters: Int32, + *, + loc=None, + ip=None, + ) -> Tuple[Integer, Integer, Integer]: + return params.get_grid_shape(max_active_clusters, loc=loc, ip=ip) + + # private method + def _get_current_work_for_linear_idx( + self, current_work_linear_idx: Int32, *, loc=None, ip=None + ) -> WorkTileInfo: + is_valid = current_work_linear_idx < cute.size( + self.params.problem_shape_ntiles, loc=loc, ip=ip + ) + + eh_idx = current_work_linear_idx % self.params.eh + b_idx = current_work_linear_idx // self.params.eh + g_idx = eh_idx // self.params.ngroup_ratio + # cur_tile_coord is (b_idx, eh_idx, g_idx) + cur_tile_coord = tuple(Int32(x) for x in (b_idx, eh_idx, g_idx)) + + return WorkTileInfo(cur_tile_coord, is_valid) + + @dsl_user_op + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + return self._get_current_work_for_linear_idx( + self._current_work_linear_idx, loc=loc, ip=ip + ) + + @dsl_user_op + def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: + return self.get_current_work(loc=loc, ip=ip) + + @dsl_user_op + def advance_to_next_work(self, *, advance_count: int = 1, loc=None, ip=None): + self._current_work_linear_idx += Int32(advance_count) * Int32( + self.num_persistent_ctas + ) + self._num_tiles_executed += Int32(1) + + @property + def num_tiles_executed(self) -> Int32: + return self._num_tiles_executed diff --git a/examples/python/CuTeDSL/cute/ffi/CMakeLists.txt b/examples/python/CuTeDSL/cute/ffi/CMakeLists.txt index 6c02b198..f2cf2336 100644 --- a/examples/python/CuTeDSL/cute/ffi/CMakeLists.txt +++ b/examples/python/CuTeDSL/cute/ffi/CMakeLists.txt @@ -30,7 +30,7 @@ cmake_minimum_required(VERSION 3.15) project(tensor) # Find Python -find_package(Python COMPONENTS Interpreter Development REQUIRED) +find_package(Python3 COMPONENTS Interpreter Development REQUIRED) # Get Python site-packages directory using Python execute_process( diff --git a/examples/python/CuTeDSL/hopper/dense_gemm.py b/examples/python/CuTeDSL/hopper/dense_gemm.py index dc9a0604..d9d9783d 100644 --- a/examples/python/CuTeDSL/hopper/dense_gemm.py +++ b/examples/python/CuTeDSL/hopper/dense_gemm.py @@ -36,6 +36,7 @@ import torch import cutlass import cutlass.cute as cute import cutlass.utils as utils +import cutlass.pipeline as pipeline import cutlass.torch as cutlass_torch from cutlass.cute.runtime import from_dlpack import cutlass.utils.hopper_helpers as sm90_utils @@ -579,20 +580,25 @@ class HopperWgmmaGemmKernel: mainloop_pipeline_array_ptr = storage.mainloop_pipeline_array_ptr.data_ptr() # Threads/warps participating in this pipeline - mainloop_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread) - # Set the consumer arrive count to the number of mcast size - consumer_arrive_cnt = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 - mainloop_pipeline_consumer_group = utils.CooperativeGroup( - utils.Agent.Thread, consumer_arrive_cnt + mainloop_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread + ) + # Each warp will constribute to the arrive count with the number of mcast size + mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + num_warps = self.threads_per_cta // 32 + consumer_arrive_cnt = mcast_size * num_warps + mainloop_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, consumer_arrive_cnt ) - mainloop_pipeline = utils.PipelineTmaAsync.create( + cta_layout_vmnk = cute.make_layout((1, *cta_layout_mnk.shape)) + mainloop_pipeline = pipeline.PipelineTmaAsync.create( barrier_storage=mainloop_pipeline_array_ptr, num_stages=self.ab_stage, producer_group=mainloop_pipeline_producer_group, consumer_group=mainloop_pipeline_consumer_group, tx_count=tma_copy_bytes, - cta_layout_vmnk=cta_layout_mnk, + cta_layout_vmnk=cta_layout_vmnk, ) # Cluster arrive after barrier init @@ -616,11 +622,11 @@ class HopperWgmmaGemmKernel: # /////////////////////////////////////////////////////////////////////////////// # Local_tile partition global tensors # /////////////////////////////////////////////////////////////////////////////// - # (bM, bK, loopK) + # (bM, bK, RestK) gA_mkl = cute.local_tile( mA_mkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, None, 1) ) - # (bN, bK, loopK) + # (bN, bK, RestK) gB_nkl = cute.local_tile( mB_nkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1) ) @@ -696,14 +702,14 @@ class HopperWgmmaGemmKernel: k_tile_cnt = cute.size(gA_mkl, mode=[2]) prefetch_k_tile_cnt = cutlass.max(cutlass.min(self.ab_stage, k_tile_cnt), 0) - mainloop_producer_state = utils.make_pipeline_state( - utils.PipelineUserType.Producer, self.ab_stage + mainloop_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.ab_stage ) if warp_idx == 0: # ///////////////////////////////////////////////////////////////////////////// # Prefetch TMA load # ///////////////////////////////////////////////////////////////////////////// - for prefetch_idx in cutlass.range_dynamic(prefetch_k_tile_cnt, unroll=1): + for prefetch_idx in cutlass.range(prefetch_k_tile_cnt, unroll=1): # ///////////////////////////////////////////////////////////////////////////// # Wait for A/B buffers to be empty before loading into them # Also sets the transaction barrier for the A/B buffers @@ -748,11 +754,11 @@ class HopperWgmmaGemmKernel: # ///////////////////////////////////////////////////////////////////////////// k_pipe_mmas = 1 - mainloop_consumer_read_state = utils.make_pipeline_state( - utils.PipelineUserType.Consumer, self.ab_stage + mainloop_consumer_read_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage ) - mainloop_consumer_release_state = utils.make_pipeline_state( - utils.PipelineUserType.Consumer, self.ab_stage + mainloop_consumer_release_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage ) peek_ab_full_status = cutlass.Boolean(1) @@ -763,14 +769,14 @@ class HopperWgmmaGemmKernel: tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False) num_k_blocks = cute.size(tCrA, mode=[2]) - for k_tile in cutlass.range_dynamic(k_pipe_mmas, unroll=1): + for k_tile in range(k_pipe_mmas): # Wait for A/B buffer to be ready mainloop_pipeline.consumer_wait( mainloop_consumer_read_state, peek_ab_full_status ) cute.nvgpu.warpgroup.fence() - for k_block_idx in range(num_k_blocks): + for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True): k_block_coord = ( None, None, @@ -800,7 +806,7 @@ class HopperWgmmaGemmKernel: # ///////////////////////////////////////////////////////////////////////////// # MAINLOOP # ///////////////////////////////////////////////////////////////////////////// - for k_tile in cutlass.range_dynamic(k_pipe_mmas, k_tile_cnt, 1, unroll=1): + for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, 1, unroll=1): # ///////////////////////////////////////////////////////////////////////////// # Wait for TMA copies to complete # ///////////////////////////////////////////////////////////////////////////// @@ -811,7 +817,7 @@ class HopperWgmmaGemmKernel: # WGMMA # ///////////////////////////////////////////////////////////////////////////// cute.nvgpu.warpgroup.fence() - for k_block_idx in range(num_k_blocks): + for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True): k_block_coord = ( None, None, @@ -949,7 +955,7 @@ class HopperWgmmaGemmKernel: epi_tile_num = cute.size(tcgc_for_tma_partition, mode=[1]) epi_tile_shape = tcgc_for_tma_partition.shape[1] - for epi_idx in cutlass.range_dynamic(epi_tile_num, unroll=epi_tile_num): + for epi_idx in cutlass.range(epi_tile_num, unroll=epi_tile_num): # Copy from accumulators to D registers for epi_v in range(size_tRS_rD): tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v] @@ -1213,7 +1219,7 @@ class HopperWgmmaGemmKernel: c_cta_v_layout = cute.composition( cute.make_identity_layout(tensor_c.shape), epi_tile ) - tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tma_tile_atom( + tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tiled_tma_atom( cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(), tensor_c, epi_smem_layout, @@ -1250,7 +1256,7 @@ class HopperWgmmaGemmKernel: ) smem_layout = cute.slice_(smem_layout_staged, (None, None, 0)) - tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tma_tile_atom( + tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom( op, tensor, smem_layout, diff --git a/examples/python/CuTeDSL/notebooks/cute_layout_algebra.ipynb b/examples/python/CuTeDSL/notebooks/cute_layout_algebra.ipynb index 035776aa..3a3f9ed7 100644 --- a/examples/python/CuTeDSL/notebooks/cute_layout_algebra.ipynb +++ b/examples/python/CuTeDSL/notebooks/cute_layout_algebra.ipynb @@ -297,7 +297,7 @@ " assert depth <= 1, f\"Depth of coalesced layout should be <= 1, got {depth}\"\n", "\n", " print(\">>> 3. Checking layout functionality remains the same after the coalesce operation:\")\n", - " for i in range(original_size):\n", + " for i in cutlass.range_constexpr(original_size):\n", " original_value = layout(i)\n", " coalesced_value = result(i)\n", " print(f\"Index {i}: original {original_value}, coalesced {coalesced_value}\")\n", diff --git a/examples/python/CuTeDSL/notebooks/data_types.ipynb b/examples/python/CuTeDSL/notebooks/data_types.ipynb index e618885d..dc305fff 100644 --- a/examples/python/CuTeDSL/notebooks/data_types.ipynb +++ b/examples/python/CuTeDSL/notebooks/data_types.ipynb @@ -60,48 +60,7 @@ "@cute.jit\n", "def foo(a: cutlass.Int32): # annotate `a` as 32-bit integer passed to jit function via ABI\n", " ...\n", - "```\n", - "To differentiate between compile-time and runtime values, CuTe DSL introduces primitive types that \n", - "represent dynamic values in JIT-compiled code.\n", - "\n", - "CuTe DSL provides a comprehensive set of primitive numeric types for representing dynamic values at \n", - "runtime. These types are formally defined within the CuTe DSL typing system:\n", - "\n", - "### Integer Types\n", - "- `Int8` - 8-bit signed integer\n", - "- `Int16` - 16-bit signed integer \n", - "- `Int32` - 32-bit signed integer\n", - "- `Int64` - 64-bit signed integer\n", - "- `Int128` - 128-bit signed integer\n", - "- `Uint8` - 8-bit unsigned integer\n", - "- `Uint16` - 16-bit unsigned integer\n", - "- `Uint32` - 32-bit unsigned integer\n", - "- `Uint64` - 64-bit unsigned integer\n", - "- `Uint128` - 128-bit unsigned integer\n", - "\n", - "### Floating Point Types\n", - "- `Float16` - 16-bit floating point\n", - "- `Float32` - 32-bit floating point \n", - "- `Float64` - 64-bit floating point\n", - "- `BFloat16` - Brain Floating Point format (16-bit)\n", - "- `TFloat32` - Tensor Float32 format (reduced precision format used in tensor operations)\n", - "- `Float8E4M3` - 8-bit floating point with 4-bit exponent and 3-bit mantissa\n", - "- `Float8E5M2` - 8-bit floating point with 5-bit exponent and 2-bit mantissa\n", - "\n", - "These specialized types are designed to represent dynamic values in CuTe DSL code that will be \n", - "evaluated at runtime, in contrast to Python's built-in numeric types which are evaluated during \n", - "compilation.\n", - "\n", - "### Example usage:\n", - "\n", - "```python\n", - "x = cutlass.Int32(5) # Creates a 32-bit integer\n", - "y = cutlass.Float32(3.14) # Creates a 32-bit float\n", - "\n", - "@cute.jit\n", - "def foo(a: cutlass.Int32): # annotate `a` as 32-bit integer passed to jit function via ABI\n", - " ...\n", - "```" + "```\n" ] }, { diff --git a/examples/python/CuTeDSL/notebooks/tensorssa.ipynb b/examples/python/CuTeDSL/notebooks/tensorssa.ipynb index 8d83e02e..3e812681 100644 --- a/examples/python/CuTeDSL/notebooks/tensorssa.ipynb +++ b/examples/python/CuTeDSL/notebooks/tensorssa.ipynb @@ -120,7 +120,7 @@ " src_vec = src.load()\n", " dst_vec = src_vec[indices]\n", " print(f\"{src_vec} -> {dst_vec}\")\n", - " if isinstance(dst_vec, cute.TensorSSA):\n", + " if cutlass.const_expr(isinstance(dst_vec, cute.TensorSSA)):\n", " dst.store(dst_vec)\n", " cute.print_tensor(dst)\n", " else:\n", diff --git a/include/cute/algorithm/copy.hpp b/include/cute/algorithm/copy.hpp index d05b170f..1da05ad4 100644 --- a/include/cute/algorithm/copy.hpp +++ b/include/cute/algorithm/copy.hpp @@ -260,13 +260,10 @@ copy(AutoVectorizingCopyWithAssumedAlignment const&, { // If more than one element vectorizes to 8bits or more, then recast and copy using VecType = uint_bit_t; - // Preserve volatility - using SrcVecType = conditional_t, VecType const volatile, VecType const>; - using DstVecType = conditional_t, VecType volatile, VecType >; // Recast - Tensor src_v = recast(src); - Tensor dst_v = recast(dst); + Tensor src_v = recast(src); + Tensor dst_v = recast(dst); return copy_if(constant_fn{}, src_v, dst_v); } else { return copy_if(constant_fn{}, src, dst); diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 5c455cc3..0331d992 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -325,21 +325,6 @@ struct TiledCopy : Copy_Atom return tile2thrfrg(ref_S, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}))(_,_,Int<0>{}); } - CUTE_HOST_DEVICE constexpr static - auto - get_layoutS_MN() - { - // (thr_idx,val_idx) -> (M,N) - auto layoutS_TV = get_layoutS_TV(); - // (M,K) -> (thr_idx,val_idx) - auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(Tiler_MN{})); - - // athrid = (v,m,k) -> thr_idx - auto thrID_S = make_layout(size<0>(TiledLayout_TV{})); - - return cute::make_tuple(layoutS_MK, thrID_S); - } - CUTE_HOST_DEVICE constexpr static auto get_layoutD_TV() @@ -350,21 +335,6 @@ struct TiledCopy : Copy_Atom return tile2thrfrg(ref_D, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}))(_,_,Int<0>{}); } - CUTE_HOST_DEVICE constexpr static - auto - get_layoutD_MN() - { - // (thr_idx,val_idx) -> (M,N) - auto layoutD_TV = get_layoutD_TV(); - // (M,K) -> (thr_idx,val_idx) - auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(Tiler_MN{})); - - // athrid = (v,m,k) -> thr_idx - auto thrID_D = make_layout(size<0>(TiledLayout_TV{})); - - return cute::make_tuple(layoutD_MK, thrID_D); - } - template ::value)> CUTE_HOST_DEVICE static @@ -680,101 +650,6 @@ print(ThrCopy const& thr_copy) print(TiledCopy{}); } -// TiledCopy to LaTeX TikZ -template -CUTE_HOST_DEVICE -auto -print_latex(TiledCopy const& copy, - TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string -{ - auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN(); - auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN(); - - print_latex_copy(layoutS_MN, thrID_S, - layoutD_MN, thrID_D); -} - -// MNK Copy Layout to LaTeX TikZ -template -CUTE_HOST_DEVICE -void -print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx - LayoutD const& D, ThrIDD const& TD, // (m,n) -> (tid,vid) and tid -> thr_idx - TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string -{ - CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{}); - - assert(size<0>(S) == size<0>(D)); - assert(size<1>(S) == size<1>(D)); - - // Commented prints - printf("%% LayoutS: "); print(S); printf("\n"); - printf("%% ThrIDS : "); print(TS); printf("\n"); - printf("%% LayoutD: "); print(D); printf("\n"); - printf("%% ThrIDD : "); print(TD); printf("\n\n"); - - // Header - printf("\\documentclass[convert]{standalone}\n" - "\\usepackage{tikz}\n\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); - - // S starting at 0,0 - for (int i = 0; i < size<0>(S); ++i) { - for (int j = 0; j < size<1>(S); ++j) { - int thrid = S(i,j) % size(TS); - int val_idx = S(i,j) / size(TS); - int thr_idx = TS(thrid); - - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(thr_idx, val_idx), - i, j, - thr_idx, val_idx); - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", - 0, 0, int(size<0>(S)), int(size<1>(S))); - // S Labels - for (int i = 0, j = -1; i < size<0>(S); ++i) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); - } - for (int i = -1, j = 0; j < size<1>(S); ++j) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); - } - - // D starting at 0,size<1>(S)+3 - for (int i = 0; i < size<0>(D); ++i) { - for (int j = 0; j < size<1>(D); ++j) { - int thrid = D(i,j) % size(TD); - int val_idx = D(i,j) / size(TD); - int thr_idx = TD(thrid); - - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(thr_idx, val_idx), - i, j + size<1>(S) + 3, - thr_idx, val_idx); - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", - 0, int(size<1>(S)+3), int(size<0>(D)), int(size<1>(D)+size<1>(S)+3)); - // D Labels - for (int i = 0, j = size<1>(D); i < size<0>(D); ++i) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, i); - } - for (int i = -1, j = 0; j < size<1>(D); ++j) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, j); - } - - // Footer - printf("\\end{tikzpicture}\n" - "\\end{document}\n"); -} - } // end namespace cute //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index e2f9bdfc..6367fc0c 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -180,10 +180,10 @@ struct MMA_Atom> if constexpr (has_dereference::value) { // If the intended FrgTypeB is a view (of the current tensor), forward the whole static_assert(is_same::value_type>::value - + || (sizeof_bits_v::value_type> == 8 && (sizeof_bits_v == 8 || sizeof_bits_v == 6 || sizeof_bits_v == 4)) - + , "Expecting ValTypeB type"); return make_tensor(static_cast(btensor)); } else { @@ -394,55 +394,22 @@ struct TiledMMA : MMA_Atom return size(permutation_mnk()); } - CUTE_HOST_DEVICE constexpr - auto - get_layoutC_MN() const - { - // (M,N) -> (M,N) - auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>())); - // (cthrid,val) -> (M,N) - auto layoutC_TV = thrfrg_C(ref_C); - // (M,N) -> (cthrid,frg) - auto layoutC_MN = right_inverse(layoutC_TV).with_shape(shape(ref_C)); - - // cthrid = (v,m,n) -> thr_idx - auto thrID_C = thr_layout_vmnk_(_,_,_,Int<0>{}); - - return cute::make_tuple(layoutC_MN, thrID_C); - } - CUTE_HOST_DEVICE constexpr auto get_layoutC_TV() const { // (M,N) -> (M,N) auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>())); - // (cthrid,val) -> (M,N) - auto layoutC_TV = thrfrg_C(ref_C); // thr_idx -> (ThrV,ThrM,ThrN,ThrK) - auto thridx_2_thrid = right_inverse(thr_layout_vmnk_); + auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}), + make_stride(Int<1>{}, Int<0>{})), + right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_)))); // (thr_idx,val) -> (M,N) - return layoutC_TV.compose(thridx_2_thrid, _); + return thrfrg_C(ref_C).compose(thridx_2_thrid, _); } - CUTE_HOST_DEVICE constexpr - auto - get_layoutA_MK() const - { - // (M,K) -> (M,K) - auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>())); - // (athrid,val) -> (M,K) - auto layoutA_TV = thrfrg_A(ref_A); - // (M,K) -> (athrid,frg) - auto layoutA_MK = right_inverse(layoutA_TV).with_shape(shape(ref_A)); - - // athrid = (v,m,k) -> thr_idx - auto thrID_A = thr_layout_vmnk_(_,_,Int<0>{},_); - - return cute::make_tuple(layoutA_MK, thrID_A); - } CUTE_HOST_DEVICE constexpr auto @@ -458,29 +425,14 @@ struct TiledMMA : MMA_Atom _)); // thr_idx -> (ThrV,ThrM,ThrN,ThrK) - auto thridx_2_thrid = right_inverse(thr_layout_vmnk_); + auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}), + make_stride(Int<1>{}, Int<0>{})), + right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_)))); // (thr_idx,val) -> (M,K) return thrfrg_A(ref_A).compose(atile, _).compose(thridx_2_thrid, _); } - CUTE_HOST_DEVICE constexpr - auto - get_layoutB_NK() const - { - // (N,K) -> (N,K) - auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>())); - // (bthrid,val) -> (N,K) - auto layoutB_TV = thrfrg_B(ref_B); - // (N,K) -> (bthrid,frg) - auto layoutB_NK = right_inverse(layoutB_TV).with_shape(shape(ref_B)); - - // bthrid = (v,n,k) -> thr_idx - auto thrID_B = thr_layout_vmnk_(_,Int<0>{},_,_); - - return cute::make_tuple(layoutB_NK, thrID_B); - } - CUTE_HOST_DEVICE constexpr auto get_layoutB_TV() const @@ -495,7 +447,9 @@ struct TiledMMA : MMA_Atom _)); // thr_idx -> (ThrV,ThrM,ThrN,ThrK) - auto thridx_2_thrid = right_inverse(thr_layout_vmnk_); + auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}), + make_stride(Int<1>{}, Int<0>{})), + right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_)))); // (thr_idx,val) -> (N,K) return thrfrg_B(ref_B).compose(btile, _).compose(thridx_2_thrid, _); @@ -733,376 +687,6 @@ print(ThrMMA const& thr_mma) print(static_cast(thr_mma)); } -// MMA Atom to LaTeX TikZ -template -CUTE_HOST_DEVICE -void -print_latex(MMA_Atom const& mma_atom, - TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string -{ - print_latex(make_tiled_mma(mma_atom)); -} - -// TiledMMA to LaTeX TikZ -template -CUTE_HOST_DEVICE -void -print_latex(TiledMMA const& mma, - TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string -{ - auto layout_and_thrid_C = mma.get_layoutC_MN(); - auto layoutC_MN = get<0>(layout_and_thrid_C); - auto thrID_C = get<1>(layout_and_thrid_C); - - auto layout_and_thrid_A = mma.get_layoutA_MK(); - auto layoutA_MK = get<0>(layout_and_thrid_A); - auto thrID_A = get<1>(layout_and_thrid_A); - - auto layout_and_thrid_B = mma.get_layoutB_NK(); - auto layoutB_NK = get<0>(layout_and_thrid_B); - auto thrID_B = get<1>(layout_and_thrid_B); - - print_latex_mma(layoutC_MN, thrID_C, - layoutA_MK, thrID_A, - layoutB_NK, thrID_B); -} - -// MNK MMA Layout to LaTeX TikZ -template -CUTE_HOST_DEVICE -void -print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx - LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx - LayoutB const& B, ThrIDB const& TB, // (n,k) -> (tid,vid) and tid -> thr_idx - TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string -{ - CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); - - assert(size<0>(A) == size<0>(C)); - assert(size<0>(B) == size<1>(C)); - assert(size<1>(A) == size<1>(B)); - - // Commented prints - printf("%% LayoutC: "); print(C); printf("\n"); - printf("%% ThrIDC : "); print(TC); printf("\n"); - printf("%% LayoutA: "); print(A); printf("\n"); - printf("%% ThrIDA : "); print(TA); printf("\n"); - printf("%% LayoutB: "); print(B); printf("\n"); - printf("%% ThrIDB : "); print(TB); printf("\n\n"); - // Header - printf("\\documentclass[convert]{standalone}\n" - "\\usepackage{tikz}\n\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); - - // C starting at 0,0 - for (int m = 0; m < size<0>(C); ++m) { - for (int n = 0; n < size<1>(C); ++n) { - int thrid = C(m,n) % size(TC); - int val_idx = C(m,n) / size(TC); - int thr_idx = TC(thrid); - - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(thr_idx, val_idx), - m, n, - thr_idx, val_idx); - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", - 0, 0, int(size<0>(C)), int(size<1>(C))); - - // A starting at 0,-size<1>(A)-1 - for (int m = 0; m < size<0>(A); ++m) { - for (int k = 0; k < size<1>(A); ++k) { - int thrid = A(m,k) % size(TA); - int val_idx = A(m,k) / size(TA); - int thr_idx = TA(thrid); - - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(thr_idx, val_idx), - m, k-1-size<1>(A), - thr_idx, val_idx); - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", - 0, int(-size<1>(A)-1), int(size<0>(A)), -1); - // A labels - for (int m = 0, k = -1; m < size<0>(A); ++m) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), m); - } - for (int m = -1, k = 0; k < size<1>(A); ++k) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), k); - } - - // B starting at -size<1>(B)-1,0 - for (int n = 0; n < size<0>(B); ++n) { - for (int k = 0; k < size<1>(B); ++k) { - int thrid = B(n,k) % size(TB); - int val_idx = B(n,k) / size(TB); - int thr_idx = TB(thrid); - - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(thr_idx, val_idx), - k-1-size<1>(B), n, - thr_idx, val_idx); - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", - int(-size<1>(B)-1), 0, -1, int(size<0>(B))); - // B labels - for (int n = 0, k = -1; n < size<0>(B); ++n) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, n); - } - for (int n = -1, k = 0; k < size<1>(B); ++k) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, k); - } - - // Footer - printf("\\end{tikzpicture}\n" - "\\end{document}\n"); -} - -// MNK MMA Layout to console printer -template -CUTE_HOST_DEVICE -void -print_layout_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx - LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx - LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx -{ - CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); - - assert(size<0>(A) == size<0>(C)); - assert(size<0>(B) == size<1>(C)); - assert(size<1>(A) == size<1>(B)); - - int a_width = size<1>(A) * 6 + 4; - - // Print out B (white-shifted) k-by-n - for (int k = 0; k < size<1>(B); ++k) { - // Header - printf("%*s", a_width, ""); - for (int n = 0; n < size<0>(B); ++n) printf("+-----"); - printf("+\n"); - // Values - printf("%*s", a_width, ""); - for (int n = 0; n < size<0>(B); ++n) printf("|T%02dV%1d", int(TB(B(n,k) % size(TB))), int(B(n,k) / size(TB))); - printf("|\n"); - } - // Footer - printf("%*s", a_width, ""); - for (int n = 0; n < size<0>(B); ++n) printf("+-----"); - printf("+\n\n"); - - // Print out A m-by-k and C m-by-n - for (int m = 0; m < size<0>(A); ++m) { - // Header - for (int k = 0; k < size<1>(A); ++k) printf("+-----"); - printf("+ "); - for (int n = 0; n < size<1>(C); ++n) printf("+-----"); - printf("+\n"); - // Values - for (int k = 0; k < size<1>(A); ++k) printf("|T%02dV%1d", int(TA(A(m,k) % size(TA))), int(A(m,k) / size(TA))); - printf("| "); - for (int n = 0; n < size<1>(C); ++n) printf("|T%02dV%1d", int(TC(C(m,n) % size(TC))), int(C(m,n) / size(TC))); - printf("|\n"); - } - // Footer - for (int k = 0; k < size<1>(A); ++k) printf("+-----"); - printf("+ "); - for (int n = 0; n < size<1>(C); ++n) printf("+-----"); - printf("+\n"); -} - -// MNK MMA Layout to SVG -- 8-value color coded by thread -template -CUTE_HOST_DEVICE -void -print_svg_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx - LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx - LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx -{ - char const *color_map[8] = {"175,175,255", "175,255,175", "255,255,175", - "255,175,175", "210,210,255", "210,255,210", - "255,255,210", "255,210,210"}; - - const int cell_width = 20; - const int cell_height = 20; - - const int page_width = (size<1>(A) + size<0>(B) + 2) * cell_width; - const int page_height = (size<1>(B) + size<0>(A) + 2) * cell_height; - - // header - printf("\n", - page_width, page_height); - - // C - int c_base_x = (size<1>(A) + 2) * cell_width; - int c_base_y = (size<1>(B) + 2) * cell_height; - for (int m = 0; m < cute::size<0>(C); ++m) { - for (int n = 0; n < cute::size<1>(C); ++n) { - - int thrid = C(m, n) % size(TC); - int val_idx = C(m, n) / size(TC); - int thr_idx = TC(thrid); - - int x = n * cell_width + c_base_x; - int y = m * cell_height + c_base_y; - - int thr_x = x + cell_width / 2; - int thr_y = y + cell_height / 4; - int val_x = x + cell_width / 2; - int val_y = y + cell_height * 3 / 4; - - printf("\n", - x, y, cell_width, cell_height, color_map[thr_idx % 8]); - - printf("T%d\n", - thr_x, thr_y, thr_idx); - printf("V%d\n", - val_x, val_y, val_idx); - } - } - - // A - int a_base_x = cell_width; - int a_base_y = (size<1>(B) + 2) * cell_height; - for (int m = 0; m < size<0>(A); ++m) { - for (int k = 0; k < size<1>(A); ++k) { - int thrid = A(m, k) % size(TA); - int val_idx = A(m, k) / size(TA); - int thr_idx = TA(thrid); - - int x = k * cell_width + a_base_x; - int y = m * cell_height + a_base_y; - - int thr_x = x + cell_width / 2; - int thr_y = y + cell_height / 4; - int val_x = x + cell_width / 2; - int val_y = y + cell_height * 3 / 4; - - printf("\n", - x, y, cell_width, cell_height, color_map[thr_idx % 8]); - printf("T%d\n", - thr_x, thr_y, thr_idx); - printf("V%d\n", - val_x, val_y, val_idx); - } - } - - // B - int b_base_x = (size<1>(A) + 2) * cell_width; - int b_base_y = cell_height; - for (int n = 0; n < size<0>(B); ++n) { - for (int k = 0; k < size<1>(B); ++k) { - int thrid = B(n, k) % size(TB); - int val_idx = B(n, k) / size(TB); - int thr_idx = TB(thrid); - - int x = n * cell_width + b_base_x; - int y = k * cell_height + b_base_y; - - int thr_x = x + cell_width / 2; - int thr_y = y + cell_height / 4; - int val_x = x + cell_width / 2; - int val_y = y + cell_height * 3 / 4; - - printf("\n", - x, y, cell_width, cell_height, color_map[thr_idx % 8]); - printf("T%d\n", - thr_x, thr_y, thr_idx); - printf("V%d\n", - val_x, val_y, val_idx); - } - } - - // A labels - for (int m = 0; m < size<0>(A); ++m) { - int x = cell_width / 2; - int y = m * cell_height + cell_height / 2 + a_base_y; - printf("%d\n", - x, y, m); - } - for (int k = 0; k < size<1>(A); ++k) { - int x = cell_width + k * cell_width + cell_width / 2; - int y = -cell_height / 2 + a_base_y; - printf("%d\n", - x, y, k); - } - - // B labels - for (int n = 0; n < size<0>(B); ++n) { - int x = b_base_x + cell_width * n + cell_width / 2; - int y = cell_height / 2; - printf("%d\n", - x, y, n); - } - for (int k = 0; k < size<1>(B); ++k) { - int x = b_base_x - cell_width / 2; - int y = cell_height * (k + 1) + cell_height / 2; - printf("%d\n", - x, y, k); - } - - // footer - printf("\n"); -} - -template -CUTE_HOST_DEVICE -void -print_svg(MMA_Atom const &mma_atom) { - print_svg(make_tiled_mma(mma_atom)); -} - -template -CUTE_HOST_DEVICE -void -print_svg(TiledMMA const &mma) { - auto layout_and_thrid_C = mma.get_layoutC_MN(); - auto layoutC_MN = get<0>(layout_and_thrid_C); - auto thrID_C = get<1>(layout_and_thrid_C); - - auto layout_and_thrid_A = mma.get_layoutA_MK(); - auto layoutA_MK = get<0>(layout_and_thrid_A); - auto thrID_A = get<1>(layout_and_thrid_A); - - auto layout_and_thrid_B = mma.get_layoutB_NK(); - auto layoutB_NK = get<0>(layout_and_thrid_B); - auto thrID_B = get<1>(layout_and_thrid_B); - - print_svg_mma(layoutC_MN, thrID_C, layoutA_MK, thrID_A, layoutB_NK, thrID_B); -} - } // namespace cute //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1114,7 +698,7 @@ print_svg(TiledMMA const &mma) { #include #include #include -#include +#include #include #include diff --git a/include/cute/atom/mma_traits_sm100.hpp b/include/cute/atom/mma_traits_sm100.hpp index 820dc103..2e69c7bb 100644 --- a/include/cute/atom/mma_traits_sm100.hpp +++ b/include/cute/atom/mma_traits_sm100.hpp @@ -3844,4 +3844,39 @@ struct MMA_Traits +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = float; + using ValTypeB = float; + using ValTypeC = float; + + using Shape_MNK = Shape<_2,_1,_1>; + using ThrID = Layout<_1>; + + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = float; + using ValTypeB = float; + using ValTypeC = float; + + using Shape_MNK = Shape<_1,_2,_1>; + using ThrID = Layout<_1>; + + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + } // end namespace cute diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index 3f02a41d..3844d187 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -834,7 +834,7 @@ coalesce_x(Layout const& layout) } else { return detail::bw_coalesce(flat_shape, flat_stride, get(flat_shape), get(flat_stride)); } - + CUTE_GCC_UNREACHABLE; } @@ -1944,185 +1944,4 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, Layout const& } #endif -// Generic 2D Layout to console table -template -CUTE_HOST_DEVICE -void -print_layout(Layout const& layout) // (m,n) -> idx -{ - CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); - - int idx_width = num_digits(cosize(layout)) + 2; - const char* delim = "+-----------------------"; - - print(layout); print("\n"); - - // Column indices - print(" "); - for (int n = 0; n < size<1>(layout); ++n) { printf(" %*d ", idx_width-2, n); } - printf("\n"); - - // Print out A m-by-n - for (int m = 0; m < size<0>(layout); ++m) { - // Header - print(" "); - for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } - printf("+\n"); - // Values - printf("%2d ", m); // Row indices - for (int n = 0; n < size<1>(layout); ++n) { printf("| %*d ", idx_width-2, int(layout(m,n))); } - printf("|\n"); - } - // Footer - print(" "); - for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } - printf("+\n"); -} - -// Generic ThrVal 2D Layout to console table -template -CUTE_HOST_DEVICE -void -print_layout(Layout const& layout, ThrID const& thrid) // (m,n) -> (tid,vid) and tid -> thr_idx -{ - CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); - - print(layout); print("\n"); - print(thrid); print("\n"); - - // Print out m-by-n - for (int m = 0; m < size<0>(layout); ++m) { - // Header - for (int n = 0; n < size<1>(layout); ++n) printf("+------"); - printf("+\n"); - // Values - for (int n = 0; n < size<1>(layout); ++n) printf("|%03d-%02d", int(thrid(layout(m,n) % size(thrid))), int(layout(m,n) / size(thrid))); - printf("|\n"); - } - // Footer - for (int n = 0; n < size<1>(layout); ++n) printf("+------"); - printf("+\n"); -} - -struct TikzColor_White { - CUTE_HOST_DEVICE char const* - operator()(int idx) const { - return "white"; - } -}; - -struct TikzColor_BWx8 { - CUTE_HOST_DEVICE char const* - operator()(int idx) const { - static char const* color_map[8] = {"black!00", "black!40", "black!20", "black!60", - "black!10", "black!50", "black!30", "black!70"}; - return color_map[idx % 8]; - } -}; - -struct TikzColor_TV { - CUTE_HOST_DEVICE char const* - operator()(int tid, int vid) const { - static char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", - "{rgb,255:red,175;green,255;blue,175}", - "{rgb,255:red,255;green,255;blue,175}", - "{rgb,255:red,255;green,175;blue,175}", - "{rgb,255:red,210;green,210;blue,255}", - "{rgb,255:red,210;green,255;blue,210}", - "{rgb,255:red,255;green,255;blue,210}", - "{rgb,255:red,255;green,210;blue,210}"}; - return color_map[tid % 8]; - } -}; - -// Generic 2D Layout to LaTeX printer -template -CUTE_HOST_DEVICE -void -print_latex(LayoutA const& layout_a, // (m,n) -> idx - TikzColorFn color = {}) // lambda(idx) -> tikz color string -{ - CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{}); - auto layout = append<2>(layout_a, Layout<_1,_0>{}); - - // Commented print(layout) - printf("%% Layout: "); print(layout); printf("\n"); - // Header - printf("\\documentclass[convert]{standalone}\n" - "\\usepackage{tikz}\n\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); - - // Layout - for (int i = 0; i < size<0>(layout); ++i) { - for (int j = 0; j < size<1>(layout); ++j) { - int idx = layout(i,j); - printf("\\node[fill=%s] at (%d,%d) {%d};\n", - color(idx), i, j, idx); - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", - int(size<0>(layout)), int(size<1>(layout))); - // Labels - for (int i = 0, j = -1; i < size<0>(layout); ++i) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); - } - for (int i = -1, j = 0; j < size<1>(layout); ++j) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); - } - - // Footer - printf("\\end{tikzpicture}\n" - "\\end{document}\n"); -} - -// Generic ThrVal 2D Layout to LaTeX TikZ -template -CUTE_HOST_DEVICE -void -print_latex(Layout const& layout, // (m,n) -> (tid,vid) - ThrID const& thr, // tid -> thr_idx - TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string -{ - CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); - - // Commented prints - printf("%% Layout: "); print(layout); printf("\n"); - printf("%% ThrID : "); print(thr); printf("\n"); - // Header - printf("\\documentclass[convert]{standalone}\n" - "\\usepackage{tikz}\n\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); - - // Layout - for (int i = 0; i < size<0>(layout); ++i) { - for (int j = 0; j < size<1>(layout); ++j) { - int thrid = layout(i,j) % size(thr); - int val_idx = layout(i,j) / size(thr); - int thr_idx = thr(thrid); - - printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color(thr_idx, val_idx), - i, j, - thr_idx, val_idx); - } - } - // Grid - printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", - int(size<0>(layout)), int(size<1>(layout))); - // Labels - for (int i = 0, j = -1; i < size<0>(layout); ++i) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); - } - for (int j = 0, i = -1; j < size<1>(layout); ++j) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); - } - - // Footer - printf("\\end{tikzpicture}\n" - "\\end{document}\n"); -} - } // end namespace cute diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index 60a4ff4a..9c811340 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -42,22 +42,15 @@ namespace cute { template -struct ArithmeticTuple : tuple -{ - template +struct ArithmeticTuple : public tuple { CUTE_HOST_DEVICE constexpr - ArithmeticTuple(ArithmeticTuple const& u) - : tuple(static_cast const&>(u)) {} + ArithmeticTuple() : tuple() {} - template CUTE_HOST_DEVICE constexpr - ArithmeticTuple(tuple const& u) - : tuple(u) {} + ArithmeticTuple(tuple const& t) : tuple(t) {} - template CUTE_HOST_DEVICE constexpr - ArithmeticTuple(U const&... u) - : tuple(u...) {} + ArithmeticTuple(T const&... t) : tuple(t...) {} }; template @@ -147,12 +140,12 @@ operator-(ArithmeticTuple const& t) { } // -// Special cases +// Special cases for C<0> // template CUTE_HOST_DEVICE constexpr -ArithmeticTuple const& +ArithmeticTuple operator+(C, ArithmeticTuple const& u) { static_assert(t == 0, "Arithmetic tuple op+ error!"); return u; @@ -160,7 +153,7 @@ operator+(C, ArithmeticTuple const& u) { template CUTE_HOST_DEVICE constexpr -ArithmeticTuple const& +ArithmeticTuple operator+(ArithmeticTuple const& t, C) { static_assert(u == 0, "Arithmetic tuple op+ error!"); return t; @@ -168,7 +161,7 @@ operator+(ArithmeticTuple const& t, C) { template CUTE_HOST_DEVICE constexpr -ArithmeticTuple const& +ArithmeticTuple operator-(C, ArithmeticTuple const& u) { static_assert(t == 0, "Arithmetic tuple op- error!"); return -u; @@ -176,7 +169,7 @@ operator-(C, ArithmeticTuple const& u) { template CUTE_HOST_DEVICE constexpr -ArithmeticTuple const& +ArithmeticTuple operator-(ArithmeticTuple const& t, C) { static_assert(u == 0, "Arithmetic tuple op- error!"); return t; @@ -212,27 +205,20 @@ struct ArithmeticTupleIterator } }; -template +template CUTE_HOST_DEVICE constexpr auto -make_inttuple_iter(Tuple const& t) { - return ArithmeticTupleIterator(as_arithmetic_tuple(t)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -make_inttuple_iter(T0 const& t0, T1 const& t1, Ts const&... ts) { - return make_inttuple_iter(cute::make_tuple(t0, t1, ts...)); +make_inttuple_iter(Ts const&... ts) { + return ArithmeticTupleIterator(as_arithmetic_tuple(ts...)); } // // ArithmeticTuple "basis" elements -// A ScaledBasis is a (at least) rank-N+1 ArithmeticTuple: +// A ScaledBasis is a (at least) rank-N+1 ArithmeticTuple: // (_0,_0,...,T,_0,...) // with value T in the Nth mode -template +template struct ScaledBasis : private tuple { CUTE_HOST_DEVICE constexpr @@ -243,40 +229,61 @@ struct ScaledBasis : private tuple CUTE_HOST_DEVICE constexpr decltype(auto) value() const { return get<0>(static_cast const&>(*this)); } + // Deprecated: Get the first hierarchical mode in this basis. CUTE_HOST_DEVICE static constexpr - auto mode() { return Int{}; } + auto mode() { return get<0>(int_sequence{}); } }; +// Ensure flat representation +template +struct ScaledBasis, Ns...> : ScaledBasis {}; + template struct is_scaled_basis : false_type {}; -template -struct is_scaled_basis> : true_type {}; +template +struct is_scaled_basis> : true_type {}; -template -struct is_integral> : true_type {}; +template +struct is_integral> : true_type {}; -// Get the scalar T out of a ScaledBasis -template -CUTE_HOST_DEVICE constexpr auto -basis_value(SB const& e) +// Shortcuts +// E<> := _1 +// E<0> := (_1,_0,_0,...) +// E<1> := (_0,_1,_0,...) +// E<0,0> := ((_1,_0,_0,...),_0,_0,...) +// E<0,1> := ((_0,_1,_0,...),_0,_0,...) +// E<1,0> := (_0,(_1,_0,_0,...),_0,...) +// E<1,1> := (_0,(_0,_1,_0,...),_0,...) +template +using E = ScaledBasis,Ns...>; + +// Apply the Ns... pack to another Tuple +template +CUTE_HOST_DEVICE decltype(auto) +basis_get(T const&, Tuple&& t) { - if constexpr (is_scaled_basis::value) { - return basis_value(e.value()); + return static_cast(t); +} + +template +CUTE_HOST_DEVICE decltype(auto) +basis_get(ScaledBasis const&, Tuple&& t) +{ + if constexpr (sizeof...(Ns) == 0) { + return static_cast(t); } else { - return e; + return get(static_cast(t)); } CUTE_GCC_UNREACHABLE; } -// Apply the N... pack to another Tuple -template +template CUTE_HOST_DEVICE decltype(auto) -basis_get(SB const& e, Tuple&& t) -{ - if constexpr (is_scaled_basis::value) { - return basis_get(e.value(), get(static_cast(t))); +basis_value(T const& e) { + if constexpr (is_scaled_basis::value) { + return e.value(); } else { - return static_cast(t); + return e; } CUTE_GCC_UNREACHABLE; } @@ -294,65 +301,34 @@ to_atuple_i(T const& t, seq) { // Turn a ScaledBases into a rank-N+1 ArithmeticTuple // with N prefix 0s: (_0,_0,...N...,_0,T) -template +template CUTE_HOST_DEVICE constexpr auto -as_arithmetic_tuple(ScaledBasis const& t) { - return detail::to_atuple_i(as_arithmetic_tuple(t.value()), make_seq{}); +as_arithmetic_tuple(ScaledBasis const& t) { + return t.value(); } -namespace detail { +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ScaledBasis const& t) { + return detail::to_atuple_i(as_arithmetic_tuple(ScaledBasis{t.value()}), make_seq{}); +} -template -struct Basis; - -template <> -struct Basis<> { - using type = Int<1>; -}; - -template -struct Basis { - using type = ScaledBasis::type, N>; -}; - -} // end namespace detail - -// Shortcut for writing ScaledBasis, N0>, N1>, ...> -// E<> := _1 -// E<0> := (_1,_0,_0,...) -// E<1> := (_0,_1,_0,...) -// E<0,0> := ((_1,_0,_0,...),_0,_0,...) -// E<0,1> := ((_0,_1,_0,...),_0,_0,...) -// E<1,0> := (_0,(_1,_0,_0,...),_0,...) -// E<1,1> := (_0,(_0,_1,_0,...),_0,...) -template -using E = typename detail::Basis::type; - -template +template CUTE_HOST_DEVICE constexpr auto make_basis_like(Shape const& shape) { - if constexpr (is_integral::value) { - return Int<1>{}; - } else { - // Generate bases for each rank of shape + if constexpr (is_tuple::value) { + // Generate bases for each mode of shape return transform(tuple_seq{}, shape, [](auto I, auto si) { - // Generate bases for each rank of si and add an i on front - using I_type = decltype(I); - return transform_leaf(make_basis_like(si), [](auto e) { - // MSVC has trouble capturing variables as constexpr, - // so that they can be used as template arguments. - // This is exactly what the code needs to do with i, unfortunately. - // The work-around is to define i inside the inner lambda, - // by using just the type from the enclosing scope. - constexpr int i = I_type::value; - return ScaledBasis{}; - }); + // Generate bases for each si and add an i on end + return make_basis_like(si); }); + } else { + return E{}; } - CUTE_GCC_UNREACHABLE; } @@ -360,109 +336,124 @@ make_basis_like(Shape const& shape) // Arithmetic // -template +template CUTE_HOST_DEVICE constexpr auto -safe_div(ScaledBasis const& b, U const& u) +safe_div(ScaledBasis const& b, U const& u) { auto t = safe_div(b.value(), u); - return ScaledBasis{t}; + return ScaledBasis{t}; } -template +template CUTE_HOST_DEVICE constexpr auto -ceil_div(ScaledBasis const& b, U const& u) +ceil_div(ScaledBasis const& b, U const& u) { auto t = ceil_div(b.value(), u); - return ScaledBasis{t}; + return ScaledBasis{t}; } -template +template CUTE_HOST_DEVICE constexpr auto -abs(ScaledBasis const& e) +abs(ScaledBasis const& e) { auto t = abs(e.value()); - return ScaledBasis{t}; + return ScaledBasis{t}; } // Equality -template +template CUTE_HOST_DEVICE constexpr auto -operator==(ScaledBasis const& t, ScaledBasis const& u) { - return bool_constant{} && t.value() == u.value(); +operator==(ScaledBasis const& t, ScaledBasis const& u) { + if constexpr (sizeof...(Ns) == sizeof...(Ms)) { + return bool_constant<((Ns == Ms) && ...)>{} && t.value() == u.value(); + } else { + return false_type{}; + } + CUTE_GCC_UNREACHABLE; } // Not equal to anything else -template +template CUTE_HOST_DEVICE constexpr false_type -operator==(ScaledBasis const&, U const&) { +operator==(ScaledBasis const&, U const&) { return {}; } -template +template CUTE_HOST_DEVICE constexpr false_type -operator==(T const&, ScaledBasis const&) { +operator==(T const&, ScaledBasis const&) { return {}; } // Multiplication -template +template CUTE_HOST_DEVICE constexpr auto -operator*(A const& a, ScaledBasis const& e) { +operator*(A const& a, ScaledBasis const& e) { auto r = a * e.value(); - return ScaledBasis{r}; + return ScaledBasis{r}; } -template +template CUTE_HOST_DEVICE constexpr auto -operator*(ScaledBasis const& e, B const& b) { +operator*(ScaledBasis const& e, B const& b) { auto r = e.value() * b; - return ScaledBasis{r}; + return ScaledBasis{r}; } // Addition -template +template CUTE_HOST_DEVICE constexpr auto -operator+(ScaledBasis const& t, ScaledBasis const& u) { +operator+(ScaledBasis const& t, ScaledBasis const& u) { return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); } -template +template CUTE_HOST_DEVICE constexpr auto -operator+(ScaledBasis const& t, ArithmeticTuple const& u) { +operator+(ScaledBasis const& t, ArithmeticTuple const& u) { return as_arithmetic_tuple(t) + u; } -template +template CUTE_HOST_DEVICE constexpr auto -operator+(ArithmeticTuple const& t, ScaledBasis const& u) { +operator+(ArithmeticTuple const& t, ScaledBasis const& u) { return t + as_arithmetic_tuple(u); } -template +template CUTE_HOST_DEVICE constexpr auto -operator+(C, ScaledBasis const& u) { - static_assert(t == 0, "ScaledBasis op+ error!"); - return u; +operator+(C, ScaledBasis const& u) { + if constexpr (sizeof...(Ms) == 0) { + return C{} + u.value(); + } else { + static_assert(t == 0, "ScaledBasis op+ error!"); + return u; + } + CUTE_GCC_UNREACHABLE; } -template +template CUTE_HOST_DEVICE constexpr auto -operator+(ScaledBasis const& t, C) { - static_assert(u == 0, "ScaledBasis op+ error!"); - return t; +operator+(ScaledBasis const& t, C) { + if constexpr (sizeof...(Ns) == 0) { + return t.value() + C{}; + } else { + static_assert(u == 0, "ScaledBasis op+ error!"); + return t; + } + CUTE_GCC_UNREACHABLE; } // @@ -475,10 +466,10 @@ CUTE_HOST_DEVICE void print(ArithmeticTupleIterator const& iter) printf("ArithTuple"); print(iter.coord_); } -template -CUTE_HOST_DEVICE void print(ScaledBasis const& e) +template +CUTE_HOST_DEVICE void print(ScaledBasis const& e) { - print(e.value()); printf("@%d", N); + print(e.value()); (void(printf("@%d", Ns)), ...); } #if !defined(__CUDACC_RTC__) @@ -488,10 +479,11 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, ArithmeticTupleIterator -CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis const& e) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis const& e) { - return os << e.value() << "@" << N; + os << e.value(); (void(os << "@" << Ns), ...); + return os; } #endif diff --git a/include/cute/numeric/int.hpp b/include/cute/numeric/int.hpp index 485c07d5..71464e72 100644 --- a/include/cute/numeric/int.hpp +++ b/include/cute/numeric/int.hpp @@ -47,8 +47,9 @@ namespace cute // Signed integers // -using int2_t = cutlass::int2b_t; -using int4_t = cutlass::int4b_t; +using int2_t = cutlass::int2b_t; +using int4_t = cutlass::int4b_t; +using int6_t = cutlass::int6b_t; using CUTE_STL_NAMESPACE::int8_t; using CUTE_STL_NAMESPACE::int16_t; using CUTE_STL_NAMESPACE::int32_t; @@ -75,10 +76,10 @@ using int_byte_t = typename int_byte::type; // Unsigned integers // -using uint1_t = cutlass::uint1b_t; -using uint2_t = cutlass::uint2b_t; -using uint4_t = cutlass::uint4b_t; -using uint6_t = cutlass::uint6b_t; +using uint1_t = cutlass::uint1b_t; +using uint2_t = cutlass::uint2b_t; +using uint4_t = cutlass::uint4b_t; +using uint6_t = cutlass::uint6b_t; using CUTE_STL_NAMESPACE::uint8_t; using CUTE_STL_NAMESPACE::uint16_t; using CUTE_STL_NAMESPACE::uint32_t; @@ -88,7 +89,7 @@ template struct uint_bit; template <> struct uint_bit< 1> { using type = uint1_t; }; template <> struct uint_bit< 2> { using type = uint2_t; }; template <> struct uint_bit< 4> { using type = uint4_t; }; -template <> struct uint_bit< 6> { using type = uint6_t; }; +template <> struct uint_bit< 6> { using type = uint6_t; }; template <> struct uint_bit< 8> { using type = uint8_t; }; template <> struct uint_bit< 16> { using type = uint16_t; }; template <> struct uint_bit< 32> { using type = uint32_t; }; diff --git a/include/cute/numeric/numeric_types.hpp b/include/cute/numeric/numeric_types.hpp index 892ec706..1f03a6d0 100644 --- a/include/cute/numeric/numeric_types.hpp +++ b/include/cute/numeric/numeric_types.hpp @@ -38,10 +38,19 @@ namespace cute { -template -struct sizeof_bits : public cutlass::sizeof_bits {}; +template +struct sizeof_bits : cutlass::sizeof_bits {}; -// DO NOT change auto to int, sizeof_bits use integral_ratio instead of int +template +struct sizeof_bits : sizeof_bits {}; + +template +struct sizeof_bits : sizeof_bits {}; + +template +struct sizeof_bits : sizeof_bits {}; + +// DO NOT change auto to int, sizeof_bits use integral_ratio instead of int template static constexpr auto sizeof_bits_v = sizeof_bits::value; @@ -53,6 +62,23 @@ using cutlass::is_subbyte; template static constexpr auto is_subbyte_v = is_subbyte::value; +// +// Integral +// + +using cutlass::bin1_t; +using cutlass::uint1b_t; +using cutlass::int2b_t; +using cutlass::uint2b_t; +using cutlass::int4b_t; +using cutlass::uint4b_t; +using cutlass::int6b_t; +using cutlass::uint6b_t; + +// +// Floating Point +// + using cutlass::half_t; using cutlass::bfloat16_t; @@ -65,18 +91,12 @@ using cutlass::type_erased_dynamic_float8_t; using cutlass::float_e4m3_t; using cutlass::float_e5m2_t; -using cutlass::uint1b_t; -using cutlass::int2b_t; -using cutlass::uint2b_t; -using cutlass::int4b_t; -using cutlass::uint4b_t; -using cutlass::bin1_t; + using cutlass::float_ue4m3_t; using cutlass::float_ue8m0_t; -using cutlass::uint6b_t; using cutlass::float_e2m1_t; using cutlass::float_e2m3_t; using cutlass::float_e3m2_t; @@ -94,8 +114,6 @@ using cutlass::detail::type_erased_dynamic_float4_unpacksmem_t; using cutlass::detail::type_erased_dynamic_float6_unpacksmem_t; }; - - // // Print utility // @@ -112,7 +130,6 @@ print(bfloat16_t a) { printf("%f", static_cast(a)); } - CUTE_HOST_DEVICE void print(tfloat32_t a) { @@ -131,6 +148,15 @@ print(float_e5m2_t a) { printf("%f", static_cast(a)); } +template +CUTE_HOST_DEVICE +void +print(cutlass::float_exmy_base a) { + printf("%f", static_cast(a)); +} + +// Pretty Print utility + CUTE_HOST_DEVICE void pretty_print(bfloat16_t v) { printf("%*.2f", 8, float(v)); @@ -156,26 +182,11 @@ pretty_print(float_e5m2_t t) { printf("%*.2f", 8, static_cast(t)); } - -template < - cutlass::detail::FpEncoding Encoding, - class Derived -> -CUTE_HOST_DEVICE -void -print(cutlass::float_exmy_base a) { - printf("%f", static_cast(a)); -} - -template < - cutlass::detail::FpEncoding Encoding, - class Derived -> +template CUTE_HOST_DEVICE void pretty_print_float_exmy_base(cutlass::float_exmy_base t) { printf("%*.2f", 8, static_cast(t)); } - } // namespace cute diff --git a/include/cute/pointer.hpp b/include/cute/pointer.hpp index 3c42fd29..24146e25 100644 --- a/include/cute/pointer.hpp +++ b/include/cute/pointer.hpp @@ -33,9 +33,9 @@ #include // CUTE_HOST_DEVICE #include // cute::iter_adaptor #include -#include // cute::subbyte_iterator #include // cute::true_type, cute::false_type #include // sizeof_bits +#include // cute::subbyte_iterator namespace cute { @@ -51,11 +51,13 @@ namespace cute // Requires construction of a sparse_ptr that emulates access to the S logical elements. // -template +template CUTE_HOST_DEVICE constexpr auto -recast_ptr(void* ptr) +recast_ptr(T* ptr) { + using NewT = copy_cv_t; + if constexpr (is_sparse::value) { constexpr int sparsity = NewT::sparsity; NewT* p = reinterpret_cast(ptr); @@ -69,24 +71,6 @@ recast_ptr(void* ptr) CUTE_GCC_UNREACHABLE; } -template -CUTE_HOST_DEVICE constexpr -auto -recast_ptr(void const* ptr) -{ - if constexpr (is_sparse::value) { - constexpr int sparsity = NewT::sparsity; - NewT const* p = reinterpret_cast(ptr); - return make_sparse_ptr(p); - } else - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } else { - return reinterpret_cast(ptr); - } - CUTE_GCC_UNREACHABLE; -} - // Disambiguate nullptr template CUTE_HOST_DEVICE constexpr diff --git a/include/cute/pointer_flagged.hpp b/include/cute/pointer_flagged.hpp index 7f205347..0e38cc72 100644 --- a/include/cute/pointer_flagged.hpp +++ b/include/cute/pointer_flagged.hpp @@ -167,23 +167,6 @@ downcast(ComposedLayout,Layout> const& // Display utilities // -// Capture and cast smem_ptr_flag Layouts to offset-0 layouts -template -CUTE_HOST_DEVICE -void -print_layout(ComposedLayout,Layout> const& layout) -{ - print_layout(as_position_independent_swizzle_layout(layout)); -} - -template -CUTE_HOST_DEVICE -void -print_latex(ComposedLayout,Layout> const& layout) -{ - print_latex(as_position_independent_swizzle_layout(layout)); -} - template CUTE_HOST_DEVICE void print(smem_ptr_flag_bits ptr) { diff --git a/include/cute/tensor.hpp b/include/cute/tensor.hpp index 1ab62fd5..659c903f 100644 --- a/include/cute/tensor.hpp +++ b/include/cute/tensor.hpp @@ -56,3 +56,9 @@ #include #include +// +// Utilities +// + +#include +#include diff --git a/include/cute/tensor_impl.hpp b/include/cute/tensor_impl.hpp index e65ad419..007d8c03 100644 --- a/include/cute/tensor_impl.hpp +++ b/include/cute/tensor_impl.hpp @@ -753,24 +753,30 @@ domain_offset(Coord const& coord, Tensor&& tensor) // -- doesn't check dynamic integer divisibility // -- doesn't check alignment -template +template CUTE_HOST_DEVICE constexpr auto recast(Tensor&& tensor) { - using OldType = typename remove_cvref_t::value_type; + using OldType = typename remove_cvref_t::element_type; + using NewType = copy_cv_t; + auto old_layout = tensor.layout(); auto new_layout = recast_layout(old_layout); - // If this is an upcast of a normal Layout with static negative strides, then offset as well - if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout::value) { - auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{}); - auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{}); - auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); }); - - return make_tensor(recast_ptr(static_cast(tensor).data() + offset), new_layout); + if constexpr (is_same::value) { + return tensor; } else { - return make_tensor(recast_ptr(static_cast(tensor).data() ), new_layout); + // If this is an upcast of a normal Layout with static negative strides, then offset as well + if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout::value) { + auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{}); + auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{}); + auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); }); + + return make_tensor(recast_ptr(static_cast(tensor).data() + offset), new_layout); + } else { + return make_tensor(recast_ptr(static_cast(tensor).data() ), new_layout); + } } CUTE_GCC_UNREACHABLE; @@ -1114,95 +1120,5 @@ CUTE_HOST_DEVICE void print(Tensor const& tensor) print(tensor.data()); print(" o "); print(tensor.layout()); } -template -CUTE_HOST_DEVICE void print_tensor(Tensor const& tensor, bool print_type = true) -{ - if (print_type) { - print(tensor); print(":\n"); - } - - if constexpr (Layout::rank == 1) - { - for (int m = 0; m < size(tensor); ++m) { - pretty_print(tensor(m)); - printf("\n"); - } - } else - if constexpr (Layout::rank == 2) - { - for (int m = 0; m < size<0>(tensor); ++m) { - for (int n = 0; n < size<1>(tensor); ++n) { - pretty_print(tensor(m,n)); - } - printf("\n"); - } - } else - if constexpr (Layout::rank == 3) - { - print_tensor(tensor(_,_,0), false); - for (int k = 1; k < size<2>(tensor); ++k) { - for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n"); - print_tensor(tensor(_,_,k), false); - } - } else - if constexpr (Layout::rank == 4) - { - print_tensor(tensor(_,_,_,0), false); - for (int p = 1; p < size<3>(tensor); ++p) { - for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n"); - print_tensor(tensor(_,_,_,p), false); - } - } -} - -#if !defined(__CUDACC_RTC__) -template -CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor const& tensor) -{ - int digits = 9; - - if constexpr (Layout::rank == 1) - { - for (int m = 0; m < size(tensor); ++m) { - os << std::setw(digits) << tensor(m) << std::endl; - } - } else - if constexpr (Layout::rank == 2) - { - for (int m = 0; m < size<0>(tensor); ++m) { - for (int n = 0; n < size<1>(tensor); ++n) { - os << std::setw(digits) << tensor(m,n); - } - os << std::endl; - } - } else - if constexpr (Layout::rank == 3) - { - print_tensor_os(os, tensor(_,_,0)); - for (int k = 1; k < size<2>(tensor); ++k) { - for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl; - print_tensor_os(os, tensor(_,_,k)); - } - } else - if constexpr (Layout::rank == 4) - { - print_tensor_os(os, tensor(_,_,_,0)); - for (int p = 1; p < size<3>(tensor); ++p) { - for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl; - print_tensor_os(os, tensor(_,_,_,p)); - } - } - - return os; -} - -template -CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor const& tensor) -{ - os << tensor.layout() << std::endl; - return print_tensor_os(os, tensor); -} -#endif // !defined(__CUDACC_RTC__) - } // end namespace cute diff --git a/include/cute/util/print_latex.hpp b/include/cute/util/print_latex.hpp new file mode 100644 index 00000000..28e30ed5 --- /dev/null +++ b/include/cute/util/print_latex.hpp @@ -0,0 +1,438 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE + +#include +#include + +#include +#include + +namespace cute +{ + +/////////////////////////////////////// +// Common LaTeX TikZ Color utilities // +/////////////////////////////////////// + +struct TikzColor_White { + CUTE_HOST_DEVICE char const* + operator()(int idx) const { + return "white"; + } +}; + +struct TikzColor_BWx8 { + CUTE_HOST_DEVICE char const* + operator()(int idx) const { + static char const* color_map[8] = {"black!00", "black!40", "black!20", "black!60", + "black!10", "black!50", "black!30", "black!70"}; + return color_map[idx % 8]; + } +}; + +struct TikzColor_TV { + CUTE_HOST_DEVICE char const* + operator()(int tid, int vid) const { + static char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", + "{rgb,255:red,175;green,255;blue,175}", + "{rgb,255:red,255;green,255;blue,175}", + "{rgb,255:red,255;green,175;blue,175}", + "{rgb,255:red,210;green,210;blue,255}", + "{rgb,255:red,210;green,255;blue,210}", + "{rgb,255:red,255;green,255;blue,210}", + "{rgb,255:red,255;green,210;blue,210}"}; + return color_map[tid % 8]; + } +}; + +///////////////////////////// +// Layout 2D to LaTeX TikZ // +///////////////////////////// + +template +CUTE_HOST_DEVICE +void +print_latex(LayoutA const& layout_a, // (m,n) -> idx + TikzColorFn color = {}) // lambda(idx) -> tikz color string +{ + CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{}); + auto layout = append<2>(layout_a, Layout<_1,_0>{}); + + // Commented print(layout) + printf("%% Layout: "); print(layout); printf("\n"); + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); + + auto [M, N] = product_each(shape(layout)); + + // Layout + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + int idx = layout(m,n); + printf("\\node[fill=%s] at (%d,%d) {%d};\n", + color(idx), m, n, idx); + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", + int(M), int(N)); + // Labels + for (int m = 0, n = -1; m < M; ++m) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, m); + } + for (int m = -1, n = 0; n < N; ++n) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, n); + } + + // Footer + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); +} + +template +CUTE_HOST_DEVICE +void +print_latex(ComposedLayout,Layout> const& layout, + TikzColorFn color = {}) // lambda(idx) -> tikz color string) +{ + print_latex(as_position_independent_swizzle_layout(layout), color); +} + +/////////////////////////////// +// LayoutTV 2D to LaTeX TikZ // +/////////////////////////////// + +template +CUTE_HOST_DEVICE +void +print_latex_tv(LayoutTV const& layout_tv, // (t,v) -> m,n coord + Tile_MN const& tile_mn, // (M,N) + TikzColorFn color = {}) // (t,v) -> color +{ + CUTE_STATIC_ASSERT_V(rank(layout_tv) == Int<2>{}); + + // Commented prints + printf("%% Layout TV: "); print(layout_tv); printf("\n"); + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); + + auto [M, N] = product_each(shape(tile_mn)); + Tensor filled = make_tensor(make_shape(M, N)); + clear(filled); + + // Layout + for (int tid = 0; tid < size<0>(layout_tv); ++tid) { + for (int vid = 0; vid < size<1>(layout_tv); ++vid) { + auto [m, n] = layout_tv(tid, vid); + if (not filled(m, n)) { + filled(m, n) = true; + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(tid, vid), + int(m), int(n), + tid, vid); + } + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", int(M), int(N)); + // Labels + for (int m = 0, n = -1; m < M; ++m) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, m); + } + for (int n = 0, m = -1; n < N; ++n) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, n); + } + // Footer + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); +} + +//////////////////////////// +// MMA Atom to LaTeX TikZ // +//////////////////////////// + +namespace detail { + +template +CUTE_HOST_DEVICE +void +print_latex_mma(LayoutC const& C, // (tid,vid) -> (m,n) coord + LayoutA const& A, // (tid,vid) -> (m,k) coord + LayoutB const& B, // (tid,vid) -> (n,k) coord + Tile_MNK const& tile_mnk, // (M,N,K) + TikzColorFn color = {}) // lambda(tid,vid) -> tikz color string +{ + CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); + + // Commented prints + printf("%% LayoutC: "); print(C); printf("\n"); + printf("%% LayoutA: "); print(A); printf("\n"); + printf("%% LayoutB: "); print(B); printf("\n"); + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); + + auto [M, N, K] = product_each(shape(tile_mnk)); + Tensor filled = make_tensor(make_shape(M, N, K)); + clear(filled); + + // C starting at 0,0 + for (int tid = 0; tid < size<0>(C); ++tid) { + for (int vid = 0; vid < size<1>(C); ++vid) { + auto [m, n] = C(tid, vid); + if (not filled(m, n, 0)) { + filled(m, n, 0) = true; + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(tid, vid), + int(m), int(n), + tid, vid); + } + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, 0, int(M), int(N)); + + clear(filled); + + // A starting at 0,-K-1 + for (int tid = 0; tid < size<0>(A); ++tid) { + for (int vid = 0; vid < size<1>(A); ++vid) { + auto [m, k] = A(tid, vid); + if (not filled(m, 0, k)) { + filled(m, 0, k) = true; + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(tid, vid), + int(m), int(k-K-1), + tid, vid); + } + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, -int(K)-1, int(M), -1); + // A labels + for (int m = 0, k = -1; m < M; ++m) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(k-K-1), m); + } + for (int m = -1, k = 0; k < K; ++k) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(k-K-1), k); + } + + clear(filled); + + // B starting at -K-1,0 + for (int tid = 0; tid < size<0>(B); ++tid) { + for (int vid = 0; vid < size<1>(B); ++vid) { + auto [n, k] = B(tid, vid); + if (not filled(0, n, k)) { + filled(0, n, k) = true; + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(tid, vid), + int(k)-int(K)-1, int(n), + tid, vid); + } + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + -int(K)-1, 0, -1, int(N)); + // B labels + for (int n = 0, k = -1; n < N; ++n) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", int(k-K-1), n, n); + } + for (int n = -1, k = 0; k < K; ++k) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", int(k-K-1), n, k); + } + + // Footer + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); +} + +} // end namespace detail + +// MMA Atom to LaTeX TikZ +template +CUTE_HOST_DEVICE +void +print_latex(MMA_Atom const& mma_atom, + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string +{ + print_latex(make_tiled_mma(mma_atom)); +} + +// TiledMMA to LaTeX TikZ +template +CUTE_HOST_DEVICE +void +print_latex(TiledMMA const& mma, + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string +{ + auto tile_mnk = tile_shape(mma); + + Tensor refC = make_identity_tensor(select<0,1>(tile_mnk)); + Tensor tensorC_TV = composition(refC, mma.get_layoutC_TV()); + + Tensor refA = make_identity_tensor(select<0,2>(tile_mnk)); + Tensor tensorA_TV = composition(refA, mma.get_layoutA_TV()); + + Tensor refB = make_identity_tensor(select<1,2>(tile_mnk)); + Tensor tensorB_TV = composition(refB, mma.get_layoutB_TV()); + + detail::print_latex_mma(tensorC_TV, tensorA_TV, tensorB_TV, tile_mnk, color); +} + +//////////////////////////// +// CopyAtom to LaTeX TikZ // +//////////////////////////// + +namespace detail { + +// Generic TV Layout to LaTeX TikZ +template +CUTE_HOST_DEVICE +void +print_latex_copy(LayoutS_TV const& S, // (t,v) -> m,n coord + LayoutD_TV const& D, // (t,v) -> m,n coord + Tile_MN const& tile_mn, // (M,N) + TikzColorFn color = {}) // (t,v) -> color +{ + CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{}); + + // Commented prints + printf("%% Layout S TV: "); print(S); printf("\n"); + printf("%% Layout D TV: "); print(D); printf("\n"); + + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); + + auto [M, N] = product_each(shape(tile_mn)); + Tensor filled = make_tensor(make_shape(M, N)); + clear(filled); + + // S starting at 0,0 + for (int tid = 0; tid < size<0>(S); ++tid) { + for (int vid = 0; vid < size<1>(S); ++vid) { + auto [m, n] = S(tid, vid); + if (not filled(m, n)) { + filled(m, n) = true; + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(tid, vid), + int(m), int(n), + tid, vid); + } + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, 0, int(M), int(N)); + // S Labels + for (int m = 0, n = -1; m < M; ++m) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, m); + } + for (int m = -1, n = 0; n < N; ++n) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, n, n); + } + + clear(filled); + + // D starting at 0,N+3 + for (int tid = 0; tid < size<0>(D); ++tid) { + for (int vid = 0; vid < size<1>(D); ++vid) { + auto [m, n] = D(tid, vid); + if (not filled(m, n)) { + filled(m, n) = true; + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(tid, vid), + int(m), int(n) + int(N) + 3, + tid, vid); + } + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, int(N) + 3, int(M), int(N) + int(N) + 3); + // D Labels + for (int m = 0, n = N; m < M; ++m) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(n+N+3), m); + } + for (int m = -1, n = 0; n < N; ++n) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(n+N+3), n); + } + + // Footer + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); +} + +} // end namespace detail + +// TiledCopy to LaTeX TikZ +template +CUTE_HOST_DEVICE +void +print_latex(TiledCopy const& copy, + TikzColorFn color = {}) // lambda(tid,vid) -> tikz color string +{ + auto tiler_mn = typename TiledCopy::Tiler_MN{}; + auto tile_mn = product_each(shape(logical_divide(make_layout(Shape<_1,_1>{}), tiler_mn))); // tile_shape + + Tensor refS = make_identity_tensor(tile_mn); + Tensor layoutS_TV = copy.tidfrg_S(refS)(_,_,Int<0>{}); + + Tensor refD = make_identity_tensor(tile_mn); + Tensor layoutD_TV = copy.tidfrg_D(refD)(_,_,Int<0>{}); + + detail::print_latex_copy(layoutS_TV, layoutD_TV, tile_mn, color); +} + +} // end namespace cute diff --git a/include/cute/util/print_svg.hpp b/include/cute/util/print_svg.hpp new file mode 100644 index 00000000..5d26809e --- /dev/null +++ b/include/cute/util/print_svg.hpp @@ -0,0 +1,257 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE + +#include +#include + +#include +#include + +namespace cute +{ + +//////////////////////////////// +// Common SVG Color utilities // +//////////////////////////////// + +struct TSVGColor_White { + CUTE_HOST_DEVICE char const* + operator()(int idx) const { + return "255,255,255"; + } +}; + +struct TSVGColor_BWx8 { + CUTE_HOST_DEVICE char const* + operator()(int idx) const { + static char const* color_map[8] = {"255,255,255", "230,230,230", "205,205,205", "180,180,180", + "155,155,155", "130,130,130", "105,105,105", "080,080,080"}; + return color_map[idx % 8]; + } +}; + +struct SVGColor_TV { + CUTE_HOST_DEVICE char const* + operator()(int tid, int vid) const { + static char const* color_map[8] = {"175,175,255", "175,255,175", "255,255,175", "255,175,175", + "210,210,255", "210,255,210", "255,255,210", "255,210,210"}; + return color_map[tid % 8]; + } +}; + +///////////////////// +// MMA Atom to SVG // +///////////////////// + +namespace detail { + +template +CUTE_HOST_DEVICE +void +print_svg_mma(LayoutC const& C, + LayoutA const& A, + LayoutB const& B, + Tile_MNK const& tile_mnk, + SVGColorFn color = {}) // lambda(tid,vid) -> SVG color string +{ + CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); + + auto [M, N, K] = product_each(shape(tile_mnk)); + + int cell_size = 20; + + int page_width = (K + N + 2) * cell_size; + int page_height = (K + M + 2) * cell_size; + + // Commented print + printf("\n"); + printf("\n"); + printf("\n"); + printf("\n"); + + // SVG Header + printf("\n", + page_width, page_height); + + Tensor filled = make_tensor(make_shape(M, N, K)); + clear(filled); + + // --- Draw C --- + for (int tid = 0; tid < size<0>(C); ++tid) { + for (int vid = 0; vid < size<1>(C); ++vid) { + auto [m, n] = C(tid, vid); + if (!filled(m, n, 0)) { + filled(m, n, 0) = true; + + int x = (n + K + 2) * cell_size; + int y = (m + K + 2) * cell_size; + + printf("\n", + x, y, cell_size, cell_size, color(tid,vid)); + printf("T%d\n", + x + cell_size/2, y + 1*cell_size/4, tid); + printf("V%d\n", + x + cell_size/2, y + 3*cell_size/4, vid); + } + } + } + + clear(filled); + + // --- Draw A --- + for (int tid = 0; tid < size<0>(A); ++tid) { + for (int vid = 0; vid < size<1>(A); ++vid) { + auto [m, k] = A(tid, vid); + if (!filled(m, 0, k)) { + filled(m, 0, k) = true; + + int x = (k + 1) * cell_size; + int y = (m + K + 2) * cell_size; + + printf("\n", + x, y, cell_size, cell_size, color(tid,vid)); + printf("T%d\n", + x + cell_size/2, y + 1*cell_size/4, tid); + printf("V%d\n", + x + cell_size/2, y + 3*cell_size/4, vid); + } + } + } + + // A labels + for (int m = 0, k = -1; m < M; ++m) { + int x = (k + 1) * cell_size; + int y = (m + K + 2) * cell_size; + printf("%d\n", + x + cell_size/2, y + cell_size/2, m); + } + for (int m = -1, k = 0; k < K; ++k) { + int x = (k + 1) * cell_size; + int y = (m + K + 2) * cell_size; + printf("%d\n", + x + cell_size/2, y + cell_size/2, k); + } + + clear(filled); + + // --- Draw B --- + for (int tid = 0; tid < size<0>(B); ++tid) { + for (int vid = 0; vid < size<1>(B); ++vid) { + auto [n, k] = B(tid, vid); + if (!filled(0, n, k)) { + filled(0, n, k) = true; + + int x = (n + K + 2) * cell_size; + int y = (k + 1) * cell_size; + + printf("\n", + x, y, cell_size, cell_size, color(tid,vid)); + printf("T%d\n", + x + cell_size/2, y + 1*cell_size/4, tid); + printf("V%d\n", + x + cell_size/2, y + 3*cell_size/4, vid); + } + } + } + + // B labels + for (int n = 0, k = -1; n < N; ++n) { + int x = (n + K + 2) * cell_size; + int y = (k + 1) * cell_size; + printf("%d\n", + x + cell_size/2, y + cell_size/2, n); + } + for (int n = -1, k = 0; k < K; ++k) { + int x = (n + K + 2) * cell_size; + int y = (k + 1) * cell_size; + printf("%d\n", + x + cell_size/2, y + cell_size/2, k); + } + + // SVG footer + printf("\n"); +} + +} // end namespace detail + +// MMA Atom to SVG +template +CUTE_HOST_DEVICE +void +print_svg(MMA_Atom const& mma_atom, + SVGColorFn color = {}) // lambda(thr_idx,val_idx) -> svg color string +{ + print_svg(make_tiled_mma(mma_atom)); +} + +// TiledMMA to SVG +template +CUTE_HOST_DEVICE +void +print_svg(TiledMMA const& mma, + SVGColorFn color = {}) // lambda(thr_idx,val_idx) -> svg color string +{ + auto tile_mnk = tile_shape(mma); + + Tensor refC = make_identity_tensor(select<0,1>(tile_mnk)); + Tensor tensorC_TV = composition(refC, mma.get_layoutC_TV()); + + Tensor refA = make_identity_tensor(select<0,2>(tile_mnk)); + Tensor tensorA_TV = composition(refA, mma.get_layoutA_TV()); + + Tensor refB = make_identity_tensor(select<1,2>(tile_mnk)); + Tensor tensorB_TV = composition(refB, mma.get_layoutB_TV()); + + detail::print_svg_mma(tensorC_TV, tensorA_TV, tensorB_TV, tile_mnk, color); +} + +} // end namespace cute diff --git a/include/cute/util/print_tensor.hpp b/include/cute/util/print_tensor.hpp new file mode 100644 index 00000000..c5eb39a1 --- /dev/null +++ b/include/cute/util/print_tensor.hpp @@ -0,0 +1,188 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE + +#include +#include + +namespace cute +{ + +//////////////////////////////// +// Layout 2D to Console table // +//////////////////////////////// + +template +CUTE_HOST_DEVICE +void +print_layout(Layout const& layout) // (m,n) -> idx +{ + CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); + + int idx_width = num_digits(cosize(layout)) + 2; + const char* delim = "+-----------------------"; + + print(layout); print("\n"); + + // Column indices + print(" "); + for (int n = 0; n < size<1>(layout); ++n) { printf(" %*d ", idx_width-2, n); } + printf("\n"); + + // Print out A m-by-n + for (int m = 0; m < size<0>(layout); ++m) { + // Header + print(" "); + for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } + printf("+\n"); + // Values + printf("%2d ", m); // Row indices + for (int n = 0; n < size<1>(layout); ++n) { printf("| %*d ", idx_width-2, int(layout(m,n))); } + printf("|\n"); + } + // Footer + print(" "); + for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } + printf("+\n"); +} + +// Capture and cast smem_ptr_flag Layouts to offset-0 layouts +template +CUTE_HOST_DEVICE +void +print_layout(ComposedLayout,Layout> const& layout) +{ + print_layout(as_position_independent_swizzle_layout(layout)); +} + +//////////////////////////////// +// Tensor 1D,2D,3D,4D Console // +//////////////////////////////// + +template +CUTE_HOST_DEVICE +void +print_tensor(Tensor const& tensor, bool print_type = true) +{ + if (print_type) { + print(tensor); print(":\n"); + } + + if constexpr (Layout::rank == 1) + { + for (int m = 0; m < size(tensor); ++m) { + pretty_print(tensor(m)); + printf("\n"); + } + } else + if constexpr (Layout::rank == 2) + { + for (int m = 0; m < size<0>(tensor); ++m) { + for (int n = 0; n < size<1>(tensor); ++n) { + pretty_print(tensor(m,n)); + } + printf("\n"); + } + } else + if constexpr (Layout::rank == 3) + { + print_tensor(tensor(_,_,0), false); + for (int k = 1; k < size<2>(tensor); ++k) { + for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n"); + print_tensor(tensor(_,_,k), false); + } + } else + if constexpr (Layout::rank == 4) + { + print_tensor(tensor(_,_,_,0), false); + for (int p = 1; p < size<3>(tensor); ++p) { + for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n"); + print_tensor(tensor(_,_,_,p), false); + } + } +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST +std::ostream& +print_tensor_os(std::ostream& os, Tensor const& tensor) +{ + int digits = 9; + + if constexpr (Layout::rank == 1) + { + for (int m = 0; m < size(tensor); ++m) { + os << std::setw(digits) << tensor(m) << std::endl; + } + } else + if constexpr (Layout::rank == 2) + { + for (int m = 0; m < size<0>(tensor); ++m) { + for (int n = 0; n < size<1>(tensor); ++n) { + os << std::setw(digits) << tensor(m,n); + } + os << std::endl; + } + } else + if constexpr (Layout::rank == 3) + { + print_tensor_os(os, tensor(_,_,0)); + for (int k = 1; k < size<2>(tensor); ++k) { + for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl; + print_tensor_os(os, tensor(_,_,k)); + } + } else + if constexpr (Layout::rank == 4) + { + print_tensor_os(os, tensor(_,_,_,0)); + for (int p = 1; p < size<3>(tensor); ++p) { + for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl; + print_tensor_os(os, tensor(_,_,_,p)); + } + } + + return os; +} + +template +CUTE_HOST +std::ostream& +operator<<(std::ostream& os, Tensor const& tensor) +{ + os << tensor.layout() << std::endl; + return print_tensor_os(os, tensor); +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index ee361c7d..34cc5ca9 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -92,6 +92,29 @@ using CUTE_STL_NAMESPACE::remove_const_t; using CUTE_STL_NAMESPACE::remove_cv_t; using CUTE_STL_NAMESPACE::remove_reference_t; +template +struct copy_cv { + using type = Dst; +}; + +template +struct copy_cv { + using type = Dst const; +}; + +template +struct copy_cv { + using type = Dst volatile; +}; + +template +struct copy_cv { + using type = Dst const volatile; +}; + +template +using copy_cv_t = typename copy_cv::type; + using CUTE_STL_NAMESPACE::extent; using CUTE_STL_NAMESPACE::remove_extent; diff --git a/include/cutlass/arch/mma_sm100.h b/include/cutlass/arch/mma_sm100.h new file mode 100644 index 00000000..46fb31f6 --- /dev/null +++ b/include/cutlass/arch/mma_sm100.h @@ -0,0 +1,118 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once + +#include + +#include "cutlass/arch/mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/config.h" +#include "cute/arch/simd_sm100.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass{ +namespace arch { + + +/// Matrix multiply-add operation +template < + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC +> +struct Mma, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC_, LayoutC, OpMultiplyAdd> { + + using Shape = gemm::GemmShape<2, 1, 1>; + using Operator = OpMultiplyAdd; + using ElementC = ElementC_; + + CUTLASS_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + d[i] = a[i] * b[0] + c[i]; + } + } +}; + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> { + + using Shape = gemm::GemmShape<2, 1, 1>; + using Operator = OpMultiplyAdd; + using ElementC = float; + + CUTLASS_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + float2 result; + cute::fma(result, make_float2(a[0], a[1]), make_float2(b[0], b[0]), make_float2(c[0], c[1])); + d[0] = result.x; + d[1] = result.y; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/include/cutlass/detail/collective/mixed_input_utils.hpp b/include/cutlass/detail/collective/mixed_input_utils.hpp index aed30bee..3847c712 100644 --- a/include/cutlass/detail/collective/mixed_input_utils.hpp +++ b/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -88,7 +88,7 @@ struct LayoutAwareConvertImpl< static void convert( cute::Tensor, cute::Stride<_4,_1>> - > const& src, + > const& src, cute::Tensor >& dst) { @@ -136,7 +136,7 @@ struct LayoutAwareConvertImpl< static void convert( cute::Tensor, cute::Stride<_4,_1>> - > const& src, + > const& src, cute::Tensor >& dst) { @@ -184,7 +184,7 @@ struct LayoutAwareConvertImpl< static void convert( cute::Tensor, cute::Stride<_4,_1>> - > const& src, + > const& src, cute::Tensor >& dst) { @@ -250,7 +250,7 @@ struct LayoutAwareConvertImpl< static void convert( cute::Tensor, cute::Stride<_4,_1>> - > const& src, + > const& src, cute::Tensor >& dst) { @@ -477,7 +477,7 @@ void LayoutAwareConvert( Tensor dst_vm = coalesce(dst); Layout src_layout = src_vm.layout(); Layout dst_layout = dst_vm.layout(); - LayoutAwareConvertImpl::convert(src_vm, dst_vm); @@ -487,18 +487,25 @@ void LayoutAwareConvert( ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace cutlass { + namespace detail { + enum class ConversionMode { + DirectConvert, // A * B + ConvertAndScale, // (scale * A) * B + ConvertAndScaleWithZero // (scale * A + zeros) * B + }; + } // namespace detail +} //namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + namespace cutlass::gemm::collective::detail { template static constexpr CUTLASS_HOST_DEVICE auto get_logical_ptr(PointerType const* ptr) { - if constexpr (cute::sizeof_bits_v < 8) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } template static constexpr @@ -530,8 +537,8 @@ auto get_gmem_layout(Shape const& shape, Stride const& stride) { template struct MixedInputUtils { private: + using ConversionMode = cutlass::detail::ConversionMode; using KernelSchedule = typename Collective::KernelSchedule; - using ConversionMode = typename Collective::ConversionMode; using SmemLayoutA = typename Collective::SmemLayoutA; using SmemLayoutB = typename Collective::SmemLayoutB; using SmemLayoutScale = typename Collective::SmemLayoutScale; @@ -551,10 +558,10 @@ public: elements_per_smem_scale() { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { return 0; - } + } else if constexpr (ModeHasScales) { return cute::cosize_v; - } + } else { static_assert(cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); } @@ -565,10 +572,10 @@ public: if constexpr (KernelConversionMode == ConversionMode::DirectConvert || KernelConversionMode == ConversionMode::ConvertAndScale ) { return 0; - } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { return cute::cosize_v; - } + } else { static_assert(cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); } @@ -634,7 +641,7 @@ public: // We are starting a new k-tile so copy the scale if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { // nothing to do - } + } else if constexpr (ModeHasScales) { auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views); auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views); @@ -649,13 +656,23 @@ public: } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); } - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); } } } + // Helper functions to select packing for conversion + template + struct select_packing { // Naive packing policy + static constexpr auto value() { + return Int, sizeof_bits_v))>{}; + } + }; + // The core converter uses a lookup table to converts i4 -> 8 bit value. template && dst, Tensor const& scales_neg, Tensor const& scales_pos) { - + lookup_table_convert(src, dst, scales_neg, scales_pos); } template ; using DstArray = cutlass::Array; @@ -699,7 +716,7 @@ public: // Determines if to get from the signed or unsigned candidates static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1 + uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1 asm volatile( "{\n" " lop3.b32 %0, %1, %2, %3, %4;\n" \ @@ -743,13 +760,13 @@ public: static_check_scale(flatten(Layout{})); } template CUTLASS_DEVICE static void dequantize_A_kblock( - Tensor const& tCrA_load, + Tensor const& tCrA_load, Tensor& tCrA_mma, cute::tuple& partitioned_extra_info, int const k_block) { @@ -764,7 +781,7 @@ public: Tensor src = tCrA_load(_, _, k_block); Tensor dst = tCrA_mma(_, _, k_block); - + CUTE_STATIC_ASSERT_V(size(src(_, 0)) == cosize(src(_, 0).layout()), "The first mode of tensor src must be contiguous in memory"); // try to make the size of the first mode equal to 32bit @@ -778,7 +795,7 @@ public: for (int i = 0; i < size<1>(dst_vm); ++i) { LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); } - } + } else if constexpr (UseScaleLookupTable) { constexpr int num_elements = decltype(size(src))::value; static_assert(is_same_v, "Lookup table only supports int4 being the quant type now."); @@ -856,7 +873,7 @@ public: CUTE_STATIC_ASSERT_V(size(src) == size(zeros)); Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, Int{})); Tensor zeros_vm = cute::group_modes<1,-1>(cute::zipped_divide(zeros, Int{})); - + if constexpr (is_same_v) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<1>(dst_vm); ++i) { @@ -885,6 +902,7 @@ public: } } + /// Utilities for any additional inputs inside of the TMA load template < class Params, @@ -897,39 +915,39 @@ public: cute::tuple const& load_inputs, TensorStorage& shared_tensors, uint2 const& cluster_local_block_id, - int const m_coord, + int const m_coord, int const l_coord) { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { return cute::make_tuple(); - } + } else if constexpr (ModeHasScales) { Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) Tensor gS_mkl = get<2>(load_inputs); auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) + Tensor tSgS = block_tma_s.partition_S(gS); Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(tSgS, tSsS); - } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) Tensor gZ_mkl = get<3>(load_inputs); auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) + Tensor tZgZ = block_tma_z.partition_S(gZ); Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) - return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); + return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); } else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); } } else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); } } @@ -938,7 +956,7 @@ public: class ThreadMma, class TensorStorage > - CUTLASS_DEVICE + CUTLASS_DEVICE static auto partition_extra_mma_info( ThreadMma const& mma_thread_slice, TensorStorage& shared_tensors) { @@ -950,8 +968,8 @@ public: else if constexpr (UseScaleLookupTable) { Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsS = mma_thread_slice.partition_A(sS); - Tensor tCrS_neg = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); - Tensor tCrS_pos = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); + Tensor tCrS_neg = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); + Tensor tCrS_pos = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(tCsS, tCrS_neg, tCrS_pos); @@ -960,7 +978,7 @@ public: else if constexpr (ModeHasScales) { Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsS = mma_thread_slice.partition_A(sS); - Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(tCsS, tCrS); @@ -968,13 +986,13 @@ public: else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsZ = mma_thread_slice.partition_A(sZ); - Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).layout()); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).layout()); return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); } - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); } @@ -996,18 +1014,18 @@ public: auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) - + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); - } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); } - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); } diff --git a/include/cutlass/epilogue/collective/builders/sm100_builder.inl b/include/cutlass/epilogue/collective/builders/sm100_builder.inl index b7102165..b8b752fe 100644 --- a/include/cutlass/epilogue/collective/builders/sm100_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm100_builder.inl @@ -1519,6 +1519,105 @@ public: >::CollectiveOp; }; +template < + class MmaTileShape_MNK, + class ClusterShape_MNK, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class EpilogueScheduleType, + class FusionOp +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassSimt, + MmaTileShape_MNK, + ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC_, + GmemLayoutTagC_, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleType, + FusionOp, + cute::enable_if_t< + cute::is_same_v || + cute::is_same_v || + cute::is_same_v >> { + using CtaTileShape_MNK = MmaTileShape_MNK; // cluster MMA not supported + + // Passing void C disables source load + using ElementC = cute::conditional_t, + ElementD, ElementC_>; // prevents void ref breakages + using GmemLayoutTagC = cute::conditional_t, + GmemLayoutTagD, GmemLayoutTagC_>; + static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v ? + thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; + + using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; + using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; + + using ThreadOp = cute::conditional_t< + IsDefaultFusionOp::value, + thread::LinearCombination< + ElementD, AlignmentD, ElementAccumulator, ElementCompute, + ScaleType, FloatRoundStyle::round_to_nearest, ElementC> + , + thread::LinearCombinationBiasElementwise< + ElementC, ElementAccumulator, ElementCompute, ElementD, ElementD, AlignmentD, + typename FusionOp::ActivationFn, cutlass::plus, + false, typename FusionOp::ElementBias> + >; + static_assert(not (cute::is_same_v && not IsDefaultFusionOp::value), "unsupported schedule + fusion"); + + using WarpShape_MNK = decltype(cutlass::gemm::collective::detail::sm100_simt_f32_warp_shape_mnk_selector()); + static constexpr int ThreadCount = cute::size(WarpShape_MNK{}) * NumThreadsPerWarp; + static constexpr int WarpShape_M = cute::size<0>(WarpShape_MNK{}); + static constexpr int WarpShape_N = cute::size<1>(WarpShape_MNK{}); + + // For 32 threads in 1 warp, we use [8 x 4] thread layouts and each thread will hold [4 x 4] accumulator value layouts. + // Then totally each warp will hold [32 x 16] accumulator value layouts. + // We separate the whole epilogue calculation to multi steps, + // each step will calculate 1x [32 x 16] for each warp to reduce register pressure (mainly for C register allocation for beta 1!= 0 case). + // So EpiTileM = WarpShape_M * 32 and EpiTileN = WarpShape_N * 16. + using EpiTileM = Int; + using EpiTileN = Int; + + using SmemLayout = cute::conditional_t(GmemStrideTypeD{}), + cute::Layout, cute::Stride<_1, EpiTileM>>, + cute::Layout, cute::Stride>>; + + using CopyAtomR2S = Copy_Atom, ElementAccumulator>; + + using CopyAtomS2R = Copy_Atom>, ElementAccumulator>; + + using TiledCopyS2R = decltype( + cutlass::gemm::collective::detail::make_simt_gmem_tiled_copy< + CopyAtomS2R, ThreadCount, AlignmentD, GmemStrideTypeD, EpiTileM, EpiTileN>()); + + using Schedule = cute::conditional_t, + EpilogueSimtVectorized, + EpilogueScheduleType>; + using CopyAtomR2G = Copy_Atom>, ElementD>; + using CollectiveOp = cutlass::epilogue::collective::Epilogue< + GmemStrideTypeC, + GmemStrideTypeD, + ThreadOp, + SmemLayout, + CopyAtomR2S, + TiledCopyS2R, + CopyAtomR2G, + Schedule>; +}; /////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::epilogue::collective diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index 94e43baf..fb09f8b1 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -205,6 +205,16 @@ struct IsThreadEpilogueOpWithPerChannelScaling +struct IsThreadEpilogueOpWithResidualAdd { + static constexpr bool value = false; +}; + +template +struct IsThreadEpilogueOpWithResidualAdd > { + static constexpr bool value = ThreadEpilogueOp::IsResidualSupported; +}; + template struct IsThreadEpilogueOpWithActivation { static constexpr bool value = false; diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp index f5e8fb50..d8d99849 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp @@ -39,6 +39,8 @@ #include "cutlass/cutlass.h" #include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/detail/helper_macros.hpp" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "cutlass/conv/detail.hpp" #include "cute/tensor.hpp" #include "cute/numeric/numeric_types.hpp" @@ -133,6 +135,7 @@ public: constexpr static int ThreadCount = 128; constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount; constexpr static bool isEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias::value; + constexpr static bool isSourceNeeded = not cute::is_void_v; using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; constexpr static uint32_t TmaTransactionBytes = 0; @@ -173,12 +176,27 @@ public: return cutlass::Status::kSuccess; } + template + static bool + can_implement(cutlass::conv::ConvProblemShape const& problem_shape, Arguments const& args) { + return can_implement(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args); + } + template static bool can_implement( [[maybe_unused]] ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { - return true; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + auto shape = cute::make_shape(M,N,L); + + bool implementable = true; + implementable = implementable && cutlass::detail::check_alignment(shape, StrideD{}); + if constexpr (isSourceNeeded) { + implementable = implementable && cutlass::detail::check_alignment(shape, StrideC{}); + } + return implementable; } // diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index 8cac28f7..114737a9 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -57,6 +57,7 @@ struct FusionOperation { using ElementSource = void; static constexpr bool IsSourceSupported = false; + static constexpr bool IsResidualSupported = false; // Source is added after activation using ElementScalar = void; static constexpr int AlignmentScalar = 0; @@ -317,6 +318,24 @@ struct PerColLinCombPerColBiasEltAct static constexpr bool IsPerColScaleSupported = true; }; +// D = activation(per-col alpha * acc + per-column bias) + per-col beta * C +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, // per-row alpha/beta + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + int AlignmentScalar_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct PerColResAddPerColBiasEltAct + : PerColLinCombPerColBiasEltAct { + static constexpr bool IsResidualSupported = true; +}; + // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias // if D is fp8 // D = scale_d * activation(Z) diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index 87258c69..95e82086 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -1306,6 +1306,114 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// +// D = activation(per-col alpha * acc + per-column bias) + per-col beta * C +template< + class CtaTileShapeMNK, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerColResAddPerColBiasEltAct = + Sm90EVT, // beta * C + activation(alpha * acc + bias) + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // beta, dynamic scalar/vector broadcast + Sm90SrcFetch, // C + Sm90EVT, // activation(alpha * acc + bias) + Sm90EVT, // alpha * acc + bias + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast + Sm90AccFetch, // acc + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias + > + > + >; + + template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + int AlignmentScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::PerColResAddPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90PerColResAddPerColBiasEltAct< + CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + > { + + using Impl = + Sm90PerColResAddPerColBiasEltAct< + CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + using Operation = + fusion::PerColResAddPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,bool,int64_t>; + using StrideBeta = Stride<_0,bool,int64_t>; + StrideAlpha dAlpha = {_0{}, bool(1), 0}; + StrideBeta dBeta = {_0{}, bool(1), 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + activation(alpha * acc + bias) + {beta_ptr, beta, dBeta}, // leaf args : beta + {}, // leaf args : C + { // unary op : activation(alpha * acc + bias) + { // ternary op : alpha * acc + bias + {alpha_ptr, alpha, dAlpha}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + namespace detail { template diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index 6aec0e83..ae63a767 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -591,7 +591,7 @@ struct Sm90TreeVisitor< auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; - gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator(params_aux.ptr_aux)); + gmem_ptr ptr_aux = make_gmem_ptr(params_aux.ptr_aux); Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params_aux.dAux)); // (M,N,L) Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) @@ -765,7 +765,7 @@ struct Sm90AuxLoad< auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; - gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator(params.ptr_aux)); + gmem_ptr ptr_aux = make_gmem_ptr(params.ptr_aux); Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params.dAux)); // (M,N,L) Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index 29b9d1d1..06ad8082 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -1173,8 +1173,9 @@ public: CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { Layout ref_layout_MN = [&] () { - if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); } - else { return get<0>(args.tiled_copy.get_layoutD_MN()); } + auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{}); + if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); } + else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); } }(); // tile_mn -> tv_idx // Get the MN layout + coord of lanes to determine shuffle reduction iterations @@ -1650,8 +1651,9 @@ public: CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { Layout ref_layout_MN = [&] () { - if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); } - else { return get<0>(args.tiled_copy.get_layoutD_MN()); } + auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{}); + if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); } + else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); } }(); // tile_mn -> tv_idx // Get the MN layout + coord of lanes to determine shuffle reduction iterations diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp index 330e1fde..bd378419 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp @@ -93,7 +93,7 @@ Array top_2_reduce(Array a, Array b) { " setp.gtu.f32 p, %2, %4;\n" // a0 > b0 " selp.f32 %1, mx.x, mx.y, p;\n" // a0 > b0 ? max(a1, b0) : max(a0, b1) " selp.f32 %0, %2, %4, p;\n" // a0 > b0 ? a0 : b0 - "}\n" : "=f"(out[0]), "=f"(out[1]) : + "}\n" : "=f"(out[0]), "=f"(out[1]) : "f"(a[0]), "f"(a[1]), "f"(b[0]), "f"(b[1])); return out; } @@ -117,8 +117,8 @@ Array top_4_reduce_scalar(Array a, float scalar) { " selp.f32 %1, %5, %8, p1;\n" // a0 = a1 > b ? a1 : b " selp.f32 %1, %1, %4, p0;\n" // a0 > b ? max(a1, b) : a0 == a0 > b ? a0 : old_a0 " selp.f32 %0, %4, %8, p0;\n" // a0 = a0 > b ? a0 : b - "}\n" : - "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : + "}\n" : + "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), "f"(scalar)); return out; } @@ -187,8 +187,8 @@ Array top_4_reduce(Array a, Array b) { " selp.f32 %3, mxa2b1, %3, pa1b1;\n" // a3 = a1 > b1 ? max(a2, b1) ** second most likely case " selp.f32 %3, mxa3b0, %3, pa2b0;\n" // a0 > a1 > a2 > b0 " selp.f32 %3, mxa0b3, %3, pb2a0;\n" // b0 > b1 > b2 > a0 - "}\n" : - "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : + "}\n" : + "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), "f"(b[0]), "f"(b[1]), "f"(b[2]), "f"(b[3])); return out; @@ -351,7 +351,7 @@ private: // we can track logsumexp instead of tracking two variables (sum of exps and the max). // In addition, subtracting logsumexp from any element and taking its exp is equivalent to // computing its softmax. - // + // // The overlap between softmax and top-K is that we don't need to reduce logsumexp along the // way at all, because any element not in the top-K is going to be masked out and set to 0. // Therefore, we only reduce the top-K elements, and when done, compute their logsumexp and @@ -370,7 +370,7 @@ private: ReductionResult() { } CUTLASS_DEVICE - ReductionResult(ElementCompute min, ElementCompute logsumexp): + ReductionResult(ElementCompute min, ElementCompute logsumexp): logsumexp_(logsumexp), min_(min) { } // Warp shuffle broadcast @@ -541,7 +541,7 @@ public: visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, Array const& frg_input) { - auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, lane_layout_MN, lane_mn, residue_cCol, residue_tCcCol] = args_tuple; Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); @@ -566,7 +566,7 @@ public: CUTLASS_DEVICE void reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, lane_layout_MN, lane_mn, residue_cCol, residue_tCcCol] = args_tuple; @@ -668,7 +668,7 @@ public: CUTLASS_DEVICE void end_loop(int epi_m, int epi_n) { - auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, lane_layout_MN, lane_mn, residue_cCol, residue_tCcCol] = args_tuple; @@ -690,8 +690,9 @@ public: CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { Layout ref_layout_MN = [&] () { - if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); } - else { return get<0>(args.tiled_copy.get_layoutD_MN()); } + auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{}); + if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); } + else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); } }(); // tile_mn -> tv_idx // Get the MN layout + coord of lanes to determine shuffle reduction iterations @@ -739,7 +740,7 @@ public: Tensor tRS_rSoftmax = thread_r2s.retile_S(tCrColReduce); // ((R2S,R2S_V),MMA_M,MMA_N) auto tCrC_layout = args.tCrC.layout(); // (R2S,R2S_M,R2S_N) - // Compose the new accumulator R2S layout with the expected tCrC layout to get final + // Compose the new accumulator R2S layout with the expected tCrC layout to get final // reduction tensor layout. auto tCrSoftmax_layout = take<0, 3>(tRS_rSoftmax.layout()).compose(tCrC_layout); // (R2S,R2S_V) o (R2S,R2S_M,R2S_N) diff --git a/include/cutlass/gemm/collective/builders/sm100_common.inl b/include/cutlass/gemm/collective/builders/sm100_common.inl index 6230c616..385fea22 100644 --- a/include/cutlass/gemm/collective/builders/sm100_common.inl +++ b/include/cutlass/gemm/collective/builders/sm100_common.inl @@ -569,6 +569,47 @@ sm100_make_trivial_fastFP32_tiled_mma() { } } +template< + class CtaShape_MNK +> +constexpr auto +sm100_simt_f32_warp_shape_mnk_selector() { + using namespace cute; + + constexpr int CtaShape_M = cute::size<0>(CtaShape_MNK{}); + constexpr int CtaShape_N = cute::size<1>(CtaShape_MNK{}); + constexpr int CtaShape_K = cute::size<2>(CtaShape_MNK{}); + + // CTA tile shape M and N are supposed to be divisible by 32. + static_assert(CtaShape_M % 32 == 0, "CtaShape_M needs to be divisible by 32."); + static_assert(CtaShape_N % 32 == 0, "CtaShape_N needs to be divisible by 32."); + + // WarpShape_MNK configuration + // We assume WarpShape_K is always 1 in our SM100 SIMT SGEMM implementation. + if constexpr (CtaShape_M >= CtaShape_N) { + if constexpr (CtaShape_M == 256 && CtaShape_N == 128) { + return cute::Shape<_4, _2, _1>{}; + } + else if constexpr ((CtaShape_M == 64 || CtaShape_M == 32) && CtaShape_N == 32) { + return cute::Shape<_1, _2, _1>{}; + } + else { + return cute::Shape<_2, _2, _1>{}; + } + } + else { + if constexpr (CtaShape_M == 128 && CtaShape_N == 256) { + return cute::Shape<_2, _4, _1>{}; + } + else if constexpr (CtaShape_M == 32 && CtaShape_N == 64) { + return cute::Shape<_1, _2, _1>{}; + } + else { + return cute::Shape<_1, _4, _1>{}; + } + } +} + template < class ElementPairA, diff --git a/include/cutlass/gemm/collective/builders/sm100_simt_builder.inl b/include/cutlass/gemm/collective/builders/sm100_simt_builder.inl new file mode 100644 index 00000000..15ad6bc2 --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm100_simt_builder.inl @@ -0,0 +1,216 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template< + class LayoutA, + int AlignmentA, + class LayoutB, + int AlignmentB, + class CtaShape_MNK, + class WarpShape_MNK +> +constexpr auto +sm100_make_simt_f32_tiled_mma() { + using namespace cute; + + constexpr int CtaShape_M = cute::size<0>(CtaShape_MNK{}); + constexpr int CtaShape_N = cute::size<1>(CtaShape_MNK{}); + constexpr int CtaShape_K = cute::size<2>(CtaShape_MNK{}); + + constexpr int WarpShape_M = cute::size<0>(WarpShape_MNK{}); + constexpr int WarpShape_N = cute::size<1>(WarpShape_MNK{}); + constexpr int WarpShape_K = cute::size<2>(WarpShape_MNK{}); + + // Use Permutation to achieve a [4 x 4] value layout for each thread. + // Ideally, we want the tiled mma to be such that loads from shared memory are 128 bit wide. + // While as we are using CtaShape_K = 16, when A and B are K-major, we use tranpose + 8 byte padding to avoid smem bank conflict, + // so we could only use 64 bit smem load. + // When A and B are MN-major, we use 128 bit smem load. + using PermutationA = Layout, _2>, Stride< _1, _4, _2>>; + using PermutationB = Layout, _4>, Stride< _4, _1>>; + + // For 32 threads in 1 warp, we use [8 x 4] thread layouts and each thread will hold [4 x 4] value layouts. + // Then totally each warp will hold [32 x 16] value layouts. + // So WarpShape_M needs to be equal or smaller than CtaShape_M / 32 and WarpShape_N needs to be equal or smaller than CtaShape_N / 16. + static_assert(WarpShape_M <= CtaShape_M / 32, "WarpShape_M is too large, it needs to be equal or smaller than CtaShape_M / 32."); + static_assert(WarpShape_N <= CtaShape_N / 16, "WarpShape_N is too large, it needs to be equal or smaller than CtaShape_N / 16."); + + constexpr int WarpStride_M = (WarpShape_M != 1) * NumThreadsPerWarp; + constexpr int WarpStride_N = WarpShape_M * NumThreadsPerWarp; + + // We first introduce a [8 x 4] thread layouts in 1 warp. + // And inside this [8 x 4] thread layouts, each 4 threads will be arranged as [2 x 2]. + // Then we could set different WarpShape to finalize how many warps we use in our tiled mma. + // For example : + // With 128 threads in the tiled mma, we could set the WarpShapeMNK as [2 x 2 x 1], [1 x 4 x 1] and [4 x 1 x 1]. + // With 64 threads in the tiled mma, we could set the WarpShapeMNK as [1 x 2 x 1] and [2 x 1 x 1]. + return make_tiled_mma( + MMA_Atom{}, + Layout>, Shape <_2, _2, Int>, _1>, + Stride< Stride<_1, _8, Int>, Stride<_2, _4, Int>, _1>>{}, + Tile< + PermutationA, + PermutationB, + Underscore>{}); +} + +} // namespace detail + +template < + class GmemLayoutATag, + int AlignmentA, + class GmemLayoutBTag, + int AlignmentB, + class CtaShape_MNK, + class ClusterShape_MNK, + int stages, + class BuilderScheduleTag> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassSimt, + float, + GmemLayoutATag, + AlignmentA, + float, + GmemLayoutBTag, + AlignmentB, + float, + CtaShape_MNK, + ClusterShape_MNK, + StageCount, + BuilderScheduleTag, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + ((sizeof(float) * AlignmentA) % detail::cp_async_min_alignment_bytes == 0) && + ((sizeof(float) * AlignmentB) % detail::cp_async_min_alignment_bytes == 0) >> { + static_assert(cute::size<2>(CtaShape_MNK{}) == 16, "SM100 SIMT SGEMM Kernels only support TileShape_K = 16."); + + // This kernel is specialized for F32 data type. + using ElementA = float; + using ElementB = float; + + using M = decltype(cute::size<0>(CtaShape_MNK{})); + using N = decltype(cute::size<1>(CtaShape_MNK{})); + using K = decltype(cute::size<2>(CtaShape_MNK{})); + + using WarpShape_MNK = decltype(detail::sm100_simt_f32_warp_shape_mnk_selector()); + + static constexpr int ThreadCount = cute::size(WarpShape_MNK{}) * NumThreadsPerWarp; + + using TiledMma = decltype( + detail::sm100_make_simt_f32_tiled_mma< + GmemLayoutATag, + AlignmentA, + GmemLayoutBTag, + AlignmentB, + CtaShape_MNK, + WarpShape_MNK>()); + + // for K major layouts, add a smem alignment offset to avoid bank conflicts + static constexpr int SmemAlignmentOffsetA = cutlass::gemm::detail::is_mn_major_A() ? 0 : 2; + static constexpr int SmemAlignmentOffsetB = cutlass::gemm::detail::is_mn_major_B() ? 0 : 2; + static constexpr int CtaShape_M = cute::size<0>(CtaShape_MNK{}); + static constexpr int CtaShape_N = cute::size<1>(CtaShape_MNK{}); + + // Shared memory layout is [M x K] in M-major + using SmemLayoutAtomA = cute::Layout, + cute::Stride<_1, Int>>; + // A M-major use 128bit smem load. + // A K-major needs to do tranpose and 8 byte padding to make smem bank conflict free, then we can only use 64bit smem load. + using SmemCopyAtomA = std::conditional_t(), + cute::Copy_Atom, ElementA>, + cute::Copy_Atom, ElementA>>; + + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * AlignmentA>; + using GmemCopyAtomA = cute::Copy_Atom, ElementA>; + using GmemTiledCopyA = decltype( + detail::make_simt_gmem_tiled_copy< + GmemCopyAtomA, ThreadCount, AlignmentA, TagToStrideA_t, M, K>()); + + // Shared memory layout is [N x K] in N-major + using SmemLayoutAtomB = cute::Layout, + cute::Stride<_1, Int>>; + // B N-major use 128bit smem load. + // B K-major needs to do tranpose and 8 byte padding to make smem bank conflict free, then we can only use 64bit smem load. + using SmemCopyAtomB = std::conditional_t(), + cute::Copy_Atom, ElementB>, + cute::Copy_Atom, ElementB>>; + + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * AlignmentB>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemTiledCopyB = decltype( + detail::make_simt_gmem_tiled_copy< + GmemCopyAtomB, ThreadCount, AlignmentB, TagToStrideB_t, N, K>()); + + static constexpr bool IsArrayOfPointersGemm = cute::is_same_v; + using DispatchPolicy = cute::conditional_t, + cutlass::gemm::MainloopSm80CpAsync + >; + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + CtaShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_builder.hpp b/include/cutlass/gemm/collective/collective_builder.hpp index b03c79c8..c734671b 100644 --- a/include/cutlass/gemm/collective/collective_builder.hpp +++ b/include/cutlass/gemm/collective/collective_builder.hpp @@ -46,6 +46,7 @@ #include "cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_simt_builder.inl" #include "cutlass/gemm/collective/builders/sm120_mma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_sparse_mma_builder.inl" diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index f65dd70b..0bb6f722 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -37,6 +37,7 @@ #include "cutlass/gemm/collective/sm70_mma_twostage.hpp" #include "cutlass/gemm/collective/sm80_mma_multistage.hpp" +#include "cutlass/gemm/collective/sm80_mma_array_multistage.hpp" #include "cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp" diff --git a/include/cutlass/gemm/collective/sm80_mma_array_multistage.hpp b/include/cutlass/gemm/collective/sm80_mma_array_multistage.hpp new file mode 100644 index 00000000..b83e0489 --- /dev/null +++ b/include/cutlass/gemm/collective/sm80_mma_array_multistage.hpp @@ -0,0 +1,412 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class ClusterShape_, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_ +> +struct CollectiveMma< + MainloopSm80ArrayCpAsync< + Stages, + ClusterShape_>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_ + > +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm80ArrayCpAsync< + Stages, + ClusterShape_>; + using TileShape = TileShape_; + // Follow the change in TestSmall: TileShape switch to CtaShape + // In legacy arch, it should be same + using CtaShape_MNK = TileShape; + using ElementA = ElementA_; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using ArrayElementA = ElementA; + using ArrayElementB = ElementB; + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); + + static_assert(DispatchPolicy::Stages >= 2, "CpAsync mainloop must have at least 2 stages in the pipeline."); + + struct SharedStorage + { + cute::array_aligned> smem_a; + cute::array_aligned> smem_b; + }; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A{nullptr}; + StrideA dA{}; + ElementB const** ptr_B{nullptr}; + StrideB dB{}; + }; + + // Device side kernel params + using Params = Arguments; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) { + (void) workspace; + return args; + } + + /// Perform a collective-scoped matrix multiply-accumulate + template < + class FrgTensorD, + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, // (BLK_M, BLK_K, K_TILES) + TensorB gB, // (BLK_N, BLK_K, K_TILES) + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char *smem_buf) + { + using namespace cute; + + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + // Construct shared memory tiles + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<0>(gA) == size<0>(sA)); // BLK_M + CUTE_STATIC_ASSERT_V(size<1>(gA) == size<1>(sA)); // BLK_K + CUTE_STATIC_ASSERT_V(size<0>(gB) == size<0>(sB)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(gB) == size<1>(sB)); // BLK_K + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // BLK_K + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) + // This aligns the tensor with BLK_K for all but the 0th k_tile + gA = cute::domain_offset(make_coord(0, get<2>(residue_mnk), 0), gA); + gB = cute::domain_offset(make_coord(0, get<2>(residue_mnk), 0), gB); + + // Partition the copying of A and B tiles across the threads + GmemTiledCopyA gmem_tiled_copy_A; + GmemTiledCopyB gmem_tiled_copy_B; + auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); + auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); + + Tensor tAgA = gmem_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = gmem_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) + Tensor tBgB = gmem_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = gmem_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) + + // + // PREDICATES + // + + // Allocate predicate tensors for m and n + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + + // Construct identity layout for sA and sB + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tAcA = gmem_thr_copy_A.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tBcB = gmem_thr_copy_B.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m + } + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n + } + + // + // PREFETCH + // + + // Clear the smem tiles to account for predicated off loads + clear(tAsA); + clear(tBsB); + + // Start async loads for 0th k-tile, where we take care of the k residue + { + constexpr int k_pipe = 0; + + Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tAsA); ++k) { + if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) + copy_if(gmem_tiled_copy_A, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,k_pipe)); + } + } + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) + copy_if(gmem_tiled_copy_B, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,k_pipe)); + } + } + cp_async_fence(); + ++k_tile_iter; + --k_tile_count; + } + + // Start async loads for 1st k-tile onwards, no k-residue handling needed + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 1; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { + if (k_tile_count <= 0) { + clear(tApA); + clear(tBpB); + } + copy_if(gmem_tiled_copy_A, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); // CpAsync + copy_if(gmem_tiled_copy_B, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); // CpAsync + cp_async_fence(); + ++k_tile_iter; + --k_tile_count; + } + + // + // MMA Atom partitioning + // + + // Tile MMA compute thread partitions and allocate accumulators + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + + auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,PIPE) + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_N,CPY_K) + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K + + // + // PIPELINED MAIN LOOP + // + + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = DispatchPolicy::Stages-1; + + Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read); + Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read); + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + + // PREFETCH register pipeline + if (K_BLOCK_MAX > 1) { + // Wait until our first prefetched tile is loaded in + cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + copy(smem_tiled_copy_A, tCsA_p(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{})); + copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); + } + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) + { + // Pipeline the outer products with a static for loop. + // + // Note, the for_each() function is required here to ensure `k_block` is of type Int. + for_each(make_int_sequence{}, [&] (auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + // Slice the smem_pipe_read smem + tCsA_p = tCsA(_,_,_,smem_pipe_read); + tCsB_p = tCsB(_,_,_,smem_pipe_read); + + // Commit the smem for smem_pipe_read + cp_async_wait(); + __syncthreads(); + } + + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static + copy(smem_tiled_copy_A, tCsA_p(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_B, tCsB_p(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + // Set all predicates to false if we are going to overshoot bounds + if (k_tile_count <= 0) { + clear(tApA); + clear(tBpB); + } + copy_if(gmem_tiled_copy_A, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); + copy_if(gmem_tiled_copy_B, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); + cp_async_fence(); + ++k_tile_iter; + + // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? 0 : smem_pipe_read; + } + + // Transform before compute + cute::transform(tCrA(_,_,k_block), TransformA{}); + cute::transform(tCrB(_,_,k_block), TransformB{}); + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); + }); + + } + + cp_async_wait<0>(); + __syncthreads(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp index 653db90a..fa5e212d 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -89,15 +89,10 @@ struct CollectiveMma< TransformB_> { public: - enum class ConversionMode { - DirectConvert, - ConvertAndScale, - ConvertAndScaleWithZero - }; - // // Type Aliases // + using ConversionMode = cutlass::detail::ConversionMode; using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput; using TileShape = TileShape_; using KernelSchedule = KernelSchedule_; diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp index 4e435299..2558350c 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -96,15 +96,11 @@ struct CollectiveMma< TransformB_> { public: - enum class ConversionMode { - DirectConvert, - ConvertAndScale, - ConvertAndScaleWithZero - }; // // Type Aliases // + using ConversionMode = cutlass::detail::ConversionMode; using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; using TileShape = TileShape_; using KernelSchedule = KernelSchedule_; diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index aa19fbc2..d051aa7a 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -109,6 +109,7 @@ static constexpr bool HasAuxiliaryLoad_v = HasAuxiliaryLoad::value; // Kernel schedule policies (the base class tags, one for each kernel layer file) // struct KernelMultistage { }; +struct KernelPtrArrayMultistage { }; struct KernelCpAsyncWarpSpecialized { }; struct KernelCpAsyncWarpSpecializedPingpong { }; struct KernelCpAsyncWarpSpecializedCooperative { }; @@ -198,6 +199,17 @@ struct MainloopSm80CpAsync { using ClusterShape = ClusterShape_; }; +// n-buffer in smem (cp.async), pipelined with registers, with predicated gmem loads for SM100 Simt Ptr-Array +template +> +struct MainloopSm80ArrayCpAsync { + constexpr static int Stages = Stages_; + using ArchTag = cute::conditional_t<(size(ClusterShape_{}) > 1), arch::Sm90, arch::Sm80>; + using Schedule = KernelPtrArrayMultistage; + using ClusterShape = ClusterShape_; +}; + // n-buffer in smem (cp.async), pipelined with Hopper GMMA, with predicated gmem loads, warp specialized dynamic schedule template< int Stages_, @@ -479,6 +491,16 @@ struct KernelTmaWarpSpecializedInputTransformSm100 final { static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; }; +// InputTransform GEMM +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelTmaWarpSpecializedMixedInputTransformSm100 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + // Ptr-Array Dense GEMM: SM100 tensor op policy that applies to both 1SM and 2SM MMA atoms template< int SchedulerPipelineStageCount_, diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index 08605e00..e212a761 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -54,6 +54,7 @@ struct IsCutlass3ArrayKernel( - TiledMma{}, - work_tile_info, - accumulators, - mma2accum_pipeline, - mma2accum_pipeline_consumer_state, - typename CollectiveEpilogue::CopyOpT2R{} - ); - - // - // Epilogue and write to gD - // - if (scheduler.compute_epilogue(work_tile_info)) { - auto [mma2accum_pipeline_state_next] = collective_epilogue( - mma2accum_pipeline, - mma2accum_pipeline_consumer_state, - problem_shape_MNKL, - CtaShape_MNK{}, - cta_coord_mnkl, - accumulators, - shared_storage.tensors.epilogue - ); - // Advance the mma2accum pipe - mma2accum_pipeline_consumer_state = mma2accum_pipeline_state_next; - } - } // Complex kernels use a collective epilogue else { mma2accum_pipeline.consumer_wait(mma2accum_pipeline_consumer_state); diff --git a/include/cutlass/gemm/kernel/sm70_gemm_array.hpp b/include/cutlass/gemm/kernel/sm70_gemm_array.hpp new file mode 100644 index 00000000..c0ef53a7 --- /dev/null +++ b/include/cutlass/gemm/kernel/sm70_gemm_array.hpp @@ -0,0 +1,279 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/tensor.hpp" + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using InternalStrideA = typename CollectiveMainloop::InternalStrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using InternalStrideB = typename CollectiveMainloop::InternalStrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, + cute::Shape, cute::Int<1>, cute::Int<1>>>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + static constexpr bool IsGdcEnabled = false; + + static constexpr bool is_valid_tile_scheduler = + cute::is_void_v or cute::is_same_v; +static_assert(is_valid_tile_scheduler, "SM70 kernel does not support specializing the tile scheduler."); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using InternalStrideC = typename CollectiveEpilogue::InternalStrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using InternalStrideD = typename CollectiveEpilogue::InternalStrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + static_assert(cute::is_same_v, + "Mainloop and epilogue do not agree on accumulator value type."); + + // MSVC requires the cast to fix a warning-as-error. + static constexpr int SharedStorageSize = static_cast(cute::max( + sizeof(typename CollectiveMainloop::SharedStorage), + sizeof(typename CollectiveEpilogue::SharedStorage))); + + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(cute::size(TiledMma{})); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + typename ProblemShape::UnderlyingProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + typename ProblemShape::UnderlyingProblemShape problem_shape = args.problem_shape.get_host_problem_shape(); + + KernelHardwareInfo hw_info{args.hw_info.device_id, args.hw_info.sm_count}; + auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); + + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(problem_shape, args.epilogue, workspace) + }; + } + + static bool + can_implement(Arguments const& args) { + + bool implementable = (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + typename ProblemShape::UnderlyingProblemShape problem_shape = args.problem_shape.get_host_problem_shape(); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + return workspace_size; + } + + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + cutlass::Status status = Status::kSuccess; + + return status; + } + + static dim3 + get_grid_shape(Params const& params) { + int batch_count = cute::size<3>(params.problem_shape); + return dim3( + cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))), + cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))), + batch_count + ); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto [M,N,K,L] = problem_shape_MNKL; + + // Preconditions + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + int thread_idx = int(threadIdx.x); + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + auto [m_coord, n_coord, l_coord] = static_cast(blockIdx); + auto blk_coord_mnkl = make_coord(int(m_coord), int(n_coord), _, int(l_coord)); // (m,n,k,l) + + // Represent the full tensors + Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A[l_coord]), make_shape(M,K,1), params.mainloop.dA); //(m,k,l) + Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B[l_coord]), make_shape(N,K,1), params.mainloop.dB); //(n,k,l) + + // Get batch slice + Tensor mA_mk = mA_mkl(_,_,0); // (m,k) + Tensor mB_nk = mB_nkl(_,_,0); // (n,k) + + // Slice to get the tiles this thread block is responsible for + Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + + // Compute tile residues for predication + auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord + auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord + auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + + // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape + TiledMma tiled_mma; + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + clear(accumulators); + + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + int k_tile_count = size<2>(gA); + + + // Perform the collective scoped MMA + CollectiveMainloop collective_mma; + collective_mma( + accumulators, + gA, + gB, + accumulators, + k_tile_iter, k_tile_count, + residue_mnk, + thread_idx, + smem_buf + ); + + // Epilogue and write to gD + CollectiveEpilogue epilogue{params.epilogue}; + epilogue( + problem_shape_MNKL, + blk_shape, + blk_coord_mnkl, + accumulators, + tiled_mma, + residue_mnk, + thread_idx, + smem_buf + ); + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/integer_subbyte.h b/include/cutlass/integer_subbyte.h index 097605c0..030e5845 100644 --- a/include/cutlass/integer_subbyte.h +++ b/include/cutlass/integer_subbyte.h @@ -194,6 +194,9 @@ struct integer_subbyte { /////////////////////////////////////////////////////////////////////////////////////////////////// +/// 1-bit binary type +using bin1_t = bool; + /// 1-bit Unsigned integer type using uint1b_t = integer_subbyte<1, false>; @@ -209,14 +212,12 @@ using int4b_t = integer_subbyte<4, true>; /// 4-bit Unsigned integer type using uint4b_t = integer_subbyte<4, false>; +/// 6-bit integer type +using int6b_t = integer_subbyte<6, true>; /// 6-bit unsigned integer type using uint6b_t = integer_subbyte<6, false>; - -/// 1-bit binary type -using bin1_t = bool; - /////////////////////////////////////////////////////////////////////////////////////////////////// template diff --git a/include/cutlass/numeric_size.h b/include/cutlass/numeric_size.h index 4f267e51..0d8f2ada 100644 --- a/include/cutlass/numeric_size.h +++ b/include/cutlass/numeric_size.h @@ -50,7 +50,13 @@ struct sizeof_bits { }; template -struct sizeof_bits: sizeof_bits {}; +struct sizeof_bits : sizeof_bits {}; + +template +struct sizeof_bits : sizeof_bits {}; + +template +struct sizeof_bits : sizeof_bits {}; template <> struct sizeof_bits { diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_control_flow.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_control_flow.rst index 18fd2528..e3c9316e 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_control_flow.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_control_flow.rst @@ -10,109 +10,130 @@ Control Flow Overview -------- -|DSL| walks Python’s AST and converts each control-flow construct it finds into +|DSL| walks Python's AST and converts each control-flow construct it finds into structured |IR|. You can therefore write ordinary Python loops and branches while the compiler decides—statement by statement—whether to -* **evaluate at compile time** if the controlling value is a |Constexpr|, or -* **emit intermediate representation (IR)** when the value is dynamic. +* **evaluate at compile time** if it's a native Python control flow, or +* **emit intermediate representation (IR)** when the control flow is marked as dynamic. +Passing |IR| values to a native Python control flow will result in an error. For a high-level discussion of the overall pipeline, see :doc:`the code-generation overview `. + For Loops --------- |DSL| recognises three kinds of ranges for ``for`` loops: -* ``range`` – the Python built-in -* ``cutlass.range_dynamic`` – always lowers to |IR| -* ``cutlass.range_constexpr`` – always unrolls at compile time +* ``range`` – the Python built-in, always lowered to |IR| +* ``cutlass.range`` - Same as Python built-in ``range``, but supports advanced unrolling and pipelining control +* ``cutlass.range_constexpr`` – unrolled at compile time -range(...) -~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The AST rewriter inserts a small helper stub. At runtime the loop bounds are -inspected: - -* **Constant bounds** → the loop is unrolled at compile time. -* **Dynamic bounds** → the loop is emitted as structured |IR|. - - -cutlass.range_dynamic(...) -~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Use when you *always* want a loop in the generated |IR|, even if the bounds -look constant. - +range(...)/cutlass.range(...) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Use when you *always* want a loop in the generated |IR|, even if the inputs +are Python values. cutlass.range_constexpr(...) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Runs in the Python interpreter and is fully unrolled before code generation. All loop indices must be |Constexpr|. -Limitations of Dynamic For Loops -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -* Early-exit ``break``, ``continue``, or raising exception are not yet supported. -* Operations in the loop body are traced only when tracing is active in that - region. - - **Example:** .. code-block:: python - @cute.jit - def loop_example(): - n = 10 + @cute.jit + def control_flow_examples(bound: cutlass.Int32): + n = 10 - # ❌ This loop is dynamic, early-exit isn't allowed. - for i in cutlass.range_dynamic(n): - if i == 5: - break # Early-exit - cute.printf("%d\\n", i) + # ✅ This loop is Python loop, evaluated at compile time. + for i in cutlass.range_constexpr(n): + cute.printf("%d\\n", i) + + # ✅ This loop is dynamic, even when bound is Python value. + for i in range(n): + cute.printf("%d\\n", i) + + # ❌ This loop bound is a dynamic value, not allowed in Python loop. + # Should use `range` instead. + for i in cutlass.range_constexpr(bound): + cute.printf("%d\\n", i) + + # ✅ This loop is dynamic, emitted IR loop. + for i in range(bound): + cute.printf("%d\\n", i) + + # ✅ This loop is dynamic, emitted IR loop with unrolling + for i in cutlass.range(bound, unroll=2): + cute.printf("%d\\n", i) - # ✅ This loop is constexpr, early-exit is allowed. - for i in cutlass.range_constexpr(n): - if i == 5: - break # Early-exit - cute.printf("%d\\n", i) If-Else Statements ------------------ -Standard Python ``if``/``else`` is supported. +Standard Python ``if``/``elif``/``else`` is supported. -* **Predicate is Constexpr (compile-time Python value)** → evaluated at compile time. -* **Predicate is dynamic** → lowered to |IR|. +* **Predicate without annotation** → lowered to |IR|. +* **Predicate annotated with `cutlass.const_expr`** → evaluated at compile time. **Example:** .. code-block:: python - @cute.jit - def main(const_var: cutlass.Constexpr, dynamic_var: cutlass.Int32): - if const_var: # compile-time branch - cute.printf("Const branch\\n") - else: - cute.printf("Const else\\n") + @cute.jit + def main(const_var: cutlass.Constexpr, dynamic_var: cutlass.Int32): + # ✅ This branch is Python branch, evaluated at compile time. + if cutlass.const_expr(const_var): + cute.printf("Const branch\\n") + else: + cute.printf("Const else\\n") - if dynamic_var == 10: # dynamic branch - cute.printf("Dynamic True\\n") - else: - cute.printf("Dynamic False\\n") + # ✅ This branch is dynamic branch, emitted IR branch. + if dynamic_var == 10: + cute.printf("Dynamic True\\n") + else: + cute.printf("Dynamic False\\n") + + # ❌ Using a dynamic value with `cutlass.const_expr` is not allowed. + if cutlass.const_expr(dynamic_var == 10): + cute.printf("Bound is 10\\n") -Similarly to for-loops, the ``if cutlass.const_expr`` and ``if cutlass.dynamic_expr`` constructs can -be used to force the evaluation at compile-time or the generation of IR, respectively. Unstructured -control flow is only supported when using ``if cutlass.const_expr``. While Loops ----------- -Python ``while`` loops are always treated as **dynamic** because the loop condition may become -dynamic after the first iteration. Similarly to for-loops and ``if``/``else``, the -``while cutlass.const_expr`` and ``while cutlass.dynamic_expr`` constructs are available. +Standard Python ``while`` is supported. + +* **Condition without annotation** → lowered to |IR|. +* **Condition annotated with `cutlass.const_expr`** → evaluated at compile time. + +**Example:** + +.. code-block:: python + + @cute.jit + def main(dynamic_var: cutlass.Int32): + n = 0 + + # ✅ This is Python while loop, evaluated at compile time. + while cutlass.const_expr(n < 10): + cute.printf("Const branch\\n") + n += 1 + + # ✅ This is dynamic while loop, emitted IR while loop. + while dynamic_var == 10: + cute.printf("Dynamic True\\n") + n += 1 + + # ❌ Using a dynamic value with `cutlass.const_expr` is not allowed. + while cutlass.const_expr(n < dynamic_var): + n += 1 + Compile-Time Metaprogramming ---------------------------- @@ -127,7 +148,7 @@ an optional **ReLU** epilogue: def gemm(..., do_relu: cutlass.Constexpr): # main GEMM work ... - if const_expr(do_relu): # compile-time guard + if cutlass.const_expr(do_relu): # compile-time guard # ReLU code is emitted only when do_relu is True ... @@ -135,3 +156,45 @@ an optional **ReLU** epilogue: gemm(..., False) # ReLU is omitted from the generated |IR| gemm(..., True) # ReLU is included + + +Limitations of Dynamic Control Flow +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +* Early-exit ``break``, ``continue``, ``pass`` or raising exception from + control flow body are not yet supported. +* Operations in the control flow body are traced only when tracing is active in + that region. +* Values originating in control flow body are not available outside the control + flow. +* Changing type of a variable in control flow body is not allowed. + +**Example:** + +.. code-block:: python + + @cute.jit + def control_flow_negative_examples(predicate: cutlass.Boolean): + n = 10 + + # ❌ This loop is dynamic, early-exit isn't allowed. + for i in cutlass.range_dynamic(n): + if i == 5: + break # Early-exit + + if predicate: + val = 10 + # ❌ return from control flow body is not allowed. + return + # ❌ Raising exception from control flow body is not allowed. + raise ValueError("This is not allowed") + # ❌ Using pass in control flow body is not allowed. + pass + + # ❌ val is not available outside the dynamic if + cute.printf("%d\\n", val) + + if predicate: + # ❌ Changing type of a variable in control flow body is not allowed. + n = 10.0 + diff --git a/media/docs/pythonDSL/faqs.rst b/media/docs/pythonDSL/faqs.rst index 90348d8d..0ec1a2ea 100644 --- a/media/docs/pythonDSL/faqs.rst +++ b/media/docs/pythonDSL/faqs.rst @@ -39,7 +39,7 @@ General the GitHub code only exists as a way for users to file issues and pull requests against. While it can be used with the pip wheel, we do not recommend most users do so unless they are hacking on the DSL itself. For all other users, we recommend they - simply ``pip install nvidia-cutlas-dsl`` and use the pip wheel as the single source + simply ``pip install nvidia-cutlass-dsl`` and use the pip wheel as the single source of truth for the dialect compiler and DSL implementation. CUTLASS GitHub repository will contain a ``requirements.txt`` file pinning the version of the wheel consistent with the state of the OSS repository (please see :doc:`quick_start`). This means getting started with diff --git a/media/docs/pythonDSL/limitations.rst b/media/docs/pythonDSL/limitations.rst index 7be5b051..59c9ad51 100644 --- a/media/docs/pythonDSL/limitations.rst +++ b/media/docs/pythonDSL/limitations.rst @@ -18,7 +18,6 @@ Notable unsupported features ---------------------------- - GeForce RTX 50 Series support -- RS WGMMA (The input matrix A comes from register and the input matrix B comes from shared memory) - Programmatic Dependent Launch (PDL) - narrow-precision data type support, including related tensor core instructions - convolutions @@ -31,6 +30,10 @@ Notable unsupported features Programming Model --------------------- +**CuTe Layout Algebra Only support 32bit** + Today, we only support 32bit shapes/strides in CuTe layouts. 64bit or arbitrary + width support is planned for future releases. + **Python Native Data Types** CuTe DSL supports Python data structures when used for "meta-programming," but these structures cannot be treated as dynamic values modifiable at runtime. diff --git a/python/CuTeDSL/base_dsl/ast_helpers.py b/python/CuTeDSL/base_dsl/ast_helpers.py index cc5cadb6..756d151f 100644 --- a/python/CuTeDSL/base_dsl/ast_helpers.py +++ b/python/CuTeDSL/base_dsl/ast_helpers.py @@ -15,6 +15,8 @@ The preprocessor read through python's ast and changes the input code. """ from typing import Callable, Iterator, Optional, overload +from typing_extensions import deprecated +import warnings from .utils.logger import log from .common import * @@ -30,13 +32,9 @@ class Executor: set_functions: Assigns the functions for checking loop bounds and conditional evaluation. - for_dynamic: Generates MLIR for OP - for_constexpr: Executes a for loop at JIT compile-time - for_execute: Decides whether to execute the loop at compile-time or generate MLIR for OP based on the provided bounds. - - if_dynamic: Generates MLIR if OP - if_constexpr: Executes a if at JIT compile-time by python interpreter - if_execute: Decides whether to execute the if statement at compile-time or generate MLIR if OP based on the predicate. + for_execute: Generates MLIR for OP + while_execute: Generates MLIR while OP + if_execute: generate MLIR if OP """ def __init__(self): @@ -44,6 +42,9 @@ class Executor: self._loop_execute_range_dynamic = None self._if_dynamic = None self._while_dynamic = None + self._compare_executor = None + self._any_executor = None + self._all_executor = None def set_functions( self, @@ -51,11 +52,17 @@ class Executor: loop_execute_range_dynamic: Callable, if_dynamic: Callable, while_dynamic: Callable, + compare_executor: Callable, + any_executor: Callable = None, + all_executor: Callable = None, ): self._is_dynamic_expression = is_dynamic_expression self._loop_execute_range_dynamic = loop_execute_range_dynamic self._if_dynamic = if_dynamic self._while_dynamic = while_dynamic + self._compare_executor = compare_executor + self._any_executor = any_executor + self._all_executor = all_executor @staticmethod def convert_to_list(x): @@ -83,31 +90,6 @@ class Executor: return res[0] return res - def for_dynamic( - self, - func: Callable, - start, - stop, - step, - used_args: list, - iter_args: list, - iter_arg_names: list, - unroll=bool, - unroll_full=int, - ): - log().debug("start [%s] stop [%s] step [%s]", start, stop, step) - return self._loop_execute_range_dynamic( - func, - start, - stop, - step, - used_args, - iter_args, - iter_arg_names, - unroll, - unroll_full, - ) - @staticmethod def for_constexpr( func: Callable, @@ -143,44 +125,14 @@ class Executor: iter_arg_names=[], unroll=-1, unroll_full=False, - is_range_constexpr=None, + pipelining=None, ): assert ( - self._loop_execute_range_dynamic and self._is_dynamic_expression + self._loop_execute_range_dynamic ), "Functions must be set before execution." log().debug("start [%s] stop [%s] step [%s]", start, stop, step) - any_dynamic_expression = ( - self._is_dynamic_expression(start) - or self._is_dynamic_expression(stop) - or self._is_dynamic_expression(step) - ) - if is_range_constexpr is None: - if not any_dynamic_expression: - return self.for_constexpr(func, start, stop, step, used_args, iter_args) - else: - return self.for_dynamic( - func, - start, - stop, - step, - used_args, - iter_args, - iter_arg_names, - unroll, - unroll_full, - ) - - # Ensure bounds are compile-time constants for constexpr execution - if is_range_constexpr: - if any_dynamic_expression: - raise DSLRuntimeError( - "Loop bounds must be constexpr (compile-time constants)" - ) - return self.for_constexpr(func, start, stop, step, used_args, iter_args) - - # MLIR generation - return self.for_dynamic( + return self._loop_execute_range_dynamic( func, start, stop, @@ -190,40 +142,9 @@ class Executor: iter_arg_names, unroll, unroll_full, + pipelining, ) - def if_dynamic( - self, - pred, - then_block: Callable, - else_block: Optional[Callable] = None, - used_args=[], - yield_args=[], - yield_arg_names=[], - ): - return self._if_dynamic( - pred, then_block, else_block, used_args, yield_args, yield_arg_names - ) - - @staticmethod - def if_constexpr( - pred, - then_block: Callable, - else_block: Optional[Callable] = None, - used_args=[], - yield_args=[], - ): - if pred: - log().debug(" running then block [%s]", yield_args) - res = then_block(*used_args, *yield_args) - log().debug("result [%s]", res) - return Executor.converge_ret_val(res) - elif else_block is not None: - log().debug("running else [%s]", yield_args) - res = else_block(*used_args, *yield_args) - log().debug("result [%s]", res) - return Executor.converge_ret_val(res) - def if_execute( self, pred, @@ -232,94 +153,14 @@ class Executor: used_args=[], yield_args=[], yield_arg_names=[], - if_constexpr=None, ): - assert ( - self._if_dynamic and self._is_dynamic_expression - ), "Functions must be set before execution." - - is_if_constexpr = not self._is_dynamic_expression(pred) - if if_constexpr is None: - if is_if_constexpr: - return self.if_constexpr( - pred, then_block, else_block, used_args, yield_args - ) - else: - return self.if_dynamic( - pred, then_block, else_block, used_args, yield_args, yield_arg_names - ) - - # Ensure bounds are compile-time constants for constexpr execution - if if_constexpr: - if not is_if_constexpr: - raise DSLRuntimeError( - "If predicate must be constexpr (compile-time constants)" - ) - return self.if_constexpr( - pred, then_block, else_block, used_args, yield_args - ) + assert self._if_dynamic, "Functions must be set before execution." # MLIR generation - return self.if_dynamic( + return self._if_dynamic( pred, then_block, else_block, used_args, yield_args, yield_arg_names ) - def while_dynamic( - self, - while_before_block: Callable, - while_after_block: Callable, - used_args=[], - yield_args=[], - yield_arg_names=[], - ): - return self._while_dynamic( - while_before_block, - while_after_block, - used_args, - yield_args, - yield_arg_names, - ) - - @staticmethod - def while_constexpr( - while_before_block, - while_after_block, - used_args=[], - yield_args=[], - ): - log().debug( - "while_constexpr begin %s", while_before_block.__qualname__ - ) - cond, loop_results = while_before_block(*used_args, *yield_args) - while cond: - loop_results = Executor.convert_to_list(loop_results) - log().debug( - "calling while_after [%s], [%s]", - used_args, - loop_results, - ) - loop_results = while_after_block(*used_args, *loop_results) - log().debug( - "while after [%s]", loop_results - ) - loop_results = Executor.convert_to_list(loop_results) - log().debug( - "calling while_before [%s], [%s]", - used_args, - loop_results, - ) - cond, loop_results = while_before_block(*used_args, *loop_results) - log().debug( - "while_before cond, results [%s], [%s]", - cond, - loop_results, - ) - - log().debug( - "while_constexpr results %s", loop_results - ) - return Executor.converge_ret_val(loop_results) - def while_execute( self, pred, @@ -328,26 +169,11 @@ class Executor: used_args=[], yield_args=[], yield_arg_names=[], - while_constexpr=None, ): - assert ( - self._while_dynamic and self._is_dynamic_expression - ), "Functions must be set before execution." - - is_while_constexpr = not self._is_dynamic_expression(pred) - - # Ensure bounds are compile-time constants for constexpr execution - if while_constexpr: - if not is_while_constexpr: - raise DSLRuntimeError( - "While predicate must be constexpr (compile-time constants)" - ) - return self.while_constexpr( - while_before_block, while_after_block, used_args, yield_args - ) + assert self._while_dynamic, "Functions must be set before execution." # MLIR generation - return self.while_dynamic( + return self._while_dynamic( while_before_block, while_after_block, used_args, @@ -367,15 +193,16 @@ def loop_selector( start, stop, step, + *, used_args=[], iter_args=[], iter_arg_names=[], unroll=-1, unroll_full=False, - constexpr=None, + pipelining=None, ): log().debug( - "start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] constexpr [%s]", + "start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] pipelining [%s]", start, stop, step, @@ -383,7 +210,7 @@ def loop_selector( iter_args, unroll, unroll_full, - constexpr, + pipelining, ) from .typing import Integer, Numeric @@ -408,7 +235,7 @@ def loop_selector( iter_arg_names, unroll, unroll_full, - constexpr, + pipelining, ) return ir_loop @@ -443,7 +270,6 @@ def while_executor( used_args=[], yield_args=[], yield_arg_names=[], - constexpr=None, ): return executor.while_execute( pred, @@ -452,7 +278,6 @@ def while_executor( used_args, yield_args, yield_arg_names, - constexpr, ) @@ -463,10 +288,9 @@ def if_executor( used_args=[], yield_args=[], yield_arg_names=[], - constexpr=None, ): return executor.if_execute( - pred, then_block, else_block, used_args, yield_args, yield_arg_names, constexpr + pred, then_block, else_block, used_args, yield_args, yield_arg_names ) @@ -475,75 +299,70 @@ def if_executor( # ============================================================================= -class range_dynamic: +class range: @overload - def __new__(cls, stop, unroll=0, unroll_full=False): + def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None): pass @overload - def __new__(cls, start, stop, step, unroll=0, unroll_full=False): + def __new__(cls, start, stop, step, unroll=0, unroll_full=False, pipelining=None): pass def __new__(cls, *args, **kwargs): - raise DSLRuntimeError("range_dynamic should be always preprocessed to IR") - - -class range_constexpr: - def __init__(self, *args): - if len(args) == 1: - self.start = 0 - self.stop = args[0] - self.step = 1 - elif len(args) == 2: - self.start, self.stop = args - self.step = 1 - elif len(args) == 3: - self.start, self.stop, self.step = args - else: - raise DSLRuntimeError( - "range_constexpr supports up to 3 arguments (start, stop, step)" - ) - # Ensure the arguments are compile-time constants (if required) - for arg_name, arg_value in [ - ("step", self.step), - ("start", self.start), - ("stop", self.stop), - ]: - if executor._is_dynamic_expression(arg_value): - raise DSLRuntimeError( - f"`range_constexpr` requires `constexpr` (non-IR Values) for all arguments, " - f"but `{arg_name}` is not. If the arguments are dynamic, use `range`; the DSL " - f"will handle them during runtime. ", - suggestion="Use `range` instead of `range_constexpr`.", - ) + raise DSLRuntimeError("dynamic range should be always preprocessed to IR") def __iter__(self) -> Iterator[int]: - current = self.start - while current < self.stop: - yield current - current += self.step + raise DSLRuntimeError("dynamic range should be always preprocessed to IR") +@deprecated( + "range_dynamic is deprecated and will be removed in the future, please remove it." +) +def range_dynamic(*args, **kwargs): + raise DSLRuntimeError("range_dynamic should be always preprocessed to IR") + + +def range_constexpr(*args): + raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.") + # ============================================================================= # If expressions # ============================================================================= def const_expr(expression): - if executor._is_dynamic_expression(expression): + """ + This function is used to check if the expression is a python value. + If the expression is a python value, return the boolean value of the expression. + If the expression is a dynamic expression, raise an error. + """ + from .typing import Numeric + + failed = False + + if isinstance(expression, Numeric): + if isinstance(expression.value, (int, float, bool)): + return expression.value + else: + failed = True + elif executor._is_dynamic_expression(expression): + failed = True + + if failed: raise DSLRuntimeError( f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).", context={ - "const_expr": "Accepts only constexpr (compile-time constant)", - "If your expression depends on dynamic values": "Avoid marking it as `const_expr()`", - "If the expression could be either dynamic or constexpr": "Omit explicit `const_expr()` marker; the DSL will infer the correct handling automatically", + "If your expression depends on dynamic values": "Remove `const_expr()`", }, ) return expression +@deprecated( + "dynamic_expr is deprecated and will be removed in the future, please remove it." +) def dynamic_expr(expression): - raise DSLRuntimeError("dynamic_expr should be always preprocessed to IR") + return expression # ============================================================================= @@ -582,3 +401,86 @@ def bool_cast(value): suggestion = "Please explicitly convert to boolean with expressions like comparision." ) return bool(value) + +def compare_executor(left, comparators, ops): + """ + Executes comparison operations with a left operand and a list of comparators. + + Args: + left: The leftmost value in the comparison chain + comparators: A list of values to compare against + ops: A list of comparison operators to apply + + Returns: + The result of the comparison chain + + Raises: + AssertionError: If the executor function is not set before execution + """ + assert ( + executor._compare_executor is not None + ), "Function must be set before execution." + return executor._compare_executor(left, comparators, ops) + + +def any_executor(iterable): + """Executes the 'any' operation on an iterable, handling both dynamic and static expressions. + + :param iterable: An iterable to check if any elements evaluate to True + :type iterable: Iterable + :return: boolean of Python value or IR value + :rtype: bool or cutlass.Boolean + + """ + if executor._any_executor and executor._is_dynamic_expression(iterable): + return executor._any_executor(iterable) + else: + return any(iterable) + + +def all_executor(iterable): + """Executes the 'all' operation on an iterable, handling both dynamic and static expressions. + + :param iterable: An iterable to check if all elements evaluate to True + :type iterable: Iterable + :return: boolean of Python value or IR value + :rtype: bool or cutlass.Boolean + """ + if executor._all_executor and executor._is_dynamic_expression(iterable): + return executor._all_executor(iterable) + else: + return all(iterable) + + +# ============================================================================= +# Control flow checks +# ============================================================================= +def range_value_check(*args): + """ + Ensure all `range_constexpr` bounds are compile-time constants (Python ints). + """ + try: + return tuple(arg.__index__() for arg in args) + except: + raise DSLRuntimeError( + "`range_constexpr` requires constexpr (compile-time constant) for all arguments.", + suggestion="Use `range` instead of `range_constexpr`.", + ) + + +def range_perf_warning(filename, lineno, *args): + has_dynamic_expr = False + for arg in args: + if executor._is_dynamic_expression(arg): + has_dynamic_expr = True + break + if not has_dynamic_expr: + warnings.warn_explicit( + ( + "The loop was previously unrolled in Python, but now it may not unroll in IR. This may cause performance regression." + "If you want to unroll the loop in Python, please use `range_constexpr` instead of `range`." + ), + category=UserWarning, + filename=filename, + lineno=lineno, + ) diff --git a/python/CuTeDSL/base_dsl/ast_preprocessor.py b/python/CuTeDSL/base_dsl/ast_preprocessor.py index f1e1c635..ea73831c 100644 --- a/python/CuTeDSL/base_dsl/ast_preprocessor.py +++ b/python/CuTeDSL/base_dsl/ast_preprocessor.py @@ -16,7 +16,7 @@ It uses Python's AST and rewrites specific Python statements such as `for` and ` The preprocessor operates on the following constructs: - `for` loops: - Rewrites `for` loops with the `@loop_selector` decorator. - - Supports `range`, `range_dynamic`, and `range_constexpr` for loop iteration. + - Supports `range`, `range_dynamic` for loop iteration. - `if-elif-else` statements: - Rewrites conditional statements with the `@if_selector` decorator. - Supports `dynamic_expr` and `const_expr` in the condition expressions. @@ -36,9 +36,11 @@ import ast import importlib import inspect import textwrap +import warnings from dataclasses import dataclass from typing import List, Set, Dict, Any, Callable, Optional from types import ModuleType +from collections import OrderedDict from .common import * from .utils.logger import log @@ -69,6 +71,22 @@ class OrderedSet: def __sub__(self, other): return OrderedSet(key for key in self._dict if key not in other) + def intersections(self, others): + """Compute the intersection of this set with multiple other sets. + + :param others: A list of sets to compute intersections with + :type others: List[Set[str]] + :return: A new ordered set containing elements that appear in this set + and at least one of the other sets + """ + result = OrderedSet() + for key in self._dict: + for other in reversed(others): + if key in other: + result.add(key) + break + return result + @dataclass class ScopeManager: @@ -78,23 +96,23 @@ class ScopeManager: """ scopes: List[Set[str]] - current_scope: Set[str] @classmethod def create(cls) -> "ScopeManager": - return cls([], set()) - - def enter_scope(self) -> None: - self.scopes.append(self.current_scope.copy()) - - def exit_scope(self) -> None: - self.current_scope = self.scopes.pop() + return cls([]) def add_to_scope(self, name: str) -> None: - self.current_scope.add(name) + self.scopes[-1].add(name) - def get_active_symbols(self) -> Set[str]: - return set(self.current_scope) + def get_active_symbols(self) -> List[Set[str]]: + return self.scopes.copy() + + def __enter__(self) -> "ScopeManager": + self.scopes.append(set()) + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.scopes.pop() class DSLPreprocessor(ast.NodeTransformer): @@ -115,6 +133,9 @@ class DSLPreprocessor(ast.NodeTransformer): BOOL_CAST = "bool_cast" IMPLICIT_DOWNCAST_NUMERIC_TYPE = "implicitDowncastNumericType" SUPPORTED_FOR_RANGE_STATEMENTS = {"range", "range_dynamic", "range_constexpr"} + COMPARE_EXECUTOR = "compare_executor" + ANY_EXECUTOR = "any_executor" + ALL_EXECUTOR = "all_executor" def __init__(self): super().__init__() @@ -131,57 +152,59 @@ class DSLPreprocessor(ast.NodeTransformer): def _get_module_imports(self, decorated_func): """Extract imports from the module containing the decorated function""" + imports = OrderedDict() # Get the module containing the decorated function - module = inspect.getmodule(decorated_func) - if module is None: - return {} + if module := inspect.getmodule(decorated_func): + try: + # Get the module source code + source = inspect.getsource(module) + module_ast = ast.parse(source) - # Get the module source code - try: - source = inspect.getsource(module) - module_ast = ast.parse(source) + # Extract imports from the full module + alias = lambda n: n.asname if n.asname else n.name + for node in ast.walk(module_ast): + if isinstance(node, ast.Import): + for name in node.names: + imports[(name.name, None)] = alias(name) + elif isinstance(node, ast.ImportFrom): + module_name = node.module + if node.level > 0: + # Handle relative imports + package_name = module.__package__.rsplit( + ".", node.level - 1 + )[0] + module_name = f"{package_name}.{module_name}" + for name in node.names: + imports[(module_name, name.name)] = alias(name) + except (IOError, TypeError): + pass - # Extract imports from the full module - imports = {} - for node in ast.walk(module_ast): - if isinstance(node, ast.Import): - for name in node.names: - imports[name.name] = name.asname if name.asname else name.name - elif isinstance(node, ast.ImportFrom): - module_name = node.module - for name in node.names: - if name.name == "*": - # Handle wildcard imports - try: - imported_module = importlib.import_module(module_name) - imports[module_name] = imported_module - except ImportError: - pass - else: - full_name = f"{module_name}.{name.name}" - imports[full_name] = ( - name.asname if name.asname else name.name - ) - return imports - except (IOError, TypeError): - return {} + return imports def exec(self, function_name, original_function, code_object, exec_globals): # Get imports from the original module module_imports = self._get_module_imports(original_function) # Import all required modules - for module_path, alias in module_imports.items(): + for (module_path, attr_name), alias_name in module_imports.items(): try: - if "." in module_path: - base_module, attribute = module_path.rsplit(".", 1) - module = importlib.import_module(base_module) - if hasattr(module, attribute): - attr = getattr(module, attribute) - exec_globals[alias] = attr + module = importlib.import_module(module_path) + if attr_name: + if attr_name == "*": + if hasattr(module, "__all__"): + attrs = module.__all__ + else: + attrs = [ + name for name in dir(module) if not name.startswith("_") + ] + else: + attrs = [attr_name] + + for attr in attrs: + alias = attr if attr_name == "*" else alias_name + exec_globals[alias] = getattr(module, attr) else: - path = importlib.import_module(module_path) - exec_globals[alias] = path + exec_globals[alias_name] = module except (ImportError, AttributeError) as e: raise ImportError(f"Failed to import {module_path}: {str(e)}") @@ -306,7 +329,7 @@ class DSLPreprocessor(ast.NodeTransformer): snippet=ast.unparse(tree), suggestion=( "If predicates are constant expression, write like " - "`if const_expr(...)` or `for ... in range_constexpr`. " + "`if const_expr(...)` or `for ... in range_constexpr(...)`. " "In that case, early exit will be executed by Python " "interpreter, so it's supported." ), @@ -315,7 +338,7 @@ class DSLPreprocessor(ast.NodeTransformer): def is_node_constexpr(self, node) -> bool: """ Determines if the node is a constexpr. - Supported nodes are if, for, while statements. + Supported nodes are if, while statements. """ if isinstance(node, ast.If) or isinstance(node, ast.While): if isinstance(node.test, ast.Call): @@ -326,16 +349,26 @@ class DSLPreprocessor(ast.NodeTransformer): elif isinstance(func, ast.Name) and func.id == "const_expr": return True - elif isinstance(node, ast.For): - if isinstance(node.iter, ast.Call): - func = node.iter.func - if isinstance(func, ast.Attribute) and func.attr == "range_constexpr": - return True - - elif isinstance(func, ast.Name) and func.id == "range_constexpr": - return True return False + def _get_range_kind(self, iter_node): + """ + Return "range", "range_dynamic", "range_constexpr" or None for the iterable + """ + if isinstance(iter_node, ast.Call): + func = iter_node.func + if ( + isinstance(func, ast.Name) + and func.id in self.SUPPORTED_FOR_RANGE_STATEMENTS + ): + return func.id, True + if ( + isinstance(func, ast.Attribute) + and func.attr in self.SUPPORTED_FOR_RANGE_STATEMENTS + ): + return func.attr, False + return None, None + def transform(self, original_function, exec_globals): """ Transforms the provided function using the preprocessor. @@ -350,7 +383,9 @@ class DSLPreprocessor(ast.NodeTransformer): return unified_tree - def analyze_region_variables(self, node: Union[ast.For, ast.If], active_symbols): + def analyze_region_variables( + self, node: Union[ast.For, ast.If], active_symbols: List[Set[str]] + ): """ Analyze variables in different code regions to identify read-only, write-only, and active variables for DSL constructs. @@ -426,8 +461,8 @@ class DSLPreprocessor(ast.NodeTransformer): # Argument can be Load and Store. We should just mark it as Store. read_args = read_args - write_args - used_args = read_args & active_symbols - iter_args = write_args & active_symbols + used_args = read_args.intersections(active_symbols) + iter_args = write_args.intersections(active_symbols) flattend_args = used_args | iter_args return list(used_args), list(iter_args), list(flattend_args) @@ -435,11 +470,21 @@ class DSLPreprocessor(ast.NodeTransformer): def extract_range_args(self, iter_node): args = iter_node.args if len(args) == 1: - return ast.Constant(value=0), self.visit(args[0]), ast.Constant(value=1) + return ( + self.visit(ast.Constant(value=0)), + self.visit(args[0]), + self.visit(ast.Constant(value=1)), + False, + ) elif len(args) == 2: - return self.visit(args[0]), self.visit(args[1]), ast.Constant(value=1) + return ( + self.visit(args[0]), + self.visit(args[1]), + self.visit(ast.Constant(value=1)), + False, + ) elif len(args) == 3: - return self.visit(args[0]), self.visit(args[1]), self.visit(args[2]) + return self.visit(args[0]), self.visit(args[1]), self.visit(args[2]), True else: raise DSLAstPreprocessorError( "Unsupported number of arguments in range", filename=self.file_name @@ -452,6 +497,10 @@ class DSLPreprocessor(ast.NodeTransformer): keywords.get("unroll_full", ast.Constant(value=False)), ) + def extract_pipelining_args(self, iter_node): + keywords = {kw.arg: kw.value for kw in iter_node.keywords} + return keywords.get("pipelining", ast.Constant(value=None)) + def create_loop_function( self, func_name, @@ -461,10 +510,10 @@ class DSLPreprocessor(ast.NodeTransformer): step, unroll, unroll_full, + pipelining, used_args, iter_args, flattened_args, - is_loop_constexpr, ): """ Creates a loop body function with the `loop_selector` decorator. @@ -501,10 +550,10 @@ class DSLPreprocessor(ast.NodeTransformer): func=ast.Name(id=self.DECORATOR_FOR_STATEMENT, ctx=ast.Load()), args=[start, stop, step], keywords=[ - ast.keyword(arg="unroll", value=unroll), - ast.keyword(arg="unroll_full", value=unroll_full), - ast.keyword(arg="constexpr", value=is_loop_constexpr), - ast.keyword( + ast.keyword(arg="unroll", value=unroll), + ast.keyword(arg="unroll_full", value=unroll_full), + ast.keyword(arg="pipelining", value=pipelining), + ast.keyword( arg="used_args", value=ast.List( elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in used_args], @@ -568,80 +617,6 @@ class DSLPreprocessor(ast.NodeTransformer): value=ast.Name(id=func_name, ctx=ast.Load()), ) - def is_supported_range_call(self, node): - return ( - isinstance(node, ast.For) - and isinstance(node.iter, ast.Call) - and ( - ( - isinstance(node.iter.func, ast.Name) - and node.iter.func.id in self.SUPPORTED_FOR_RANGE_STATEMENTS - ) - or ( - isinstance(node.iter.func, ast.Attribute) - and node.iter.func.attr in self.SUPPORTED_FOR_RANGE_STATEMENTS - ) - ) - ) - - def get_loop_constexpr(self, node): - if not self.is_supported_range_call(node): - return None - - # Map function names to their constexpr values - constexpr_map = {"range": None, "range_dynamic": False, "range_constexpr": True} - range_name = ( - node.iter.func.id - if isinstance(node.iter.func, ast.Name) - else node.iter.func.attr - ) - return ast.Constant(value=constexpr_map[range_name]) - - def transform_for_loop(self, node, active_symbols): - # Constexpr doesn't get preprocessed - if self.is_node_constexpr(node): - self.generic_visit(node) - return node - - # We only support range, range_constexpr, range_dynamic - if self.is_supported_range_call(node): - constexpr_val = self.get_loop_constexpr(node) - # Check for early exit and raise exception - self.check_early_exit(node, "for") - start, stop, step = self.extract_range_args(node.iter) - unroll, unroll_full = self.extract_unroll_args(node.iter) - used_args, iter_args, flat_args = self.analyze_region_variables( - node, active_symbols - ) - - func_name = f"loop_body_{self.counter}" - self.counter += 1 - - func_def = self.create_loop_function( - func_name, - node, - start, - stop, - step, - unroll, - unroll_full, - used_args, - iter_args, - flat_args, - constexpr_val, - ) - - assign = ast.copy_location( - self.create_loop_call(func_name, iter_args), node - ) - - # This should work fine as it modifies the AST structure - return [func_def, assign] - - self.generic_visit(node) - - return node - def visit_BoolOp(self, node): # Visit child nodes first self.generic_visit(node) @@ -725,16 +700,299 @@ class DSLPreprocessor(ast.NodeTransformer): return node + @staticmethod + def _insert_range_value_check(node): + """ + Insert a check for range arguments + """ + range_inputs = node.iter.args + check_call = ast.copy_location( + ast.Call( + func=ast.Name(id="range_value_check", ctx=ast.Load()), + args=range_inputs, + keywords=[], + ), + node.iter, + ) + node.iter = ast.copy_location( + ast.Call( + func=ast.Name(id="range", ctx=ast.Load()), + args=[ast.Starred(value=check_call, ctx=ast.Load())], + keywords=[], + ), + node.iter, + ) + def visit_For(self, node): active_symbols = self.scope_manager.get_active_symbols() - self.scope_manager.enter_scope() + with self.scope_manager: + if isinstance(node.target, ast.Name): + self.scope_manager.add_to_scope(node.target.id) + + # For static for loop (for with range_constexpr or not range based for), preprocessor keeps the loop. + range_kind, is_builtin_range = self._get_range_kind(node.iter) + if range_kind == "range_constexpr" or range_kind == None: + self.generic_visit(node) + if range_kind == "range_constexpr": + # Rewrite range_constexpr to range + node.iter.func.id = "range" + self._insert_range_value_check(node) + return node + + if range_kind == "range_dynamic": + # Generate a warning + warnings.simplefilter("always", DeprecationWarning) # turn off filter + warnings.warn_explicit( + "range_dynamic is deprecated and will be removed in the future, please remove it.", + category=DeprecationWarning, + filename=self.file_name, + lineno=node.iter.lineno, + ) + warnings.simplefilter("default", DeprecationWarning) # reset filter + + warning_call = None + if range_kind == "range" and is_builtin_range: + # Warn about possible performance regression due to behavior change + warning_call = ast.Expr( + ast.Call( + func=ast.Name(id="range_perf_warning", ctx=ast.Load()), + args=[ + ast.Constant(value=self.file_name), + ast.Constant(value=node.iter.lineno), + ] + + node.iter.args, + keywords=[], + ) + ) + ast.copy_location(warning_call, node.iter) + + new_for_node = self.transform_for_loop(node, active_symbols) + + return new_for_node if warning_call is None else [warning_call] + new_for_node + + @staticmethod + def _hoist_expr_to_assignments(expr, name): + return ast.copy_location( + ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=expr), expr + ) + + def _build_select_and_assign(self, *, name, test, body, orelse, location): + node = ast.copy_location( + ast.Assign( + targets=[ast.Name(id=name, ctx=ast.Store())], + value=ast.IfExp( + test=test, + body=body, + orelse=orelse, + ), + ), + location, + ) + self.generic_visit(node) + return node + + def _handle_negative_step(self, node, start_expr, stop_expr, step_expr): + # hoist start, stop, step to assignments + start_ori_name = f"start_ori_{self.counter}" + start = self._hoist_expr_to_assignments(start_expr, start_ori_name) + stop_ori_name = f"stop_ori_{self.counter}" + stop = self._hoist_expr_to_assignments(stop_expr, stop_ori_name) + step_ori_name = f"step_ori_{self.counter}" + step = self._hoist_expr_to_assignments(step_expr, step_ori_name) + + extra_exprs = [start, stop, step] + + # Handle possible negative step, generates the following code in Python: + # isNegative = step < 0 + isNegative_name = f"isNegative_{self.counter}" + isNegative = ast.copy_location( + ast.Assign( + targets=[ast.Name(id=isNegative_name, ctx=ast.Store())], + value=ast.Compare( + left=ast.Name(id=step_ori_name, ctx=ast.Load()), + ops=[ast.Lt()], + comparators=[ast.Constant(value=0)], + ), + ), + step, + ) + + # start = stop if isNegative else start + start_name = f"start_{self.counter}" + start = self._build_select_and_assign( + name=start_name, + test=ast.Name(id=isNegative_name, ctx=ast.Load()), + body=ast.Name(id=stop_ori_name, ctx=ast.Load()), + orelse=ast.Name(id=start_ori_name, ctx=ast.Load()), + location=start, + ) + + # stop = start if isNegative else stop + stop_name = f"stop_{self.counter}" + stop = self._build_select_and_assign( + name=stop_name, + test=ast.Name(id=isNegative_name, ctx=ast.Load()), + body=ast.Name(id=start_ori_name, ctx=ast.Load()), + orelse=ast.Name(id=stop_ori_name, ctx=ast.Load()), + location=stop, + ) + + # step = -step if isNegative else step + step_name = f"step_{self.counter}" + step = self._build_select_and_assign( + name=step_name, + test=ast.Name(id=isNegative_name, ctx=ast.Load()), + body=ast.UnaryOp( + op=ast.USub(), operand=ast.Name(id=step_ori_name, ctx=ast.Load()) + ), + orelse=ast.Name(id=step_ori_name, ctx=ast.Load()), + location=step, + ) + + # offset = start + stop if isNegative else 0 + offset_name = f"offset_{self.counter}" + offset = self._build_select_and_assign( + name=offset_name, + test=ast.Name(id=isNegative_name, ctx=ast.Load()), + body=ast.BinOp( + op=ast.Add(), + left=ast.Name(id=start_name, ctx=ast.Load()), + right=ast.Name(id=stop_name, ctx=ast.Load()), + ), + orelse=ast.Constant(value=0), + location=node, + ) + + extra_exprs.append(isNegative) + extra_exprs.append(start) + extra_exprs.append(stop) + extra_exprs.append(step) + extra_exprs.append(offset) + + # Add this to begining of loop body + # for i in range(start, stop, step): + # i = offset - i if isNegative else i + assert isinstance(node.target, ast.Name) + + target_name = node.target.id + target = self._build_select_and_assign( + name=target_name, + test=ast.Name(id=isNegative_name, ctx=ast.Load()), + body=ast.BinOp( + op=ast.Sub(), + left=ast.Name(id=offset_name, ctx=ast.Load()), + right=ast.Name(id=target_name, ctx=ast.Load()), + ), + orelse=ast.Name(id=target_name, ctx=ast.Load()), + location=node.target, + ) + + node.body.insert(0, target) + + return ( + ast.Name(id=start_name, ctx=ast.Load()), + ast.Name(id=stop_name, ctx=ast.Load()), + ast.Name(id=step_name, ctx=ast.Load()), + extra_exprs, + ) + + def transform_for_loop(self, node, active_symbols): + # Check for early exit and raise exception + self.check_early_exit(node, "for") + if node.orelse: + raise DSLAstPreprocessorError( + "dynamic for loop with else is not supported", + filename=self.file_name, + snippet=ast.unparse(node), + ) + + # Get loop target variable name + target_var_name = None + target_var_is_active_before_loop = False if isinstance(node.target, ast.Name): - self.scope_manager.add_to_scope(node.target.id) + target_var_name = node.target.id + for active_symbol in active_symbols: + if target_var_name in active_symbol: + target_var_is_active_before_loop = True + active_symbols.remove(active_symbol) + break - new_for_node = self.transform_for_loop(node, active_symbols) - self.scope_manager.exit_scope() - return new_for_node + # Add necessary exprs to handle this + if target_var_is_active_before_loop: + # Initialize an extra loop carried variable + loop_carried_var_name = f"loop_carried_var_{self.counter}" + pre_loop_expr = ast.copy_location( + ast.Assign( + targets=[ast.Name(id=loop_carried_var_name, ctx=ast.Store())], + value=ast.Name(id=target_var_name, ctx=ast.Load()), + ), + node, + ) + # append an extra assignment to the loop carried variable + node.body.append( + ast.copy_location( + ast.Assign( + targets=[ast.Name(id=loop_carried_var_name, ctx=ast.Store())], + value=ast.Name(id=target_var_name, ctx=ast.Load()), + ), + node, + ) + ) + active_symbols.append({loop_carried_var_name}) + + start_expr, stop_expr, step_expr, has_step = self.extract_range_args(node.iter) + unroll, unroll_full = self.extract_unroll_args(node.iter) + pipelining = self.extract_pipelining_args(node.iter) + used_args, iter_args, flat_args = self.analyze_region_variables( + node, active_symbols + ) + + if has_step: + start, stop, step, exprs = self._handle_negative_step( + node, start_expr, stop_expr, step_expr + ) + else: + start, stop, step, exprs = start_expr, stop_expr, step_expr, [] + + if target_var_is_active_before_loop: + exprs.append(pre_loop_expr) + + func_name = f"loop_body_{self.counter}" + self.counter += 1 + + func_def = self.create_loop_function( + func_name, + node, + start, + stop, + step, + unroll, + unroll_full, + pipelining, + used_args, + iter_args, + flat_args, + ) + + assign = ast.copy_location(self.create_loop_call(func_name, iter_args), node) + + # This should work fine as it modifies the AST structure + exprs = exprs + [func_def, assign] + + if target_var_is_active_before_loop: + # Create a new assignment to the target variable + exprs.append( + ast.copy_location( + ast.Assign( + targets=[ast.Name(id=target_var_name, ctx=ast.Store())], + value=ast.Name(id=loop_carried_var_name, ctx=ast.Load()), + ), + node, + ) + ) + + return exprs def visit_Name(self, node): self.generic_visit(node) @@ -765,16 +1023,30 @@ class DSLPreprocessor(ast.NodeTransformer): func = node.func self.generic_visit(node) - # Check if the function is 'bool' - if isinstance(func, ast.Name) and func.id == "bool": - return ast.copy_location( - ast.Call( - func=ast.Name(id=self.BOOL_CAST, ctx=ast.Load()), - args=[node.args[0]], - keywords=[], - ), - node, - ) + # Rewrite call to some built-in functions + if isinstance(func, ast.Name): + # Check if the function is 'bool' + if func.id == "bool": + return ast.copy_location( + ast.Call( + func=ast.Name(id=self.BOOL_CAST, ctx=ast.Load()), + args=[node.args[0]], + keywords=[], + ), + node, + ) + elif func.id in ["any", "all"]: + helper_func = ( + self.ANY_EXECUTOR if func.id == "any" else self.ALL_EXECUTOR + ) + return ast.copy_location( + ast.Call( + func=ast.Name(id=helper_func, ctx=ast.Load()), + args=[node.args[0]], + keywords=[], + ), + node, + ) elif isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name): def create_downcast_call(arg): return ast.copy_location( @@ -893,21 +1165,20 @@ class DSLPreprocessor(ast.NodeTransformer): return new_decorator_list def visit_FunctionDef(self, node): - self.scope_manager.enter_scope() - self.function_counter += 1 - self.function_name = node.name - if self.function_depth > 0: - self.local_closures.add(node.name) + with self.scope_manager: + self.function_counter += 1 + self.function_name = node.name + if self.function_depth > 0: + self.local_closures.add(node.name) - self.function_depth += 1 + self.function_depth += 1 - # Add function name and arguments - self.scope_manager.add_to_scope(node.name) - for arg in node.args.args: - self.scope_manager.add_to_scope(arg.arg) + # Add function name and arguments + self.scope_manager.add_to_scope(node.name) + for arg in node.args.args: + self.scope_manager.add_to_scope(arg.arg) - self.generic_visit(node) - self.scope_manager.exit_scope() + self.generic_visit(node) self.function_depth -= 1 @@ -916,55 +1187,51 @@ class DSLPreprocessor(ast.NodeTransformer): return node def visit_With(self, node): - self.scope_manager.enter_scope() + with self.scope_manager: + for item in node.items: + if isinstance(item.optional_vars, ast.Name): + self.scope_manager.add_to_scope(item.optional_vars.id) + self.generic_visit(node) - for item in node.items: - if isinstance(item.optional_vars, ast.Name): - self.scope_manager.add_to_scope(item.optional_vars.id) - self.generic_visit(node) - - self.scope_manager.exit_scope() return node def visit_While(self, node): active_symbols = self.scope_manager.get_active_symbols() - self.scope_manager.enter_scope() + # print(active_symbols) + with self.scope_manager: + # Constexpr doesn't get preprocessed + if self.is_node_constexpr(node): + self.generic_visit(node) + return node - # Constexpr doesn't get preprocessed - if self.is_node_constexpr(node): - self.generic_visit(node) - self.scope_manager.exit_scope() - return node + # Check for early exit and raise exception + self.check_early_exit(node, "while") - # Check for early exit and raise exception - self.check_early_exit(node, "while") + used_args, yield_args, flat_args = self.analyze_region_variables( + node, active_symbols + ) + func_name = f"while_region_{self.counter}" + self.counter += 1 - used_args, yield_args, flat_args = self.analyze_region_variables( - node, active_symbols - ) - func_name = f"while_region_{self.counter}" - self.counter += 1 + func_def = self.create_while_function( + func_name, node, used_args, yield_args, flat_args + ) + assign = ast.copy_location( + self.create_loop_call(func_name, yield_args), node + ) - func_def = self.create_while_function( - func_name, node, used_args, yield_args, flat_args - ) - assign = ast.copy_location(self.create_loop_call(func_name, yield_args), node) - - self.scope_manager.exit_scope() return [func_def, assign] def visit_Try(self, node): - self.scope_manager.enter_scope() - self.generic_visit(node) - self.scope_manager.exit_scope() + with self.scope_manager: + self.generic_visit(node) return node def visit_ExceptHandler(self, node): - self.scope_manager.enter_scope() - if node.name: # Exception variable - self.scope_manager.add_to_scope(node.name) - self.generic_visit(node) - self.scope_manager.exit_scope() + with self.scope_manager: + if node.name: # Exception variable + self.scope_manager.add_to_scope(node.name) + self.generic_visit(node) return node def create_if_call(self, func_name, yield_args, flat_args): @@ -992,17 +1259,7 @@ class DSLPreprocessor(ast.NodeTransformer): Visits an inline if-else expression (ternary operator). This is the Python equivalent of `x if condition else y`. """ - # Check if the condition is constexpr - constexpr_val, test = self.is_constexpr(node) - - node.test = test - node.body = self.visit(node.body) - node.orelse = self.visit(node.orelse) - - # If it's a constexpr node, we don't need to transform it - if constexpr_val.value is True: - return node - + self.generic_visit(node) # Emit # node if type(pred) == bool else select_(pred, body, orelse) # so if pred is a python bool, use python to short-circuit and avoid emit arith.select @@ -1031,61 +1288,78 @@ class DSLPreprocessor(ast.NodeTransformer): node, ) + cmpops = { + "Eq": "==", + "NotEq": "!=", + "Lt": "<", + "LtE": "<=", + "Gt": ">", + "GtE": ">=", + "Is": "is", + "IsNot": "is not", + "In": "in", + "NotIn": "not in", + } + def compare_ops_to_str(self, node): + names = [ + ast.Constant(value=self.cmpops[op.__class__.__name__]) for op in node.ops + ] + return ast.List(elts=names, ctx=ast.Load()) + + def visit_Compare(self, node): + self.generic_visit(node) + + comparator_strs = self.compare_ops_to_str(node) + + keywords = [ + ast.keyword(arg="left", value=node.left), + ast.keyword( + arg="comparators", value=ast.List(elts=node.comparators, ctx=ast.Load()) + ), + ast.keyword(arg="ops", value=comparator_strs), + ] + + call = ast.copy_location( + ast.Call( + func=ast.Name(id=self.COMPARE_EXECUTOR, ctx=ast.Load()), + args=[], + keywords=keywords, + ), + node, + ) + + return call + def visit_If(self, node): active_symbols = self.scope_manager.get_active_symbols() - self.scope_manager.enter_scope() + with self.scope_manager: + # non-dynamic expr doesn't get preprocessed + if self.is_node_constexpr(node): + self.generic_visit(node) + return node - # Constexpr doesn't get preprocessed - if self.is_node_constexpr(node): - self.generic_visit(node) - self.scope_manager.exit_scope() - return node + # Check for early exit and raise exception + self.check_early_exit(node, "if") - # Check for early exit and raise exception - self.check_early_exit(node, "if") + used_args, yield_args, flat_args = self.analyze_region_variables( + node, active_symbols + ) + func_name = f"if_region_{self.counter}" + self.counter += 1 - used_args, yield_args, flat_args = self.analyze_region_variables( - node, active_symbols - ) - func_name = f"if_region_{self.counter}" - self.counter += 1 + func_def = self.create_if_function( + func_name, node, used_args, yield_args, flat_args + ) + assign = ast.copy_location( + self.create_if_call(func_name, yield_args, flat_args), node + ) - func_def = self.create_if_function( - func_name, node, used_args, yield_args, flat_args - ) - assign = ast.copy_location( - self.create_if_call(func_name, yield_args, flat_args), node - ) - - self.scope_manager.exit_scope() return [func_def, assign] - def is_constexpr(self, node): - """Determines if the if condition is wrapped in const_expr or dynamic_expr""" - if isinstance(node.test, ast.Call): - func = node.test.func - - # Check if the function is 'const_expr' - if isinstance(func, ast.Name) and func.id == "const_expr": - return ast.Constant(value=True), node.test.args[0] - - # Check if the function is 'dynamic_expr' - elif isinstance(func, ast.Name) and func.id == "dynamic_expr": - return ast.Constant(value=False), self.visit(node.test.args[0]) - - # Check if it's an attribute access for 'const_expr' or 'dynamic_expr' - elif isinstance(func, ast.Attribute): - if func.attr == "const_expr": - return ast.Constant(value=True), node.test.args[0] - elif func.attr == "dynamic_expr": - return ast.Constant(value=False), self.visit(node.test.args[0]) - - return ast.Constant(value=None), self.visit(node.test) - def create_if_function( self, func_name, node, used_args, yield_args, flattened_args ): - is_constexpr, test_expr = self.is_constexpr(node) + test_expr = self.visit(node.test) pred_name = self.make_func_param_name("pred", flattened_args) func_args = [ast.arg(arg=pred_name, annotation=None)] func_args += [ast.arg(arg=var, annotation=None) for var in flattened_args] @@ -1265,8 +1539,6 @@ class DSLPreprocessor(ast.NodeTransformer): arg="else_block", value=ast.Name(id=else_block_name, ctx=ast.Load()) ) ) - # Add constexpr - execute_keywords.append(ast.keyword(arg="constexpr", value=is_constexpr)) execute_call = ast.copy_location( ast.Call( @@ -1328,7 +1600,7 @@ class DSLPreprocessor(ast.NodeTransformer): cond, yield_args = while_before_block(yield_args) return yield_args """ - is_constexpr, test_expr = self.is_constexpr(node) + test_expr = self.visit(node.test) pred_name = self.make_func_param_name("pred", flattened_args) # Section: decorator construction @@ -1438,7 +1710,6 @@ class DSLPreprocessor(ast.NodeTransformer): arg="while_after_block", value=ast.Name(id=while_after_block_name, ctx=ast.Load()), ), - ast.keyword(arg="constexpr", value=is_constexpr), ast.keyword( arg="yield_arg_names", value=ast.List( diff --git a/python/CuTeDSL/base_dsl/dsl.py b/python/CuTeDSL/base_dsl/dsl.py index 4870cdae..db3f0392 100644 --- a/python/CuTeDSL/base_dsl/dsl.py +++ b/python/CuTeDSL/base_dsl/dsl.py @@ -164,16 +164,17 @@ def _mlir_type_to_numpy_type(type): def is_dynamic_expression(value): """ - Check if the value is an MLIR's SSA value. + Given the `value`, check if itself is an IR value or recursively go through it to check if it contains IR value """ - # Case 1: If the value has MLIR's SSA value, return True - # Case 2: If the value supports __extract_mlir_values__ then it's possible to get SSA value - return ( - isinstance(value, ir.Value) - or hasattr(value, "__extract_mlir_values__") - or len(extract_mlir_values(value)) > 0 - ) - + if isinstance(value, (tuple, list)): + for x in value: + if is_dynamic_expression(x): + return True + elif isinstance(value, (ir.Value, ir.BlockArgumentList)) or hasattr( + value, "__extract_mlir_values__" + ): + return True + return False def extract_mlir_values(obj): """ @@ -726,6 +727,7 @@ class BaseDSL: ) jit_arg_types, jit_arg_attrs, jit_exec_args = [], [], [] + jit_adapted_args = [] default_attr = ir.DictAttr.get({}) input_args = [*args, *kwargs.values()] @@ -759,7 +761,9 @@ class BaseDSL: # If not any known type, try JIT argument adapter # to convert the argument adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) - arg = adapter(arg) if adapter else arg + if adapter: + arg = adapter(arg) + jit_adapted_args.append(arg) if is_host: jit_exec_arg.extend(get_c_pointers(arg)) @@ -798,14 +802,14 @@ class BaseDSL: jit_arg_types.extend(jit_arg_type) jit_arg_attrs.extend(jit_arg_attr) - return jit_exec_args, jit_arg_types, jit_arg_attrs + return jit_exec_args, jit_arg_types, jit_arg_attrs, jit_adapted_args def generate_mlir_function_types( self, func, function_name, input_args, kwargs, args_spec: inspect.FullArgSpec ): """Convert input arguments to MLIR function signature also convert numpy arrays to memref.""" - exe_args, types, _ = self._generate_jit_func_args( + exe_args, types, attrs, adapted_args = self._generate_jit_func_args( func, function_name, input_args, kwargs, args_spec, is_host=True ) @@ -816,7 +820,7 @@ class BaseDSL: types ), "expects the same number of arguments and function parameters" - return exe_args, types + return exe_args, types, adapted_args @dataclass class LaunchConfig: @@ -1158,7 +1162,7 @@ class BaseDSL: """Generate MLIR module and compile iself.T_provider.""" with ir.Context(), ir.Location.unknown(): # Convert input arguments to MLIR arguments - exe_args, func_types = self.generate_mlir_function_types( + exe_args, func_types, adapted_args = self.generate_mlir_function_types( funcBody, function_name, args, kwargs, args_spec ) @@ -1476,7 +1480,7 @@ class BaseDSL: if self.device_compilation_only: return kernel_operands, kernel_arg_types, kernel_arg_attrs - kernel_operands, kernel_arg_types, kernel_arg_attrs = ( + kernel_operands, kernel_arg_types, kernel_arg_attrs, _ = ( self._generate_jit_func_args( kernel_func, kernel_name, args, kwargs, args_spec, is_host=False ) @@ -1586,12 +1590,14 @@ class BaseDSL: if self.device_compilation_only: log().debug("Generating cuda-python arguments") # Convert input arguments to MLIR arguments - self.exe_args, kernel_types = self.generate_mlir_function_types( - funcBody, - kernel_name, - canonicalized_args, - canonicalized_kwargs, - args_spec, + self.exe_args, kernel_types, _ = ( + self.generate_mlir_function_types( + funcBody, + kernel_name, + canonicalized_args, + canonicalized_kwargs, + args_spec, + ) ) helper = kernelGenHelper() diff --git a/python/CuTeDSL/base_dsl/env_manager.py b/python/CuTeDSL/base_dsl/env_manager.py index ef1fea7a..4a8f6591 100644 --- a/python/CuTeDSL/base_dsl/env_manager.py +++ b/python/CuTeDSL/base_dsl/env_manager.py @@ -78,11 +78,9 @@ def detect_gpu_arch(prefix): major, minor = arch suffix = "" - if major >= 9 and minor >= 0: + if major >= 9: suffix = "a" - elif minor != 0: - # e.g sm_86, belong with sm_80 family - minor = 0 + return f"sm_{major}{minor}{suffix}" diff --git a/python/CuTeDSL/base_dsl/jit_executor.py b/python/CuTeDSL/base_dsl/jit_executor.py index 3ed9282b..83268009 100644 --- a/python/CuTeDSL/base_dsl/jit_executor.py +++ b/python/CuTeDSL/base_dsl/jit_executor.py @@ -12,23 +12,24 @@ """ This module provides jit executor related classes """ -import io -import inspect import ctypes -import numpy as np +import inspect +import io from typing import get_origin +import numpy as np + +# MLIR modules imports +from .._mlir import ir + # Local modules imports -from .utils.timer import timer -from .utils.logger import log +from . import typing as t from .common import DSLRuntimeError from .runtime import cuda as cuda_helpers from .runtime.jit_arg_adapters import JitArgAdapterRegistry, is_arg_spec_constexpr from .typing import get_c_pointers -from . import typing as t - -# MLIR modules imports -from .._mlir import ir +from .utils.logger import log +from .utils.timer import timer class CudaSingleModule: @@ -64,6 +65,7 @@ class JitExecutor: self.args_spec = args_spec self.function_name = function_name if args_spec is not None: + self.original_args_spec = args_spec self.args_spec = self.filter_runtime_arg_spec(args_spec) # cuda kernels self.cuda_modules = cuda_modules @@ -135,6 +137,29 @@ class JitExecutor: for module in set(cuda_modules): cuda_helpers.unload_cubin_module(module) + def get_constexpr_args(self) -> list[dict[str, int | str]]: + """ + This function returns the constexpr args that have been pruned from the original function signature. + The return type is a list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name). + + :return: list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name). + :rtype: list[dict[str, int | str]] + """ + if self.original_args_spec is None: + return list() + constexpr_args = list() + for i, arg_name in enumerate(self.original_args_spec.args): + if arg_name not in self.args_spec.args: + constexpr_args.append({"argument_index": i, "argument_name": arg_name}) + + if self.original_args_spec.kwonlyargs: + for kwarg in self.original_args_spec.kwonlyargs: + if kwarg not in self.args_spec.kwonlyargs: + constexpr_args.append( + {"argument_index": None, "argument_name": kwarg} + ) + return constexpr_args + def generate_execution_args(self, args, kwargs, args_spec: inspect.FullArgSpec): """ This function is the prune version of `generate_mlir_function_types` which only generates execution args @@ -175,6 +200,7 @@ class JitExecutor: ) exe_args = [] + adapted_args = [] input_args = rectified_args + list(rectified_kwargs.values()) input_arg_names = args_spec.args + args_spec.kwonlyargs for arg, arg_name in zip(input_args, input_arg_names): @@ -193,13 +219,16 @@ class JitExecutor: adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) if adapter: arg = adapter(arg) + adapted_args.append(arg) exe_args.extend(get_c_pointers(arg)) - return exe_args + return exe_args, adapted_args def __call__(self, *args, **kwargs): - exe_args = self.generate_execution_args(args, kwargs, self.args_spec) + exe_args, adapted_args = self.generate_execution_args( + args, kwargs, self.args_spec + ) self.run_compiled_program(exe_args) diff --git a/python/CuTeDSL/base_dsl/typing.py b/python/CuTeDSL/base_dsl/typing.py index 6554db61..3724d325 100644 --- a/python/CuTeDSL/base_dsl/typing.py +++ b/python/CuTeDSL/base_dsl/typing.py @@ -46,29 +46,75 @@ from .._mlir.dialects import arith, math @runtime_checkable class DynamicExpression(Protocol): - """ - This is a protocol class that provides a common interface - to generate user-defined dynamic expressions. + """Protocol defining the interface for object holding dynamic values in the DSL. - The DSL checks this protocol to determine if a class is a dynamic expression (SSA value) or not. + This protocol enables classes to represent dynamic values in the DSL. Classes implementing + this protocol can be used in JIT-compiled functions and dynamic value generation. + + It is required for custom data types to work correctly with following JIT features: + * as function argument to call another JIT function from JIT function + * as return value from JIT function + * for constructions like if-else, while-loop, etc. + + :param value: The MLIR operation result value to initialize the object with + :type value: ir.Value + + **Required Methods** + + * ``__extract_mlir_values__``: Extract MLIR values from the object + * ``__new_from_mlir_values__``: Create new instance from MLIR values + + **Implementation Example** + + To implement a custom data type that works with the DSL: + + .. code-block:: python + + class CustomData(metaclass=DslType): + def __init__(self, int_value): + self.int_value = int_value + + def __extract_mlir_values__(self): + return [self.int_value] + + def __new_from_mlir_values__(self, values): + return CustomData(values[0]) + + **Usage in JIT Functions** + + When used in JIT-compiled functions, the DSL automatically extracts MLIR values: + + .. code-block:: python + + @jit + def caller(): + x = CustomData(1) + return foo(x) + + This generates MLIR like: + + .. code-block:: mlir + + func @caller() -> i32 { + %0 = func.call @foo(%arg0) : (i32) -> i32 + return %0 : i32 + } """ def __extract_mlir_values__(self): - """ - Generate a dynamic expression for the current object. + """Extract MLIR values from this object. - :return: List of MLIR values + :return: List of MLIR values representing this object's data :rtype: List[ir.Value] """ raise NotImplementedError def __new_from_mlir_values__(self, values): - """ - Create a new object from MLIR values. + """Create a new instance from MLIR values. - :param values: List of MLIR values + :param values: List of MLIR values to construct the object from :type values: List[ir.Value] - :return: A new instance of the class that implements this protocol + :return: New instance of the implementing class :rtype: Any """ raise NotImplementedError @@ -77,50 +123,73 @@ class DynamicExpression(Protocol): @runtime_checkable class JitArgument(Protocol): """ - This is a protocol class that provides a common interface - for JIT function arguments generation for Python to call JIT functions. + Protocol class defining the interface for JIT function argument generation. - The DSL checks this protocol to determine if a class is capable of providing information - needed for generating JIT function arguments. + This protocol enables classes to provide the necessary information for generating + JIT function arguments and allow the DSL JIT executor to call JIT compiled functions. - See breakdowns below for JitArgument protocol based JIT function calls. + **Required Methods** + + * ``__c_pointers__``: Returns ctypes pointers for runtime execution + * ``__get_mlir_types__``: Returns MLIR types for function definition + * ``__new_from_mlir_values__``: Creates new instances from MLIR values + + **Example** .. code-block:: python + class CustomData: + def __init__(self, int_value, ...): + self.int_value = int_value + ... + + def __c_pointers__(self): + return [ctypes.pointer(ctypes.c_int32(self.int_value)), ...] + + def __get_mlir_types__(self): + return [ir.IntegerType.get(32), ...] + + def __new_from_mlir_values__(self, values): + return CustomData(values[0], ...) + @jit def foo(x: CustomData): - return x.int_value + 1 + a = x.int_value + 1 + ... - # Emit: `%c0 = arith.constant(1, i32)` - c1 = const(1, Int32) - # `c1` tracks `%c0` defined outside of function body of `foo` - # `%c0` can't be used directly in function body of `foo` - x = CustomData(c1, ...) + # `CustomData` is an argument of `foo` + foo(CustomData(1, ...)) When called like ``y = foo(x)``, the following steps occur: - 1. JIT compiler generates MLIR function definition using ``__get_mlir_types__``: + 1. JIT compiler generates MLIR function definition using ``__get_mlir_types__`` .. code-block:: mlir - func @foo(%arg0: i32, ...) -> i32 { + func.func @foo(%arg0: i32, ...) { ... + + return } - 2. Function is traced in Python, wrapping MLIR values with ``__new_from_mlir_values__``: + 2. JIT function can't use values from Python, so it needs to reconstruct the object from + MLIR values, a.k.a `%arg0`, with ``__new_from_mlir_values__`` and pass it to `foo`. + + Following code demonstrates how JIT compiler reconstructs the object and pass to Python. .. code-block:: python # Implementation of IR tracing new_x = CustomData(ir.Value(%arg0), ...) y = foo(new_x) - # `x.int_value` is %arg0 rather than `c1` defined outside + # `x.int_value` is %arg0 rather than `c1` defined by Python. - 3. For Python runtime execution, JIT engine invokes compiled function using ``__c_pointers__``: + 3. For Python runtime execution, JIT engine invokes compiled function using ``__c_pointers__`` + pointing to the underlying data object passing to JIT compiled function. .. code-block:: python - jit_engine.invoke(foo, concat([x.__c_pointers__(), ...])) + jit_engine.invoke(compiled_foo, concat([x.__c_pointers__(), ...])) """ def __c_pointers__(self): @@ -224,47 +293,6 @@ class DslType(type): :property mlir_type: Returns the corresponding MLIR type for this DSL type :type mlir_type: Any - **Examples** - - Define a custom data type: - - .. code-block:: python - - class CustomData(metaclass=DslType, ...): - def __init__(self, int_value, ...): - self.int_value = int_value - ... - - def __str__(cls): - return "CustomData[int, ...]" - - def __c_pointers__(self): - return [ctypes.pointer(ctypes.c_int32(self.int_value)), ...] - - def __get_mlir_types__(self): - return [_T.i32(), ...] - - def __extract_mlir_values__(self): - return [self.int_value, ...] - - def __new_from_mlir_values__(self, values): - return CustomData(values[0], ...) - - For JIT function calls, MLIR values are extracted with ``__extract_mlir_values__``: - - .. code-block:: python - - @jit - def caller(): - x = CustomData(1, ...) - return foo(x) - - .. code-block:: mlir - - func @caller() -> i32 { - %0 = func.call @foo(%arg0, ...) : (i32, ...) -> i32 - return %0 : i32 - } """ _is_abstract: bool @@ -946,9 +974,12 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): :return: The result of the logical not operation :rtype: Boolean """ - ty = type(self) - zero_val = arith.constant(ty.mlir_type, ty.zero) - return self.__eq__(ty(zero_val), loc=loc, ip=ip) + if isinstance(self.value, (int, float, bool)): + return not self.value + else: + ty = type(self) + zero_val = arith.constant(ty.mlir_type, ty.zero) + return self.__eq__(ty(zero_val), loc=loc, ip=ip) def __dsl_and__(self, other, *, loc=None, ip=None): """DSL implementation of Python's `and` operator. @@ -1057,6 +1088,15 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): ], ) + def __index__(self): + if isinstance(self.value, (int, float, bool)): + return self.value + else: + raise DSLRuntimeError( + f"'{type(self.value)}' object cannot be interpreted as an integer", + suggestion="Mark the loop as dynamic with `dynamic_expr` or `range_dynamic` and decorate the parent function with `jit` decorator", + ) + def __neg__(self, *, loc=None, ip=None): if isinstance(self, (bool, int, float)): return type(self)(-self.value) # type: ignore @@ -1813,7 +1853,7 @@ class IRVariadic: def __init__(self, operands): """ - Create a list of variadic operands. `operands` must be SSA values. + Create a list of variadic operands. `operands` must be dynamic values. """ self.operands = operands diff --git a/python/CuTeDSL/cutlass/cute/__init__.py b/python/CuTeDSL/cutlass/cute/__init__.py index 11496402..0cfb0e03 100644 --- a/python/CuTeDSL/cutlass/cute/__init__.py +++ b/python/CuTeDSL/cutlass/cute/__init__.py @@ -68,7 +68,9 @@ from .core import ( select, front, is_major, + leading_dim, find, + find_if, coalesce, group_modes, cosize, @@ -221,7 +223,9 @@ __all__ = [ "select", "front", "is_major", + "leading_dim", "find", + "find_if", "coalesce", "group_modes", "cosize", diff --git a/python/CuTeDSL/cutlass/cute/arch/__init__.py b/python/CuTeDSL/cutlass/cute/arch/__init__.py index 5114b97f..42806994 100644 --- a/python/CuTeDSL/cutlass/cute/arch/__init__.py +++ b/python/CuTeDSL/cutlass/cute/arch/__init__.py @@ -25,12 +25,13 @@ __all__ = [ # # mbar.py # - "mbarrier_init_arrive_cnt", + "mbarrier_init", "mbarrier_init_fence", - "mbarrier_init_tx_bytes", + "mbarrier_arrive_and_expect_tx", + "mbarrier_expect_tx", "mbarrier_wait", "mbarrier_try_wait", - "conditional_mbarrier_try_wait", + "mbarrier_conditional_try_wait", "mbarrier_arrive", # # nvvm_wrappers.py @@ -51,6 +52,7 @@ __all__ = [ "shuffle_sync_down", "shuffle_sync_bfly", "barrier", + "barrier_arrive", "sync_threads", "sync_warp", "fence_acq_rel_cta", diff --git a/python/CuTeDSL/cutlass/cute/arch/elect.py b/python/CuTeDSL/cutlass/cute/arch/elect.py index fce82b13..ead552af 100644 --- a/python/CuTeDSL/cutlass/cute/arch/elect.py +++ b/python/CuTeDSL/cutlass/cute/arch/elect.py @@ -69,7 +69,16 @@ def elect_one(*, loc=None, ip=None) -> IfOpRegion: pass """ arch = CuTeDSL._get_dsl().envar.arch - check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch") + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) is_thread_leader = nvvm.elect_sync(T.bool()) if_op = scf.IfOp(is_thread_leader, loc=loc, ip=ip) return IfOpRegion(if_op.then_block, loc=loc, ip=ip) diff --git a/python/CuTeDSL/cutlass/cute/arch/mbar.py b/python/CuTeDSL/cutlass/cute/arch/mbar.py index b4dc3725..8a6e3cfb 100644 --- a/python/CuTeDSL/cutlass/cute/arch/mbar.py +++ b/python/CuTeDSL/cutlass/cute/arch/mbar.py @@ -8,6 +8,7 @@ # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA # is strictly prohibited. +from typing import Optional from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op @@ -26,7 +27,7 @@ from ...impl_utils import check_value_in @dsl_user_op -def mbarrier_init_arrive_cnt(mbar_ptr: Pointer, cnt: Int, *, loc=None, ip=None) -> None: +def mbarrier_init(mbar_ptr: Pointer, cnt: Int, *, loc=None, ip=None) -> None: """ Initializes a mbarrier with the specified thread arrival count. @@ -46,16 +47,25 @@ def mbarrier_init_fence(*, loc=None, ip=None) -> None: A fence operation that applies to the mbarrier initializations. """ arch = CuTeDSL._get_dsl().envar.arch - check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch") + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) nvvm.fence_mbarrier_init(loc=loc, ip=ip) @dsl_user_op -def mbarrier_init_tx_bytes( +def mbarrier_arrive_and_expect_tx( mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None ) -> None: """ - Initializes a mbarrier with the specified number of transaction bytes. + Arrives on a mbarrier and expects a specified number of transaction bytes. :param mbar_ptr: A pointer to the mbarrier in SMEM :type mbar_ptr: Pointer @@ -66,7 +76,16 @@ def mbarrier_init_tx_bytes( SMEM. """ arch = CuTeDSL._get_dsl().envar.arch - check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch") + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) mbar_llvm_ptr = mbar_ptr.llvm_ptr if peer_cta_rank_in_cluster is not None: @@ -91,6 +110,56 @@ def mbarrier_init_tx_bytes( ) +@dsl_user_op +def mbarrier_expect_tx( + mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None +) -> None: + """ + Expects a specified number of transaction bytes without an arrive. + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param bytes: The number of transaction bytes + :type bytes: Int + :param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to + the mbarrier is converted to a remote address in the peer CTA's + SMEM. + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + + mbar_llvm_ptr = mbar_ptr.llvm_ptr + if peer_cta_rank_in_cluster is not None: + mbar_llvm_ptr = nvvm.mapa( + mbar_llvm_ptr.type, + mbar_llvm_ptr, + Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + space = nvvm.MBarrierSpaceKind.CLUSTER + else: + space = nvvm.MBarrierSpaceKind.CTA + + nvvm.mbarrier_txn( + mbar_llvm_ptr, + Int32(bytes).ir_value(loc=loc, ip=ip), + kind=nvvm.MBarrierTxnKind.EXPECT_TX, + space=space, + loc=loc, + ip=ip, + ) + + @dsl_user_op def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None: """ @@ -102,7 +171,16 @@ def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None: :type phase: Int """ arch = CuTeDSL._get_dsl().envar.arch - check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch") + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) timeout_ns = 10000000 # This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX @@ -129,7 +207,16 @@ def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Bo :rtype: Boolean """ arch = CuTeDSL._get_dsl().envar.arch - check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch") + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) return Boolean( nvvm.mbarrier_wait_parity( @@ -144,7 +231,7 @@ def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Bo @dsl_user_op -def conditional_mbarrier_try_wait( +def mbarrier_conditional_try_wait( cond, mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None ) -> Boolean: """ @@ -159,7 +246,16 @@ def conditional_mbarrier_try_wait( :rtype: Boolean """ arch = CuTeDSL._get_dsl().envar.arch - check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch") + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) return if_generate( cond, lambda: mbarrier_try_wait(mbar_ptr, phase, loc=loc, ip=ip), @@ -171,7 +267,11 @@ def conditional_mbarrier_try_wait( @dsl_user_op def mbarrier_arrive( - mbar_ptr: Pointer, peer_cta_rank_in_cluster: Int = None, *, loc=None, ip=None + mbar_ptr: Pointer, + peer_cta_rank_in_cluster: Optional[Int] = None, + *, + loc=None, + ip=None, ) -> None: """ Arrives on an mbarrier. @@ -185,7 +285,16 @@ def mbarrier_arrive( mbar_llvm_ptr = mbar_ptr.llvm_ptr if peer_cta_rank_in_cluster is not None: arch = CuTeDSL._get_dsl().envar.arch - check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch") + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) mbar_llvm_ptr = nvvm.mapa_shared_cluster( mbar_llvm_ptr.type, diff --git a/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py b/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py index 03d83c26..f247cf60 100644 --- a/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py +++ b/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py @@ -225,6 +225,25 @@ def barrier(*, barrier_id=None, number_of_threads=None, loc=None, ip=None) -> No barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip ) + +@dsl_user_op +def barrier_arrive( + *, barrier_id=None, number_of_threads=None, loc=None, ip=None +) -> None: + if barrier_id is not None: + barrier_id = Int32(barrier_id).ir_value(loc=loc, ip=ip) + + if number_of_threads is None: + raise ValueError( + "barrier_arrive needs pass number_of_threads to arrive the barrier", + ) + number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip) + + nvvm.barrier_arrive( + barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip + ) + + @dsl_user_op def sync_threads(*, loc=None, ip=None) -> None: """ @@ -545,3 +564,20 @@ def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32: asm_dialect=llvm.AsmDialect.AD_ATT, ) ) + + +# TODO: add `fastmath` flag for this op +@dsl_user_op +def exp(a: Union[float, Float32], *, loc=None, ip=None) -> Float32: + LOG2_E = 1.4426950408889634 + return exp2(a * LOG2_E, loc=loc, ip=ip) + + +# TODO: add `fastmath` flag for this op +@dsl_user_op +def exp_packed_f32x2( + a: Tuple[Float32, Float32], *, loc=None, ip=None +) -> Tuple[Float32, Float32]: + LOG2_E = Float32(1.4426950408889634) + b = mul_packed_f32x2(a, (LOG2_E, LOG2_E), loc=loc, ip=ip) + return exp2(b[0], loc=loc, ip=ip), exp2(b[1], loc=loc, ip=ip) diff --git a/python/CuTeDSL/cutlass/cute/core.py b/python/CuTeDSL/cutlass/cute/core.py index 3f67a411..9e240634 100644 --- a/python/CuTeDSL/cutlass/cute/core.py +++ b/python/CuTeDSL/cutlass/cute/core.py @@ -11,13 +11,25 @@ import copy as py_copy from dataclasses import dataclass +import inspect import math import operator from abc import ABC, abstractmethod from functools import lru_cache, partial, reduce from inspect import isclass from itertools import chain -from typing import Iterable, overload, List, Tuple, Union, Type, Any, Dict, Optional +from typing import ( + Callable, + Iterable, + overload, + List, + Tuple, + Union, + Type, + Any, + Dict, + Optional, +) from enum import Enum, auto from cutlass.cutlass_dsl import ( @@ -100,10 +112,12 @@ def _pack_x(x, packer, op, *, loc=None, ip=None) -> ir.Value: def _pack_shape(shape: Shape, *, loc=None, ip=None) -> ir.Value: + _check_shape(shape) return _pack_x(shape, _cute_ir.pack_shape, _cute_ir.MakeShapeOp, loc=loc, ip=ip) def _pack_stride(stride: Stride, *, loc=None, ip=None) -> ir.Value: + _check_stride(stride) # Convert basis elements to the base class before _pack_x stride = transform_leaf( lambda x: x.to(_cute_ir.ScaledBasis) if isinstance(x, ScaledBasis) else x, @@ -113,16 +127,20 @@ def _pack_stride(stride: Stride, *, loc=None, ip=None) -> ir.Value: def _pack_coord(coord: Coord, *, loc=None, ip=None) -> ir.Value: + _check_coord(coord) return _pack_x(coord, _cute_ir.pack_coord, _cute_ir.MakeCoordOp, loc=loc, ip=ip) def _pack_int_tuple(int_tuple: IntTuple, *, loc=None, ip=None) -> ir.Value: + _check_int_tuple(int_tuple) return _pack_x( int_tuple, _cute_ir.pack_int_tuple, _cute_ir.MakeIntTupleOp, loc=loc, ip=ip ) def _pack_tile(tile: Tile, *, loc=None, ip=None) -> ir.Value: + _check_tile(tile) + def expand_leaves(tile) -> list: leaves = [] for e in tile: @@ -176,6 +194,63 @@ def _unpack_x_tuple(t: Union[ir.Type, ir.Value], *, loc=None, ip=None) -> XTuple return transform_leaf(post_process, res) +#################################################################################################### +# Validation helpers +#################################################################################################### + + +def _check_shape(shape: Shape) -> None: + if is_integer(shape): + if isinstance(shape, int): + if shape <= 0: + raise ValueError( + f"Expected size in shape to be strictly positive, but got {shape}" + ) + elif isinstance(shape, Integer): + pass + else: + raise TypeError(f"Expected size be int or Integer, but got {type(shape)}") + elif isinstance(shape, tuple): + for s in shape: + _check_shape(s) + else: + raise ValueError( + f"Expected Shape, which is a positive integer or tuple of Shapes, but got {shape}" + ) + + +def _check_coord(coord: Coord) -> None: + flat_coord = flatten_to_tuple(coord) + if not all(is_integer(c) or c is None for c in flat_coord): + raise ValueError( + f"Expected Coord, whose leaves are integers or None, but got {coord}" + ) + + +def _check_stride(stride: Stride) -> None: + flat_stride = flatten_to_tuple(stride) + if not all(is_integer(s) or isinstance(s, ScaledBasis) for s in flat_stride): + raise ValueError( + f"Expected Stride, whose leaves are integers or ScaledBasis, but got {stride}" + ) + + +def _check_int_tuple(int_tuple: IntTuple) -> None: + flat_int_tuple = flatten_to_tuple(int_tuple) + if not all(is_integer(d) for d in flat_int_tuple): + raise ValueError( + f"Expected IntTuple, whose leaves are integers, but got {int_tuple}" + ) + + +def _check_tile(tile: Tile) -> None: + flat_tile = flatten_to_tuple(tile) + if not all(is_integer(t) or isinstance(t, _Layout) or t is None for t in flat_tile): + raise ValueError( + f"Expected Tile, whose leaves are integers or Layout or None, but got {tile}" + ) + + #################################################################################################### # # Core types @@ -428,7 +503,9 @@ class ScaledBasis: :type mode: Union[int, List[int]] :raises TypeError: If mode is not an integer or list of integers - **Examples**:: + **Examples:** + + .. code-block:: python # Create a scaled basis with integer scale and mode sb1 = ScaledBasis(2, 0) # 2 * E(0) @@ -572,7 +649,9 @@ def E(mode: Union[int, List[int]]) -> ScaledBasis: :rtype: ScaledBasis :raises TypeError: If mode is not an integer or a list - **Examples**:: + **Examples:** + + .. code-block:: python # Create a basis element for the first dimension (mode 0) e0 = E(0) @@ -664,7 +743,7 @@ class _Layout(Layout): :ivar stride: An IntTuple representing the strides of the layout. :ivar max_alignment: The maximum alignment of the layout. - **Examples**:: + **Examples:** .. code-block:: python @@ -796,7 +875,9 @@ class _Layout(Layout): :param idx: The linear index to convert. :return: The hierarchical coordinate corresponding to the index. - **Examples**:: + **Examples:** + + .. code-block:: python layout = make_layout((4, 8), stride=(8, 1)) @@ -836,7 +917,9 @@ class ComposedLayout(ir.Value): :ivar outer: The outer layout component :ivar max_alignment: The maximum alignment of the composed layout - **Examples**:: + **Examples:** + + .. code-block:: python # Create a composed layout with inner layout, offset, and outer layout @@ -856,11 +939,11 @@ class ComposedLayout(ir.Value): offset = composed.offset outer = composed.outer - # map coordinate (1, 2) to linear index - # - outer(1, 2) = (1, 2) - # - offset + outer(1, 2) = (1, 2) - # - inner(1, 2) = 1 * 1 + 2 * 4 = 9 - idx = crd2idx((1, 2), composed) + # map coordinate (0, 1) to linear index + # - outer(0, 1) = (0, 1) + # - offset + outer(0, 1) = (0, 1) + # - inner(0, 1) = 0 * 1 + 1 * 4 = 4 + idx = crd2idx((0, 1), composed) # Composition is used in many tiling operations # For example, in logical_product, raked_product, and blocked_product @@ -1091,7 +1174,7 @@ class _Pointer(Pointer): @ir.register_value_caster(_cute_ir.MemRefType.get_static_typeid(), replace=True) -@ir.register_value_caster(_cute_ir.CountingTensorType.get_static_typeid(), replace=True) +@ir.register_value_caster(_cute_ir.CoordTensorType.get_static_typeid(), replace=True) @ir.register_value_caster( _cute_nvgpu_ir.SmemDescViewType.get_static_typeid(), replace=True ) @@ -1122,7 +1205,7 @@ class _Tensor(Tensor): - For composed layouts, stride information is not directly accessible - Dynamic layouts do not support vector load/store operations - Examples: + **Examples:** .. code-block:: python @@ -1178,7 +1261,7 @@ class _Tensor(Tensor): def __new_from_mlir_values__(self, values): # Only expecting single value of _Tensor or ir.Value # In this context, a _Tensor instance is an encapsulated ir.Value which is automatically created - # by value caster for MemRef/CountingTensor/SmemDescView typed values + # by value caster for MemRef/CoordTensor/SmemDescView typed values assert len(values) == 1, f"Expected 1 value, but got {len(values)}" assert isinstance( values[0], (_Tensor, ir.Value) @@ -1221,7 +1304,7 @@ class _Tensor(Tensor): :raises ValueError: If coordinate access is invalid for the tensor layout - Examples: + **Examples:** .. code-block:: python @@ -1235,7 +1318,7 @@ class _Tensor(Tensor): val = tensor[1] # Loads element at offset 4 (4bytes per Float32) val = tensor[(0, 1)] # Loads element at offset 64 - # Create a counting tensor + # Create a coord tensor layout = make_layout((64, 128), stride=(1 * E(0), 1 * E(1))) tensor = make_tensor((128, 128), layout) @@ -1251,7 +1334,7 @@ class _Tensor(Tensor): dereference operations. Attempting to set individual elements of tensors with these element types will result in errors. - Examples: + **Examples:** .. code-block:: python @@ -1268,7 +1351,7 @@ class _Tensor(Tensor): """ if has_underscore(crd): return slice_(self.value, crd) - elif isinstance(self.type, _cute_ir.CountingTensorType): + elif isinstance(self.type, _cute_ir.CoordTensorType): res = _cute_ir.get_iter(slice_(self, crd).value, loc=loc, ip=ip) return _unpack_x_tuple(res) else: @@ -1317,8 +1400,8 @@ class _Tensor(Tensor): :param crd: Coordinate or slice specification for tensor element assignment :type crd: Coord - :param value: Value to assign - can be scalar or TensorSSA for slice assignment - :type value: Union[int, float, ir.Value, TensorSSA] + :param data: Value to assign - can be scalar or TensorSSA for slice assignment + :type data: Union[int, float, ir.Value, Numeric, TensorSSA] :param loc: Source location for MLIR operation tracking, defaults to None :type loc: Optional[Location] :param ip: Insertion point for MLIR operation, defaults to None @@ -1334,7 +1417,7 @@ class _Tensor(Tensor): dereference operations. Attempting to set individual elements of tensors with these element types will result in errors. - Examples: + **Examples:** .. code-block:: python @@ -1408,15 +1491,18 @@ class _Tensor(Tensor): @property def leading_dim(self) -> Union[int, Tuple[int], None]: - """ - Get the leading dimension of this Tensor. + """Get the leading dimension of this Tensor. - Returns: - int: Single leading dimension index if found - Tuple[int, ...]: Tuple of indices for nested leading dimensions - None: If no leading dimension is found + :return: The index or indices of the first mode (from left to right) with stride 1 + :rtype: Union[int, Tuple[int], None] + :returns: + - int: Single leading dimension index if found + - Tuple[int]: Tuple of indices for nested leading dimensions + - None: If no leading dimension is found + + :postcondition: ``get(self.stride(), mode=self.leading_dim()) == 1 if self.leading_dim() != None else True`` """ - return find(1, self.stride, exclude_when=(1, self.shape)) + return leading_dim(self.shape, self.stride) @property @lru_cache_ir() @@ -1526,7 +1612,7 @@ class _Tensor(Tensor): :raises NotImplementedError: If tensor has dynamic size - Examples: + **Examples:** .. code-block:: python @@ -1570,15 +1656,30 @@ class _Tensor(Tensor): def print_tensor(tensor: Tensor, *, verbose: bool = False, loc=None, ip=None): """Print content of the tensor in human readable format. - tensor(raw_ptr<@..., Float32, generic, align(4)> o (8,5):(5,1), data= - [[-0.4326, -0.5434, 0.1238, 0.7132, 0.8042], - [-0.8462, 0.9871, 0.4389, 0.7298, 0.6948], - [ 0.3426, 0.5856, 0.1541, 0.2923, 0.6976], - [-0.1649, 0.8811, 0.1788, 0.1404, 0.2568], - [-0.2944, 0.8593, 0.4171, 0.8998, 0.1766], - [ 0.8814, 0.7919, 0.7390, 0.4566, 0.1576], - [ 0.9159, 0.7577, 0.6918, 0.0754, 0.0591], - [ 0.6551, 0.1626, 0.1189, 0.0292, 0.8655]]) + Outputs the tensor data in a structured format showing both metadata + and the actual data values. The output includes tensor type information, + layout details, and a formatted array representation of the values. + + :param tensor: The tensor to print + :type tensor: Tensor + :param verbose: If True, includes additional debug information in the output + :type verbose: bool + :param loc: Source location where it's called, defaults to None + :type loc: source location, optional + :param ip: Insertion pointer for IR generation, defaults to None + :type ip: insertion pointer, optional + :raises NotImplementedError: If the tensor type doesn't support trivial dereferencing + + Example output: + tensor(raw_ptr<@..., Float32, generic, align(4)> o (8,5):(5,1), data= + [[-0.4326, -0.5434, 0.1238, 0.7132, 0.8042], + [-0.8462, 0.9871, 0.4389, 0.7298, 0.6948], + [ 0.3426, 0.5856, 0.1541, 0.2923, 0.6976], + [-0.1649, 0.8811, 0.1788, 0.1404, 0.2568], + [-0.2944, 0.8593, 0.4171, 0.8998, 0.1766], + [ 0.8814, 0.7919, 0.7390, 0.4566, 0.1576], + [ 0.9159, 0.7577, 0.6918, 0.0754, 0.0591], + [ 0.6551, 0.1626, 0.1189, 0.0292, 0.8655]]) """ if not isinstance(tensor.type, _cute_ir.MemRefType): raise NotImplementedError( @@ -1611,10 +1712,9 @@ def print_tensor(tensor: Tensor, *, verbose: bool = False, loc=None, ip=None): @lru_cache_ir() def is_integer(a) -> bool: """Check if an object is static integer or dynamic integer""" - return ( - isinstance(a, int) - or isinstance(a, Integer) - or (isinstance(a, ir.Value) and isinstance(a.type, ir.IntegerType)) + return isinstance(a, (int, Integer)) or ( + isinstance(a, ir.Value) + and isinstance(a.type, (ir.IntegerType, _cute_ir.ConstrainedIntType)) ) @@ -1726,7 +1826,18 @@ def pretty_str(arg) -> str: @dsl_user_op def printf(*args, loc=None, ip=None) -> None: - """Print a value or a list of values. + """ + Print a value or a list of values. + + It supports c-style printf format as well: + + .. code-block:: python + + a = cute.make_layout(shape=(10, 10), stride=(10, 1)) + b = cutlass.Float32(1.234) + cute.printf(a, b) + cute.printf("a={}, b={}", a, b) + cute.printf("a={}, b=%.2f", a, b) :param args: List of values to print :type args: list @@ -1810,75 +1921,140 @@ def is_major(mode, stride: Stride, *, loc=None, ip=None) -> bool: return True if first_stride == 1 else False +def leading_dim(shape: Shape, stride: Stride) -> Union[int, Tuple[int, ...], None]: + """ + Find the leading dimension of a shape and stride. + + :param shape: The shape of the tensor or layout + :type shape: Shape + :param stride: The stride of the tensor or layout + :type stride: Stride + :return: The leading dimension index or indices + :rtype: Union[int, Tuple[int, ...], None] + + The return value depends on the stride pattern: + + * If a single leading dimension is found, returns an integer index + * If nested leading dimensions are found, returns a tuple of indices + * If no leading dimension is found, returns None + """ + + def pred_fn(val, pos): + # skip dynamic values which can't be compared + # find the candidate target val, stride at this position is 1 + if (not is_dynamic_expression(val)) and (val == 1): + # extract the shape at this position + mode = [pos] if isinstance(pos, int) else list(pos) + s = get(shape, mode) + if is_dynamic_expression(s) or s != 1: + # shape at this position is dynamic value or not 1 + # we found the leading dimension + return True + return False + + return find_if(stride, pred_fn=pred_fn) + + @dsl_user_op -def find( - x: int, +def find_if( t: Union[tuple, ir.Value, int], + pred_fn: Callable[[int, Tuple[int, ...]], bool], *, - exclude_when: Optional[IntTuple] = None, loc=None, ip=None, ) -> Union[int, Tuple[int, ...], None]: - """Find the first position of a x in t. - If exclude_when is provided, the positions where comparison equals comparison_value will be excluded from the search results. + """Find the first position in t where pred_fn(val, pos) returns True. - :param x: The static integer x to search for - :type x: int :param t: The search space :type t: Union[tuple, ir.Value, int] - :param exclude_when: A tuple of (comparison_value, comparison) - positions where comparison equals comparison_value will be excluded from the search results - :type exclude_when: Optional[Tuple[int, Union[tuple, ir.Value, int]]] + :param pred_fn: A callable object (lambda, function, etc.) that predicates the value and position in t. + It takes the current leaf value and position, returns True if the value or position is satisfied. + :type pred_fn: Callable[[int, Tuple[int, ...]], bool] + :return: Index if found at top level, tuple of indices showing nested position, or None if not found + :rtype: Union[int, Tuple[int, ...], None] + + Examples: + .. code-block:: python + + # Find the first position of x in t + t = (3, 4) + find_if(t, pred_fn=lambda val, pos: val == x) + + .. code-block:: python + + # find the leading dimension + shape = (3, 4) + stride = (4, 1) + # Find value 1 in stride where the corresponding shape is not 1 + def pred_fn(val, pos): + mode = [pos] if isinstance(pos, int) else list(pos) + return val == 1 and get(shape, mode) != 1 + find_if(stride, pred_fn=pred_fn) + """ + + def _find_if_impl(curr, pos, *, loc=None, ip=None): + if isinstance(curr, tuple): + # Recursively search nested tuple + for i in range(rank(curr)): + sub_curr = get(curr, mode=[i], loc=loc, ip=ip) + sub_pos = (pos, i) if isinstance(pos, int) else pos + (i,) + res_pos = _find_if_impl(sub_curr, sub_pos, loc=loc, ip=ip) + if res_pos is not None: + return res_pos + else: + # For leaf values, check if it matches x + if pred_fn(curr, pos): + return pos + return None + + def _check_pred_fn(): + if not callable(pred_fn): + raise TypeError(f"pred_fn must be callable, but got {type(pred_fn)}") + signature = inspect.signature(pred_fn) + if len(signature.parameters) != 2: + raise ValueError( + f"pred_fn must have two parameters (value, pos), but got {len(signature.parameters)}" + ) + + _check_pred_fn() + + for i in range(rank(t)): + curr = get(t, mode=[i], loc=loc, ip=ip) + res_pos = _find_if_impl(curr, i, loc=loc, ip=ip) + if res_pos is not None: + return res_pos + return None + + +@dsl_user_op +def find( + t: Union[tuple, ir.Value, int], + x: int, + *, + loc=None, + ip=None, +) -> Union[int, Tuple[int, ...], None]: + """Find the first position of a value ``x`` in a hierarchical structure ``t``. + + Searches for the first occurrence of x in t, optionally excluding positions + where a comparison value matches. The search can traverse nested structures + and returns either a single index or a tuple of indices for nested positions. + + :param t: The search space + :type t: Union[tuple, ir.Value, int] + :param x: The static integer x to search for + :type x: int :return: Index if found at top level, tuple of indices showing nested position, or None if not found :rtype: Union[int, Tuple[int, ...], None] """ if not isinstance(x, int): raise TypeError(f"find() requires a static x to search for, but got {x}") - # Extract comparison value and tuple from exclude_when if provided - comparison_value, comparison = None, None - if exclude_when is not None: - comparison_value, comparison = exclude_when + def pred_fn(val, pos): + # Skip dynamic values which can't be compared + return not is_dynamic_expression(val) and val == x - # Iterate through t, checking both nested tuples and leaf values - for i in range(rank(t)): - # Get current elements from t and comparison - curr1 = get(t, mode=[i], loc=loc, ip=ip) - curr2 = ( - get(comparison, mode=[i], loc=loc, ip=ip) - if comparison is not None - else None - ) - - if isinstance(curr1, tuple): - # Recursively search nested tuple - sub_pos = find( - x, - curr1, - exclude_when=( - (comparison_value, curr2) if comparison is not None else None - ), - loc=loc, - ip=ip, - ) - if sub_pos is not None: - # Combine current index with recursive result - if isinstance(sub_pos, int): - return (i, sub_pos) - return (i,) + sub_pos - else: - # For leaf values, check if it matches x - # Skip dynamic expressions and Numeric types which can't be compared - if not (is_dynamic_expression(curr1) or isinstance(curr1, Numeric)): - if curr1 == x: - if ( - comparison is None - or is_dynamic_expression(curr2) - or isinstance(curr2, Numeric) - or curr2 != comparison_value - ): - return i - - return None + return find_if(t, pred_fn=pred_fn, loc=loc, ip=ip) def transform_leaf(f, *args): @@ -2081,7 +2257,9 @@ def get(input, mode: List[int], *, loc=None, ip=None): :raises ValueError: If any index in mode is out of range :raises TypeError: If mode contains non-integer elements or if input has unsupported type - **Examples**: + :postcondition: ``get(t, mode=find(x,t)) == x if find(x,t) != None else True`` + + **Examples:** For a layout like ((4,8),2):((16,1),8), get with mode=[0,1] would extract the element 8 from the shape component. @@ -2143,6 +2321,20 @@ def select(input, mode: List[int], *, loc=None, ip=None): :rtype: Layout, ComposedLayout, tuple :raises ValueError: If any index in mode is out of range :raises TypeError: If the input type is invalid + + **Examples:** + + .. code-block:: python + + # Select specific dimensions from a layout + layout = make_layout((4, 8, 16), stride=(32, 4, 1)) + selected = select(layout, mode=[0, 2]) # Select mode 0 and mode 2 + # Result: (4, 16):(32, 1) + + # Select elements from a tuple + t = (1, 2, 3, 4, 5) + selected = select(t, mode=[0, 2, 4]) # Select mode 0, mode 2, and mode 4 + # Result: (1, 3, 5) """ if any((not isinstance(i, int)) or (i >= rank(input)) for i in mode): raise ValueError( @@ -2208,7 +2400,7 @@ def group_modes(input, begin: int, end: int = -1, *, loc=None, ip=None): :return: A new object with the specified modes grouped :rtype: Same type as input with modified structure - Examples: + **Examples:** .. code-block:: python @@ -2273,7 +2465,7 @@ def slice_(src, coord: Coord, *, loc=None, ip=None): :rtype: Union[Tensor, Layout, IntTuple, tuple] :raises ValueError: If the coordinate pattern is incompatible with source - Examples: + **Examples:** .. code-block:: python @@ -2376,7 +2568,7 @@ def dice(src, dicer, *, loc=None, ip=None): :raises TypeError: If dicer has an unsupported type :raises ValueError: If input is not provided - Examples: + **Examples:** .. code-block:: python @@ -2513,7 +2705,7 @@ def prepend(input, elem, up_to_rank: Union[None, int] = None, *, loc=None, ip=No :raises ValueError: If up_to_rank is less than input's current rank :raises TypeError: If input or elem has unsupported type - Examples: + **Examples:** .. code-block:: python @@ -2582,7 +2774,7 @@ def append(input, elem, up_to_rank: Union[None, int] = None, *, loc=None, ip=Non :raises ValueError: If up_to_rank is less than input's current rank :raises TypeError: If input or elem has unsupported type - Examples: + **Examples:** .. code-block:: python @@ -2638,7 +2830,7 @@ def repeat_like(x, target): :return: A structure matching target but filled with x :rtype: Union[tuple, Any] - Examples: + **Examples:** .. code-block:: python @@ -2666,7 +2858,7 @@ def flatten_to_tuple(a: Union[IntTuple, Coord, Shape, Stride]) -> tuple: :return: A flattened tuple containing all elements from the input :rtype: tuple - Examples: + **Examples:** .. code-block:: python @@ -2692,7 +2884,7 @@ def flatten(a: Union[IntTuple, Coord, Shape, Stride, Layout, Tensor]) -> tuple: :rtype: Union[tuple, Any] :raises NotImplementedError: If input is a Layout or Tensor - Examples: + **Examples:** .. code-block:: python @@ -3069,7 +3261,7 @@ def make_layout( :return: A new Layout object with the specified shape and stride :rtype: Layout - Examples: + **Examples:** .. code-block:: python @@ -3098,6 +3290,9 @@ def make_layout( * make_layout((3,4), (1,4)) can be confusing with make_layout(((3,4), (1,4))) * make_layout((3,4), stride=(1,4)) is more readable """ + if stride is not None and not is_congruent(shape, stride): + raise ValueError(f"shape and stride must be congruent") + shape_val = _pack_shape(shape, loc=loc, ip=ip) if stride is not None: stride_val = _pack_stride(stride, loc=loc, ip=ip) @@ -3127,7 +3322,7 @@ def make_identity_layout(shape: Shape, *, loc=None, ip=None) -> Layout: :return: A new identity Layout object with the specified shape :rtype: Layout - Examples: + **Examples:** .. code-block:: python @@ -3165,7 +3360,7 @@ def make_ordered_layout(shape: Shape, order: Shape, *, loc=None, ip=None) -> Lay :return: A new Layout object with the specified shape and dimension ordering :rtype: Layout - Examples: + **Examples:** .. code-block:: python @@ -3184,7 +3379,7 @@ def make_ordered_layout(shape: Shape, order: Shape, *, loc=None, ip=None) -> Lay - The length of order must match the rank of the shape """ shape_val = _pack_shape(shape, loc=loc, ip=ip) - order_val = _pack_shape(order, loc=loc, ip=ip) + order_val = _pack_int_tuple(order, loc=loc, ip=ip) return _cute_ir.make_ordered_layout( shape=shape_val, order=order_val, loc=loc, ip=ip ) @@ -3213,7 +3408,7 @@ def make_composed_layout( :return: A new ComposedLayout representing the composition :rtype: ComposedLayout - Examples: + **Examples:** .. code-block:: python @@ -3510,7 +3705,7 @@ def make_tensor( :raises ValueError: If iterator type is not supported - Examples: + **Examples:** .. code-block:: python @@ -3522,7 +3717,7 @@ def make_tensor( layout = make_layout(((128, 8), (1, 4, 1)), stride=((32, 1), (0, 8, 4096))) tensor = make_tensor(smem_ptr, layout) - # Create a counting tensor + # Create a coord tensor layout = make_layout(2, stride=16 * E(0)) tensor = make_tensor(5, layout) @@ -3540,7 +3735,7 @@ def make_tensor( ty = None if is_integer(iterator) or isinstance(iterator, tuple): iterator = _pack_int_tuple(iterator, loc=loc, ip=ip) - ty = _cute_ir.CountingTensorType.get(iterator.type, layout.type) + ty = _cute_ir.CoordTensorType.get(iterator.type, layout.type) elif isinstance(iterator, Pointer): iterator = iterator.value ty = _cute_ir.MemRefType.get(iterator.type, layout.type) @@ -3568,17 +3763,17 @@ def make_identity_tensor(shape: Shape, *, loc=None, ip=None) -> Tensor: :return: A tensor that maps each coordinate to itself :rtype: Tensor - Examples: + **Examples:** .. code-block:: python - # Create a simple 1D counting tensor + # Create a simple 1D coord tensor tensor = make_identity_tensor(6) # [0,1,2,3,4,5] - # Create a 2D counting tensor + # Create a 2D coord tensor tensor = make_identity_tensor((3,2)) # [(0,0),(1,0),(2,0),(0,1),(1,1),(2,1)] - # Create hierarchical counting tensor + # Create hierarchical coord tensor tensor = make_identity_tensor(((2,1),3)) # [((0,0),0),((1,0),0),((0,0),1),((1,0),1),((0,0),2),((1,0),2)] @@ -3654,7 +3849,7 @@ def make_fragment_like(src, dtype=None, *, loc=None, ip=None): :return: A new layout or fragment tensor with matching shape :rtype: Union[Layout, Tensor] - **Examples** + **Examples:** Creating a rmem tensor from a tensor: @@ -3699,7 +3894,7 @@ def make_fragment_like(src, dtype=None, *, loc=None, ip=None): else: return new_layout elif isinstance(src, Tensor): - if isinstance(src.type, _cute_ir.CountingTensorType): + if isinstance(src.type, _cute_ir.CoordTensorType): if dtype is None: raise ValueError( "dtype must be provided when src is a coordinate tensor" @@ -4094,7 +4289,7 @@ def tile_to_shape( ip=None, ) -> Union[Layout, ComposedLayout]: trg_shape = _pack_shape(shape(trg_shape), loc=loc, ip=ip) - order = _pack_shape(order, loc=loc, ip=ip) + order = _pack_int_tuple(order, loc=loc, ip=ip) return _cute_ir.tile_to_shape(atom, trg_shape, order, loc=loc, ip=ip) @@ -4380,7 +4575,6 @@ class MmaAtom(Atom): ip=ip, ) - @dsl_user_op def make_fragment_B(self, input, *, loc=None, ip=None): if isinstance(input, _Tensor): @@ -4831,24 +5025,44 @@ def make_copy_atom( def make_layout_tv( thr_layout: Layout, val_layout: Layout, *, loc=None, ip=None ) -> Tuple[Shape, Layout]: - """ - Create a tiled copy given separate thr and val layouts. A TV partitioner is inferred based on inputs. - Requires input thr layout be compact. + """Create a thread-value layout for partitioning data tensors. - Parameters - ---------- - atom : copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc. - thr_layout : mn -> tid (need to be compact?) - val_layout : mn -> vid - loc : source location for mlir (optional) - ip : insertion point (optional) + This function creates a thread-value layout that maps between ``(thread_idx, value_idx)`` + coordinates and logical ``(M,N)`` coordinates. The thread layout must be compact to ensure + proper partitioning. - Returns - ------- - layout_mn - logical tile size - layout_tv - thread-value layout (tid, vid) -> mn + This implements the thread-value partitioning pattern shown in + Figure TVLayout, where data is partitioned across threads and values within each thread. + + :param thr_layout: Layout mapping from ``(TileM,TileN)`` coordinates to thread IDs (must be compact) + :type thr_layout: Layout + :param val_layout: Layout mapping from ``(ValueM,ValueN)`` coordinates to value IDs within each thread + :type val_layout: Layout + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tuple containing ``tiler_mn`` and ``layout_tv`` + :rtype: Tuple[Shape, Layout] + + where: + * ``tiler_mn`` is tiler and ``shape(tiler_mn)`` is compatible with ``shape(zipped_divide(x, tiler_mn))[0]`` + * ``layout_tv``: Thread-value layout mapping (thread_idx, value_idx) -> (M,N) + + **Example:** + + .. code-block:: python + + tiler_mn, layout_tv = cute.make_layout_tv( + cute.make_layout((4, 8), stride=(8, 1)), cute.make_layout(2, stride=1) + ) + + Above code creates a TV layout that maps between thread/value coordinates + and the logical coordinates in a 8x8 matrix with: + + * thread block layout ``(4,8):(8,1)`` + * 2 elements per thread """ # Take the raked_products to compute the Layout_MN @@ -4869,22 +5083,24 @@ def make_layout_tv( @dsl_user_op def make_tiled_copy_tv(atom, thr_layout, val_layout, *, loc=None, ip=None) -> TiledCopy: - """ - Create a tiled copy given separate thr and val layouts. A TV partitioner is inferred based on inputs. - Requires input thr layout be compact. + """Create a tiled copy given separate thread and value layouts. - Parameters - ---------- - atom : copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc. - thr_layout : mn -> tid (need to be compact?) - val_layout : mn -> vid - loc : source location for mlir (optional) - ip : insertion point (optional) + A TV partitioner is inferred based on the input layouts. The input thread layout + must be compact. - Returns - ------- - tiled_copy - A tiled copy for partitioner + :param atom: Copy atom + :type atom: CopyAtom + :param thr_layout: Layout mapping from ``(TileM,TileN)`` coordinates to thread IDs (must be compact) + :type thr_layout: Layout + :param val_layout: Layout mapping from ``(ValueM,ValueN)`` coordinates to value IDs + :type val_layout: Layout + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy """ tiler_mn, layout_tv = make_layout_tv(thr_layout, val_layout, loc=loc, ip=ip) @@ -4905,21 +5121,21 @@ def make_tiled_copy_tv(atom, thr_layout, val_layout, *, loc=None, ip=None) -> Ti @dsl_user_op def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): - """ - Create a tiled type given a TV partitioner and tiler + """Create a tiled type given a TV partitioner and tiler. - Parameters - ---------- - atom : copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc. - layout_tv : thread-value layout. - tiler_mn : tile size (??) - loc : source location for mlir (optional) - ip : insertion point (optional) + :param atom: Copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc. + :type atom: CopyAtom + :param layout_tv: Thread-value layout + :type layout_tv: Layout + :param tiler_mn: Tile size + :type tiler_mn: Tiler + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional - Returns - ------- - tiled_copy - A tuple of A tiled copy and atom + :return: A tiled copy for the partitioner + :rtype: TiledCopy """ # tiler_mn = pack_tuple(tiler_mn, make_tile) @@ -4941,20 +5157,19 @@ def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): @dsl_user_op def make_tiled_copy_S(atom, tiled_copy, *, loc=None, ip=None): - """ - Create a tiled type out of the copy_atom that matches the Src-Layout of tiled_copy. + """Create a tiled copy out of the copy_atom that matches the Src-Layout of tiled_copy. - Parameters - ---------- - atom : copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc. - tiled_copy : tiled copy - loc : source location for mlir (optional) - ip : insertion point (optional) + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_copy: Tiled copy + :type tiled_copy: TiledCopy + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional - Returns - ------- - tiled_copy - A tuple of A tiled copy and atom + :return: A tiled copy for the partitioner + :rtype: TiledCopy """ return make_tiled_copy( @@ -4964,20 +5179,19 @@ def make_tiled_copy_S(atom, tiled_copy, *, loc=None, ip=None): @dsl_user_op def make_tiled_copy_D(atom, tiled_copy, *, loc=None, ip=None): - """ - Create a tiled type out of the copy_atom that matches the Dst-Layout of tiled_copy. + """Create a tiled copy out of the copy_atom that matches the Dst-Layout of tiled_copy. - Parameters - ---------- - atom : copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc. - tiled_copy : tiled copy - loc : source location for mlir (optional) - ip : insertion point (optional) + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_copy: Tiled copy + :type tiled_copy: TiledCopy + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional - Returns - ------- - tiled_copy - A tuple of A tiled copy and atom + :return: A tiled copy for the partitioner + :rtype: TiledCopy """ return make_tiled_copy( @@ -4987,21 +5201,21 @@ def make_tiled_copy_D(atom, tiled_copy, *, loc=None, ip=None): @dsl_user_op def make_tiled_copy_C_atom(atom: CopyAtom, mma: TiledMma, *, loc=None, ip=None): - """ - Create the smallest tiled copy that can retile LayoutC_TV - for use with pipelined epilogues with subtiled stores + """Create the smallest tiled copy that can retile LayoutC_TV for use with pipelined epilogues with subtiled stores. - Parameters - ---------- - atom: CopyAtom - mma : TiledMma - loc : source location for mlir (optional) - ip : insertion point (optional) + :param atom: Copy atom + :type atom: CopyAtom + :param mma: Tiled MMA + :type mma: TiledMma + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional - Returns - ------- - tiled_copy - A tiled copy for partitioner + :return: A tiled copy for partitioner + :rtype: TiledCopy + + :raises ValueError: If the number value of CopyAtom's source layout is greater than the size of TiledMma's LayoutC_TV """ # Truncate the V-layout to just the Copy_Atom, keep the V-order layoutC_tv = mma.tv_layout_C_tiled @@ -5081,29 +5295,60 @@ def gemm( ip=None, **kwargs, ) -> None: - """ - The GEMM algorithm. + """The GEMM algorithm. Computes ``D <- AB + C`` where ``C`` and ``D`` can alias. Note that some MMA Atoms (e.g. warpgroup-wide or tcgen05 MMAs) require manually setting an "accumulate" boolean field. All tensors must be partitioned according to the provided MMA Atom. + + For MMA Atoms that require single-threaded execution, the gemm op automatically handles thread + election internally. Manual thread selection is not required in such cases. + + :param atom: MMA atom + :type atom: MmaAtom + :param d: Destination tensor + :type d: Tensor + :param a: First source tensor + :type a: Tensor + :param b: Second source tensor + :type b: Tensor + :param c: Third source tensor + :type c: Tensor + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point for MLIR, defaults to None + :type ip: Optional[InsertionPoint], optional + :param kwargs: Additional keyword arguments + :type kwargs: dict + :return: None + :rtype: None """ + value = atom._unpack(loc=loc, ip=ip, **kwargs) return _cute_ir.gemm(value, d.value, a.value, b.value, c.value, loc=loc, ip=ip) @dsl_user_op def basic_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: - """ - Performs a basic element-wise copy. + """Performs a basic element-wise copy. This functions **assumes** the following pre-conditions: 1. `size(src) == size(dst)` When the `src` and `dst` shapes are static, the pre-conditions are actually verified and the element-wise loop is fully unrolled. + + :param src: Source tensor + :type src: Tensor + :param dst: Destination tensor + :type dst: Tensor + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional """ + if is_static(src.shape) and is_static(dst.shape): simt_copy_ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( src.element_type.mlir_type, src.element_type.width @@ -5120,8 +5365,7 @@ def basic_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: @dsl_user_op def basic_copy_if(pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: - """ - Performs a basic predicated element-wise copy. + """Performs a basic predicated element-wise copy. This functions **assumes** the following pre-conditions: 1. `size(src) == size(dst)` @@ -5129,6 +5373,7 @@ def basic_copy_if(pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None) When all shapes are static, the pre-conditions are actually verified and the element-wise loop is fully unrolled. + """ if src.element_type.width != dst.element_type.width: raise NotImplementedError( @@ -5252,6 +5497,9 @@ def copy( An additional predication tensor can be provided. If the partitioned tensors have the following logical profile ``((ATOM_V,ATOM_REST),REST_M,...)``, the predication tensor must have a profile consistent with ``(ATOM_REST,REST_M,...)``. + + For Copy Atoms that require single-threaded execution, the copy op automatically handles thread + election internally. Manual thread selection is not required in such cases. """ if isinstance(src.type, _cute_ir.MemRefType) and isinstance( dst.type, _cute_ir.MemRefType @@ -5761,7 +6009,7 @@ class TensorSSA(cutlass_arith.ArithValue): :raises ValueError: If coordinate access is invalid for the tensor layout - Examples: + **Examples:** .. code-block:: python @@ -5903,9 +6151,13 @@ class TensorSSA(cutlass_arith.ArithValue): :return: The reduced tensor :rtype: TensorSSA - Examples: + **Examples:** + + .. code-block:: python + reduce(f32 o (4,)) => f32 + reduce(f32 o (4, 5)) => f32 reduce(f32 o (4, (5, 4)), reduction_profile=(_, 1)) @@ -5917,6 +6169,12 @@ class TensorSSA(cutlass_arith.ArithValue): if reduction_profile is None: return self + if not is_weakly_congruent(reduction_profile, self.shape): + raise ValueError( + f"Expect reduction_profile be weakly congruent to the shape of the tensor, " + f"but got {reduction_profile} and {self.shape}" + ) + if op is ReductionOp.ADD: red_kind = vector.CombiningKind.ADD elif op is ReductionOp.MUL: @@ -6008,7 +6266,7 @@ def full(shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> Tenso def full_like( - a: TensorSSA, + a: Union[TensorSSA, Tensor], fill_value, dtype: Union[None, Type[Numeric]] = None, *, @@ -6033,14 +6291,17 @@ def full_like( :func:`zeros_like`: Return an array of zeros with shape and type of input. :func:`full`: Return a new array of given shape filled with value. - Examples - -------- + **Examples:** + .. code-block:: python frg = cute.make_fragment(Float32, (2, 3)) a = frg.load() b = cute.full_like(a, 1.0) """ + if not hasattr(a, "shape"): + raise TypeError(f"Expect `a` be shaped type, but got {type(a)}") + return full( a.shape, fill_value, dtype if dtype is not None else a.dtype, loc=loc, ip=ip ) diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py index 322e8bf0..b4b88031 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py @@ -26,7 +26,7 @@ __all__ = [ # # helpers.py # - "make_tma_tile_atom", + "make_tiled_tma_atom", "tma_partition", "create_tma_multicast_mask", "prefetch_descriptor", diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py index 8de65a72..8744a376 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py @@ -127,7 +127,12 @@ class CopyBulkTensorTileG2SOp(CopyOp): cta_group: CtaGroup = CtaGroup.ONE - admissible_archs = ["sm_90", "sm_90a", "sm_100a"] + admissible_archs = [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ] def __post_init__(self) -> None: if not isinstance(self.cta_group, CtaGroup): @@ -159,7 +164,7 @@ class CopyBulkTensorTileG2SOp(CopyOp): self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs ) -> "CopyBulkTensorTileG2SNonExecTrait": raise NotImplementedError( - "Use cpasync.make_tma_tile_atom to obtain a copy Atom for TMA" + "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" ) def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum: @@ -224,7 +229,12 @@ class CopyBulkTensorTileG2SMulticastOp(CopyOp): cta_group: CtaGroup = CtaGroup.ONE - admissible_archs = ["sm_90", "sm_90a", "sm_100a"] + admissible_archs = [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ] def __post_init__(self): if not isinstance(self.cta_group, CtaGroup): @@ -256,7 +266,7 @@ class CopyBulkTensorTileG2SMulticastOp(CopyOp): self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs ) -> "CopyBulkTensorTileG2SMulticastNonExecTrait": raise NotImplementedError( - "Use cpasync.make_tma_tile_atom to obtain a copy Atom for TMA" + "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" ) def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum: @@ -326,7 +336,12 @@ class CopyBulkTensorTileS2GOp(CopyOp): This Operation uses TMA in the ``.tile`` mode. """ - admissible_archs = ["sm_90", "sm_90a", "sm_100a"] + admissible_archs = [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ] def __post_init__(self): # Arch verification @@ -345,7 +360,7 @@ class CopyBulkTensorTileS2GOp(CopyOp): self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs ) -> "CopyBulkTensorTileS2GTrait": raise NotImplementedError( - "Use cpasync.make_tma_tile_atom to obtain a copy Atom for TMA" + "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" ) diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py index 92f028a2..89f7061e 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py @@ -29,14 +29,14 @@ from .copy import ( @dsl_user_op -def make_tma_tile_atom( +def make_tiled_tma_atom( op: Union[ CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp, ], gmem_tensor: Tensor, - smem_layout: Layout, + smem_layout: Union[Layout, core.ComposedLayout], cta_tiler: Tiler, num_multicast: int = 1, *, @@ -45,7 +45,7 @@ def make_tma_tile_atom( ip=None, ) -> Tuple[core.CopyAtom, Tensor]: """ - Makes a TMA Copy Atom in the ``.tile`` mode to copy tiles of a GMEM tensor to/from and SMEM + Makes a TMA Copy Atom in the ``.tile`` mode to copy tiles of a GMEM tensor to/from SMEM buffer with the given Layout. Given @@ -71,7 +71,7 @@ def make_tma_tile_atom( :param gmem_tensor: The GMEM tensor involved in the Copy :type gmem_tensor: Tensor :param smem_layout: The SMEM layout to construct the Copy Atom for - :type smem_layout: Layout + :type smem_layout: Union[Layout, core.ComposedLayout] :param cta_tiler: The CTA Tiler to use :type cta_tiler: Tiler :param num_multicast: The multicast factor @@ -94,6 +94,12 @@ def make_tma_tile_atom( ip=ip, ) + # Wrap smem_layout in a composed layout to make it a TMA-friendly layout + if isinstance(smem_layout, Layout): + smem_layout = core.make_composed_layout( + core.make_swizzle(0, 4, 3), 0, smem_layout + ) + if isinstance(op, CopyBulkTensorTileG2SOp): if num_multicast != 1: raise ValueError( diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py index 020b96d8..ccd06d01 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py @@ -34,7 +34,7 @@ from .cpasync.copy import ( @dsl_user_op -def make_tma_tile_atom_A( +def make_tiled_tma_atom_A( op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], gmem_tensor: Tensor, smem_layout: Layout, @@ -46,6 +46,51 @@ def make_tma_tile_atom_A( loc=None, ip=None, ) -> Tuple[core.CopyAtom, Tensor]: + """ + Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation + accounting for the MK projections of the TiledMMA for A tensor loads. + + Given + + - a GMEM tensor + - a SMEM layout + - a MMA Tiler + - a TiledMma + - a Cluster-level shape + + this function figures out the bulk tensor asynchronous copy instruction to use with the maximum + "TMA vector length" to copy tiles of the GMEM tensor to an SMEM buffer with the provided + layout and consistent with the provided Tiler & tiled_mma (considering the M-mode & K-mode). + The Cluster-level shape is used to determine the multicast factor across the N-mode for A tensor loads. + + This function returns two results: + + 1. the Copy Atom + 2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates + that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the + associated layout can output coordinates. Otherwise, TMA tensors can be partitioned + similarly to any other CuTe tensors using the algebra. + + :param op: The Copy Operation to construct an Atom for + :type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp] + :param gmem_tensor: The GMEM tensor to be loaded by this copy atom + :type gmem_tensor: Tensor + :param smem_layout: Shared memory layout to load the tensor into (PDSL) + :type smem_layout: Layout + :param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions + :type mma_tiler_mnk: Shape + :param tiled_mma: The TiledMMA that will consume the load as operands + :type tiled_mma: core.TiledMma + :param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions + :type cluster_shape_vmnk: Shape + :param internal_type: An optional parameter for the internal data type to when element + type does not match the copy type + :type internal_type: Type[Numeric] + :return: A copy atom for this operation and the associated TMA coord tensor + :rtype: Tuple[core.CopyAtom, Tensor] + + """ + if internal_type is not None: if not isinstance(internal_type, NumericMeta): raise TypeError(f"internal_type must be a Numeric, but got {internal_type}") @@ -54,7 +99,7 @@ def make_tma_tile_atom_A( op, [CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], "op", - "make_tma_tile_atom_A", + "make_tiled_tma_atom_A", ) ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip) @@ -94,7 +139,7 @@ def make_tma_tile_atom_A( @dsl_user_op -def make_tma_tile_atom_B( +def make_tiled_tma_atom_B( op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], gmem_tensor: Tensor, smem_layout: Layout, @@ -106,6 +151,51 @@ def make_tma_tile_atom_B( loc=None, ip=None, ) -> Tuple[core.CopyAtom, Tensor]: + """ + Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation + accounting for the NK projections of the TiledMMA for B tensor loads. + + Given + + - a GMEM tensor + - a SMEM layout + - a MMA Tiler + - a TiledMma + - a Cluster-level shape + + this function figures out the bulk tensor asynchronous copy instruction to use with the maximum + "TMA vector length" to copy tiles of the GMEM tensor to an SMEM buffer with the provided + layout and consistent with the provided Tiler & tiled_mma (considering the N-mode & K-mode). + The Cluster-level shape is used to determine the multicast factor across the M-mode for B tensor loads. + + This function returns two results: + + 1. the Copy Atom + 2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates + that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the + associated layout can output coordinates. Otherwise, TMA tensors can be partitioned + similarly to any other CuTe tensors using the algebra. + + :param op: The Copy Operation to construct an Atom for + :type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp] + :param gmem_tensor: The GMEM tensor to be loaded by this copy atom + :type gmem_tensor: Tensor + :param smem_layout: Shared memory layout to load the tensor into (PDSL) + :type smem_layout: Layout + :param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions + :type mma_tiler_mnk: Shape + :param tiled_mma: The TiledMMA that will consume the load as operands + :type tiled_mma: core.TiledMma + :param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions + :type cluster_shape_vmnk: Shape + :param internal_type: An optional parameter for the internal data type to when element + type does not match the copy type + :type internal_type: Type[Numeric] + :return: A Copy Atom for this Operation and the associated TMA tensor + :rtype: Tuple[core.CopyAtom, Tensor] + + """ + if internal_type is not None: if not isinstance(internal_type, NumericMeta): raise TypeError(f"internal_type must be a Numeric, but got {internal_type}") @@ -114,7 +204,7 @@ def make_tma_tile_atom_B( op, [CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], "op", - "make_tma_tile_atom_B", + "make_tiled_tma_atom_B", ) ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip) @@ -154,6 +244,6 @@ def make_tma_tile_atom_B( __all__ = [ - "make_tma_tile_atom_A", - "make_tma_tile_atom_B", + "make_tiled_tma_atom_A", + "make_tiled_tma_atom_B", ] diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py index 283cf8fb..1c4439a0 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py @@ -98,7 +98,10 @@ class _LdBase(CopyOp): repeat: Repetition = Repetition.x1 pack: Pack = Pack.NONE - admissible_archs = ["sm_100a"] + admissible_archs = [ + "sm_100a", + "sm_100f", + ] def __post_init__(self) -> None: # Arch verification @@ -284,7 +287,10 @@ class _StBase(CopyOp): repeat: Repetition unpack: Unpack = Unpack.NONE - admissible_archs = ["sm_100a"] + admissible_archs = [ + "sm_100a", + "sm_100f", + ] def __post_init__(self) -> None: # Arch verification diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py index b5d681cf..7f8f0de9 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py @@ -136,7 +136,10 @@ class MmaOp(MmaOp): a_major_mode: OperandMajorMode b_major_mode: OperandMajorMode - admissible_archs = ["sm_100a"] + admissible_archs = [ + "sm_100a", + "sm_100f", + ] def __post_init__(self) -> None: # Verify arch diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py index ca9177f3..275861f7 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py @@ -339,10 +339,10 @@ class MmaF8Op(MmaOp): "expects the 'b_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN", ) # Accumulator data type verification - if self.acc_dtype != Float32: + if self.acc_dtype not in [Float16, Float32]: raise OpError( self, - "expects the 'acc_dtype' Op parameter to be Float32", + "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", ) # Verify the instruction shape instruction_k = 32 diff --git a/python/CuTeDSL/cutlass/cute/runtime.py b/python/CuTeDSL/cutlass/cute/runtime.py index ec0ccc5f..c86890d3 100644 --- a/python/CuTeDSL/cutlass/cute/runtime.py +++ b/python/CuTeDSL/cutlass/cute/runtime.py @@ -20,6 +20,7 @@ from typing import Union from cutlass._mlir import ir import cutlass._mlir.dialects.cute as _cute_ir +from cutlass.base_dsl.dsl import is_dynamic_expression from cutlass.cutlass_dsl import TensorFormat, JitArgAdapterRegistry # Local modules imports @@ -45,7 +46,8 @@ from .typing import ( BFloat16, Float8E5M2, ) -from .core import find, _Tensor as CoreTensor +from . import core +from .core import _Tensor as CoreTensor class _Pointer(Pointer): @@ -131,6 +133,9 @@ class _Pointer(Pointer): def memspace(self): return self._addr_space + def align(self, min_align: int, *, loc=None, ip=None) -> Pointer: + raise NotImplementedError("align is not supported in runtime") + def verify(self, expected_py_type): if expected_py_type is Pointer: return True @@ -361,7 +366,7 @@ class _Tensor(Tensor): * If nested leading dimensions are found, returns a tuple of indices * If no leading dimension is found, returns None """ - return find(1, self.stride, exclude_when=(1, self.shape)) + return core.leading_dim(self.shape, self.stride) def fill(self, value: Numeric): raise TypeError(f"fill function is not supported in runtime") @@ -479,12 +484,8 @@ class TensorAdapter: Convert a DLPack protocol supported tensor/array to a cute tensor. """ - # Need reference these capsules to avoid being garbage collected - tensor_capsules = [] - def __init__(self, arg): self._arg = from_dlpack(arg).mark_layout_dynamic() - self.tensor_capsules.append(self._arg) def __new_from_mlir_values__(self, values): return self._arg.__new_from_mlir_values__(values) diff --git a/python/CuTeDSL/cutlass/cute/testing.py b/python/CuTeDSL/cutlass/cute/testing.py index 2d81a9c7..4a1bb016 100644 --- a/python/CuTeDSL/cutlass/cute/testing.py +++ b/python/CuTeDSL/cutlass/cute/testing.py @@ -9,29 +9,26 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -import random -import numpy as np import functools -import hashlib - -from cutlass.cutlass_dsl import ( - const, - T, - CuTeDSL, - BaseDSL, - t, - Constexpr, - detect_gpu_arch, -) - -import cutlass._mlir.dialects.cute as _cute_ir -import cutlass._mlir.ir as ir -from cutlass._mlir.dialects import nvvm, cf, vector, builtin - -from cutlass.cute import core -from cutlass.cute import nvgpu -from typing import Type +import inspect +import logging +import os +from enum import Enum from inspect import isclass +from itertools import product +from time import time +from typing import Any, Callable, Dict, List, Optional, Type, Union + +import cuda.bindings.driver as cuda_driver +import cuda.bindings.runtime as cuda_runtime +import numpy as np + +import cutlass._mlir.ir as ir +import cutlass.base_dsl.jit_executor +import cutlass.cute as cute +from cutlass._mlir.dialects import builtin, cf, nvvm, vector +from cutlass.cute import core, nvgpu +from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, t def assert_(cond, msg=None): @@ -248,9 +245,10 @@ def sample_pytest(rand_cfg=None): import functools import os import random - import pytest import sys + import pytest + seed, sample_ratio = rand_cfg random.seed(seed) @@ -270,3 +268,311 @@ def sample_pytest(rand_cfg=None): return wrapper return decorator + + +######################################### +# Benchmarking utilities +######################################### + + +class JitArguments: + """ + A type to hold both args and kwargs for passing to a kernel while benchmarking. + """ + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + +def _cuda_success( + err: Union[tuple, cuda_runtime.cudaError_t, cuda_driver.CUresult], message: str +): + """ + Helper function to check CUDA API errors. + """ + if isinstance(err, tuple): + _cuda_success(err[0], message) + elif isinstance(err, cuda_runtime.cudaError_t): + error_message = cuda_runtime.cudaGetErrorString(err)[1].decode("utf-8") + if err != cuda_runtime.cudaError_t.cudaSuccess: + raise RuntimeError(f"{message} : {error_message}") + elif isinstance(err, cuda_driver.CUresult): + if err != cuda_driver.CUresult.CUDA_SUCCESS: + error_message = cuda_driver.cuGetErrorString(err)[1].decode("utf-8") + raise RuntimeError(f"{message} : {error_message}") + else: + raise TypeError( + f"{err} is an unexpected type : it should be a cudaError_t or CUresult" + ) + + +def _does_kernel_use_stream( + kernel: Callable, stream: cuda_driver.CUstream, *args, **kwargs +): + """ + This function checks if the kernel uses the provided non-default stream. + It does this by capturing the stream and then checking if any kernels were launched. + :param kernel: The kernel to check + :type kernel: Callable + :param stream: The stream to check + :type stream: cuda_driver.CUstream + :return: True if the kernel uses the stream, False otherwise + :rtype: bool + """ + + assert int(stream) != int( + cuda_driver.CUstream_flags.CU_STREAM_DEFAULT + ), "Stream must be a non-default stream" + + err = cuda_runtime.cudaStreamBeginCapture( + stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal + ) + _cuda_success(err, "Error on stream capture") + + kernel(*args, **kwargs) + + err, graph = cuda_runtime.cudaStreamEndCapture(stream) + _cuda_success(err, "Error on stream capture") + + # Get number of nodes in warmup graph to check it matches what is expected + err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(graph) + _cuda_success(err, "Error on querying graph") + return num_nodes > 0 + + +def benchmark( + callable: Callable, + *, + warmup_iterations: int = 10, + profiling_iterations: int = 100, + stream: Optional[cuda_driver.CUstream] = None, + kernel_arguments: Optional[JitArguments] = None, + workspace_generator: Optional[Callable[[], JitArguments]] = None, + workspace_count: int = 1, + use_cuda_graphs: bool = False, +) -> float: + """Benchmarks a callable function with the specified parameters. + + For example, + .. code-block:: python + + from cutlass.cute.testing import benchmark + + @cute.jit + def user_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda_driver.CUstream): + # contents of the function + pass + + time_us = benchmark(user_function, kernel_arguments=JitArguments(a, b, c, stream) + warmup_iterations=10, profiling_iterations=100 + stream=stream) + + To prevent skewing results by repeately accessing the L2 cache, use the workspace_count and workspace_generator + parameters to cycle through a number of different workspaces. + + .. code-block:: python + + from cutlass.cute.testing import benchmark + + @cute.jit + def user_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor): + # contents of the function + pass + + def workspace_generator(): + # create a, b, and c + return JitArguments(a, b, c) + + time_us = benchmark(user_function, + workspace_generator=workspace_generator, + workspace_count=10, + warmup_iterations=10000, + profiling_iterations=1000) + + To benchmark you may always configure the function being profiled (callable), the warmup iterations, and + the number of profiling iterations. + + Whenever the kernel being benchmarked runs in a non-default stream, the stream must be provided through the stream parameter. + + To use CUDA graphs, the callable must be a compiled @cute.jit annotated function. + When using CUDA graphs, the kernel must be launched in a non-default stream. + + :param callable: The function to benchmark + :type callable: Callable + :param warmup_iterations: Number of warmup iterations, defaults to 10 + :type warmup_iterations: int, optional + :param profiling_iterations: Number of benchmark iterations, defaults to 100 + :type profiling_iterations: int, optional + :param stream: Stream kernel is launched in, defaults to CUDA stream default + :type stream: CUstream, None + :param kernel_arguments: Kernel arguments to launch callable with, defaults to None + :type kernel_arguments: JitArguments, None + :param workspace_generator: Function that returns kernel arguments, defaults to None + :type workspace_generator: Callable + :param workspace_count: Number of workspaces (arguments) to loop through, looping through enough workspaces will keep the L2 cache cold + :type workspace_count: int, optional + :param use_cuda_graphs: Whether to use cuda graphs, defaults to False + :type use_cuda_graphs: bool, optional + + :return: The benchmark time in microseconds + :rtype: float + """ + + if stream is None: + stream = cuda_driver.CUstream(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT) + + if workspace_count < 1: + raise ValueError("workspace_count must be at least 1") + + time_us = float("nan") + if workspace_generator == None: + # If no workspace generator is provided, we need a single workspace + if workspace_count != 1: + raise ValueError("Need a single workspace if not providing a generator") + + # If no workspace generator is provided, we need a kernel_argument + if kernel_arguments == None: + raise ValueError( + "Please pass a kernel argument if not providing a generator" + ) + workspace_generator = lambda: kernel_arguments + + workspaces = [workspace_generator() for _ in range(workspace_count)] + + for workspace in workspaces: + if type(workspace) != JitArguments: + raise TypeError( + "workspace_generator and/or kernel_arguments should use JitArguments type" + ) + + def _loop_and_call_kernel(iterations: int, workspace_index: int = 0): + for _ in range(iterations): + current_workspace = workspaces[workspace_index] + callable(*current_workspace.args, **current_workspace.kwargs) + workspace_index = (workspace_index + 1) % workspace_count + return workspace_index + + # Create CUDA events for timing + err, start_event = cuda_driver.cuEventCreate( + cuda_driver.CUevent_flags.CU_EVENT_DEFAULT + ) + _cuda_success(err, "Error on creating event") + err, end_event = cuda_driver.cuEventCreate( + cuda_driver.CUevent_flags.CU_EVENT_DEFAULT + ) + _cuda_success(err, "Error on creating event") + + elapsed_time = float("nan") + + if use_cuda_graphs: + # Check if the callable is a JitExecutor + if not isinstance(callable, cutlass.base_dsl.jit_executor.JitExecutor): + raise TypeError("Function must be precompiled to be used with CUDA Graphs") + + # Check if the stream is a non-default stream + if int(stream) == int(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT): + raise ValueError( + "Measuring with CUDA Graphs requires executing in a non-default stream" + ) + + workspace_index = 0 + + # Capture warmup graph + err = cuda_runtime.cudaStreamBeginCapture( + stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal + ) + _cuda_success(err, "Error on stream capture") + + workspace_index = _loop_and_call_kernel(warmup_iterations) + err, gwarm = cuda_runtime.cudaStreamEndCapture(stream) + _cuda_success(err, "Error on stream capture") + + # Get number of nodes in warmup graph to check it matches what is expected + err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(gwarm) + _cuda_success(err, "Error on querying graph") + # Assertion is >= since we may launch multiple kernels in one host function + if num_nodes < warmup_iterations: + raise ValueError( + f"CUDA stream passed to benchmark does not match the stream the kernel was launched in" + ) + + # Capture profiling graph + err = cuda_runtime.cudaStreamBeginCapture( + stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal + ) + _cuda_success(err, "Error on stream capture") + _loop_and_call_kernel(profiling_iterations, workspace_index) + err, gprofile = cuda_runtime.cudaStreamEndCapture(stream) + _cuda_success(err, "Error on stream capture") + + # Instantiate graphs + err, gwarm = cuda_runtime.cudaGraphInstantiate(gwarm, 0) + _cuda_success(err, "Error on graph instantiation") + err, gprofile = cuda_runtime.cudaGraphInstantiate(gprofile, 0) + _cuda_success(err, "Error on graph instantiation") + + # Launch warmup graph + err = cuda_runtime.cudaGraphLaunch(gwarm, stream) + _cuda_success(err, "Error on graph launch") + + # Record start time + err = cuda_driver.cuEventRecord(start_event, stream) + _cuda_success(err, "Error on recording event") + + # Launch profiling graph + err = cuda_runtime.cudaGraphLaunch(gprofile, stream) + _cuda_success(err, "Error on graph launch") + + # Record end time + err = cuda_driver.cuEventRecord(end_event, stream) + _cuda_success(err, "Error on recording event") + err = cuda_driver.cuEventSynchronize(end_event) + _cuda_success(err, "Error on synchronizing event") + + # Get elapsed time + err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event) + _cuda_success(err, "Error on querying event") + + # Destroy graphs + err = cuda_runtime.cudaGraphExecDestroy(gwarm) + _cuda_success(err, "Error on destroying graph") + err = cuda_runtime.cudaGraphExecDestroy(gprofile) + _cuda_success(err, "Error on destroying graph") + + else: + + if int(stream) != int( + cuda_driver.CUstream_flags.CU_STREAM_DEFAULT + ) and not _does_kernel_use_stream( + callable, stream, *workspaces[0].args, **workspaces[0].kwargs + ): + raise ValueError( + "CUDA stream passed to benchmark does not match the stream the kernel was launched in" + ) + + # Not using graphs + # Warmup + workspace_index = _loop_and_call_kernel(warmup_iterations) + # Record start event + err = cuda_driver.cuEventRecord(start_event, stream) + _cuda_success(err, "Error on recording event") + _loop_and_call_kernel(profiling_iterations, workspace_index) + # Record end event + err = cuda_driver.cuEventRecord(end_event, stream) + _cuda_success(err, "Error on recording event") + # Synchronize end event + err = cuda_driver.cuEventSynchronize(end_event) + _cuda_success(err, "Error on synchronizing event") + err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event) + _cuda_success(err, "Error on querying event") + + # Destroy events + err = cuda_driver.cuEventDestroy(start_event) + _cuda_success(err, "Error on destroying event") + err = cuda_driver.cuEventDestroy(end_event) + _cuda_success(err, "Error on destroying event") + + return elapsed_time / profiling_iterations * 1e3 + + diff --git a/python/CuTeDSL/cutlass/cute/typing.py b/python/CuTeDSL/cutlass/cute/typing.py index a13ebeaf..215e71d9 100644 --- a/python/CuTeDSL/cutlass/cute/typing.py +++ b/python/CuTeDSL/cutlass/cute/typing.py @@ -68,6 +68,8 @@ class Pointer(ABC): @property def dtype(self) -> Type[Numeric]: ... + def align(self, min_align: int) -> "Pointer": ... + def __get_mlir_types__(self) -> List[ir.Type]: ... def __extract_mlir_values__(self) -> List[ir.Value]: ... diff --git a/python/CuTeDSL/cutlass/pipeline/__init__.py b/python/CuTeDSL/cutlass/pipeline/__init__.py new file mode 100644 index 00000000..d2729787 --- /dev/null +++ b/python/CuTeDSL/cutlass/pipeline/__init__.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .helpers import ( + Agent, + CooperativeGroup, + PipelineOp, + SyncObject, + MbarrierArray, + NamedBarrier, + TmaStoreFence, + PipelineUserType, + PipelineState, + make_pipeline_state, + pipeline_init_wait, + arrive, + arrive_unaligned, + wait, + wait_unaligned, + arrive_and_wait, + sync, +) + +from .sm90 import ( + PipelineAsync, + PipelineTmaAsync, + PipelineTmaMultiConsumersAsync, + PipelineTmaStore, +) + +from .sm100 import ( + PipelineTmaUmma, + PipelineAsyncUmma, + PipelineUmmaAsync, +) + +__all__ = [ + "Agent", + "CooperativeGroup", + "PipelineOp", + "SyncObject", + "MbarrierArray", + "NamedBarrier", + "TmaStoreFence", + "PipelineUserType", + "PipelineState", + "PipelineAsync", + "PipelineTmaAsync", + "PipelineTmaUmma", + "PipelineTmaMultiConsumersAsync", + "PipelineAsyncUmma", + "PipelineUmmaAsync", + "PipelineTmaStore", +] diff --git a/python/CuTeDSL/cutlass/pipeline/helpers.py b/python/CuTeDSL/cutlass/pipeline/helpers.py new file mode 100644 index 00000000..68acfdab --- /dev/null +++ b/python/CuTeDSL/cutlass/pipeline/helpers.py @@ -0,0 +1,645 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Union +import warnings + +import cutlass.cute as cute +from cutlass.cutlass_dsl import Boolean, Int32, Int64, if_generate +from cutlass._mlir.dialects import llvm +import cutlass._mlir.dialects.cute as _cute_ir + + +############################################################################## +# Agent class +############################################################################## + + +class Agent(enum.Enum): + """ + Agent indicates what is participating in the pipeline synchronization. + """ + + # Arbitrary grouping of N threads + Thread = enum.auto() + # Same as AsyncThread, but includes all threads in the block + ThreadBlock = enum.auto() + # Same as AsyncThread, but includes all threads in the cluster + ThreadBlockCluster = enum.auto() + + +class CooperativeGroup: + """ + CooperativeGroup contains size and alignment restrictions for an Agent. + """ + + def __init__(self, agent: Agent, size: int = 1, alignment: int = 1): + if agent is Agent.Thread: + assert size > 0 + if size == 32: + assert ( + size == alignment + ), "Error: Alignment does not match number of threads in a warp." + elif size == 128: + assert ( + size == alignment + ), "Error: Alignment does not match number of threads in a warpgroup." + elif agent is Agent.ThreadBlock: + raise NotImplementedError("Error: Not yet supported.") + elif agent is Agent.ThreadBlockCluster: + raise NotImplementedError("Error: Not yet supported.") + else: + # Should never reach this state + size = 0 + + if size <= 0: + raise ValueError( + "Error: The number of threads in a CooperativeGroup must be more than 0." + ) + + # Size indicates how many threads are participating in this CooperativeGroup + self.size = size + # Agent indicates the type of thread group + self.agent = agent + + +class PipelineOp(enum.Enum): + """ + PipelineOp assigns an operation to an agent corresponding to a specific hardware feature. + """ + + # async-threads + AsyncThread = enum.auto() + # Blackwell (SM100a) MMA instruction + TCGen05Mma = enum.auto() + # Tensor Memory Accelerator load + TmaLoad = enum.auto() + # TMA Store consuming smem produced by AsyncThread + TmaStore = enum.auto() + # Composite of multiple PipelineOps + Composite = enum.auto() + + +def _get_pipeline_op(type_str): + return PipelineOp(type_str) + + +############################################################################## +# SyncObject class +############################################################################## + + +class SyncObject(ABC): + """Abstract base class for hardware synchronization primitives. + + This class defines the interface for different types of hardware synchronization + mechanisms including shared memory barriers, named barriers, and fences. + """ + + @abstractmethod + def arrive(self) -> None: + pass + + @abstractmethod + def wait(self) -> None: + pass + + @abstractmethod + def arrive_and_wait(self) -> None: + pass + + @abstractmethod + def arrive_and_drop(self) -> None: + pass + + @abstractmethod + def get_barrier(self) -> Union[cute.Pointer, int, None]: + pass + + @abstractmethod + def max(self) -> Union[int, None]: + pass + + +class MbarrierArray(SyncObject): + """ + MbarrierArray implements an abstraction for an array of smem barriers. + """ + + def __init__( + self, + barrier_storage: cute.Pointer, + num_stages: int, + agent: tuple[PipelineOp, CooperativeGroup], + tx_count: int = 0, + ) -> None: + self.barrier_storage = barrier_storage + self.tx_count = tx_count + self.num_stages = num_stages + self.op_type, self.cg = agent + self.arrive_count = self.cg.size + + if self.num_stages <= 0: + raise ValueError("Error: Mbarrier stage count must be greater than 0.") + if self.arrive_count <= 0: + raise ValueError("Error: Mbarrier arrive count must be greater than 0.") + if self.op_type is PipelineOp.TmaLoad and self.tx_count < 0: + raise ValueError( + "Error: Mbarrier tx count must not be less than 0 for TMA ops." + ) + + # Store mbarrier base pointer + self.mbarrier_base = self.barrier_storage + + # Mbarrier initialization in constructor + self.mbarrier_init() + + def recast_to_new_op_type(self, new_op_type: PipelineOp) -> "MbarrierArray": + """ + Creates a copy of MbarrierArray with a different op_type without re-initializing barriers + """ + # Create new instance without initialization + new_mbarrier_array = object.__new__(MbarrierArray) + + # Copy all attributes directly + new_mbarrier_array.barrier_storage = self.barrier_storage + new_mbarrier_array.op_type = new_op_type + new_mbarrier_array.cg = self.cg + new_mbarrier_array.num_stages = self.num_stages + new_mbarrier_array.tx_count = self.tx_count + new_mbarrier_array.arrive_count = self.arrive_count + new_mbarrier_array.mbarrier_base = self.mbarrier_base + return new_mbarrier_array + + # Mbarrier initialization + def mbarrier_init(self) -> None: + """ + Initializes an array of mbarriers using warp 0. + """ + + def then_body(): + for index in range(self.num_stages): + cute.arch.mbarrier_init(self.get_barrier(index), self.arrive_count) + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + if_generate(warp_idx == 0, then_body) + + def arrive( + self, + index: int, + dst: int, + cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None, + ) -> None: + """Select the arrive corresponding to this MbarrierArray's PipelineOp. + + :param index: Index of the mbarrier in the array to arrive on + :type index: int + :param dst: Destination parameter for selective arrival, which can be either a mask or destination cta rank. + When None, both ``TCGen05Mma`` and ``AsyncThread`` will arrive on their local mbarrier. + - For ``TCGen05Mma``, ``dst`` serves as a multicast mask (e.g., 0b1011 allows arrive signal to be multicast to CTAs + in the cluster with rank = 0, 1, and 3). + - For ``AsyncThread``, ``dst`` serves as a destination cta rank (e.g., 3 means threads will arrive on + the mbarrier with rank = 3 in the cluster). + :type dst: int | None + :param cta_group: CTA group for ``TCGen05Mma``, defaults to None for other op types + :type cta_group: ``cute.nvgpu.tcgen05.CtaGroup``, optional + """ + if self.op_type is PipelineOp.AsyncThread: + self.arrive_mbarrier(index, dst) + elif self.op_type is PipelineOp.TCGen05Mma: + assert ( + cta_group is not None + ), "Error: CTA group must be provided for TCGen05Mma." + self.arrive_tcgen05mma(index, dst, cta_group) + elif self.op_type in [PipelineOp.TmaLoad]: + self.arrive_and_expect_tx(index, self.tx_count) + else: + assert ( + False + ), f"Error: MbarrierArray is not supported for PipelineOp: {_get_pipeline_op(self.op_type)}." + + def arrive_mbarrier(self, index: int, dst_rank: Optional[int] = None) -> None: + if dst_rank is None: + cute.arch.mbarrier_arrive(self.get_barrier(index)) + else: + cute.arch.mbarrier_arrive(self.get_barrier(index), dst_rank) + + def arrive_tcgen05mma( + self, index: int, mask: Optional[int], cta_group: cute.nvgpu.tcgen05.CtaGroup + ) -> None: + if mask is None: + with cute.arch.elect_one(): + cute.nvgpu.tcgen05.commit(self.get_barrier(index)) + else: + with cute.arch.elect_one(): + cute.nvgpu.tcgen05.commit(self.get_barrier(index), mask, cta_group) + + def arrive_and_expect_tx(self, index: int, tx_count: int) -> None: + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(self.get_barrier(index), tx_count) + + def try_wait(self, index: int, phase: int) -> Boolean: + return cute.arch.mbarrier_try_wait(self.get_barrier(index), phase) + + def wait(self, index: int, phase: int) -> None: + cute.arch.mbarrier_wait(self.get_barrier(index), phase) + + def arrive_and_wait( + self, + index: int, + phase: int, + dst: int, + cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None, + ) -> None: + arrive(index, dst, cta_group) + wait(index, phase) + + def arrive_and_drop(self) -> None: + raise NotImplementedError("Error: Not yet supported.") + + def get_barrier(self, index: int) -> cute.Pointer: + return self.mbarrier_base + index + + def max(self) -> int: + # Transaction barriers have a maximum arrive count of 511 (2^9 - 1). + # Non-transaction barriers have a maximum arrive count of 1,048,575 (2^20 - 1). + return 511 + + def __extract_mlir_values__(self): + return [self.barrier_storage] + + def __new_from_mlir_values__(self, values): + return MbarrierArray( + values[0], self.num_stages, (self.op_type, self.cg), self.tx_count + ) + + +@dataclass(frozen=True) +class NamedBarrier(SyncObject): + """ + NamedBarrier is an abstraction for named barriers managed by hardware. + There are 16 named barriers available, with barrier_ids 0-15. + + See the `PTX documentation `__. + """ + + barrier_id: int + num_threads: int + + def __post_init__(self) -> None: + if self.barrier_id < 0 or self.barrier_id >= 16: + raise ValueError("Error: NamedBarrier ID must be between 0 and 16.") + if self.barrier_id == 0: + warnings.warn( + "NamedBarrier ID 0 is by other driver APIs (i.e. sync_threads()) and should not be used." + ) + + def arrive(self) -> None: + """ + The aligned flavor of arrive is used when all threads in the CTA will execute the + same instruction. See PTX documentation. + """ + cute.arch.barrier_arrive( + barrier_id=self.barrier_id, number_of_threads=self.num_threads + ) + + def arrive_unaligned(self) -> None: + """ + The unaligned flavor of arrive can be used with an arbitrary number of threads in the CTA. + """ + llvm.inline_asm( + None, + [Int32(self.barrier_id).ir_value(), Int32(self.num_threads).ir_value()], + "barrier.arrive $0, $1;", + "r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + def wait(self) -> None: + """ + NamedBarriers do not have a standalone wait like mbarriers, only an arrive_and_wait. + If synchronizing two warps in a producer/consumer pairing, the arrive count would be + 32 using mbarriers but 64 using NamedBarriers. Only threads from either the producer + or consumer are counted for mbarriers, while all threads participating in the sync + are counted for NamedBarriers. + """ + warnings.warn( + "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." + ) + self.arrive_and_wait() + + def wait_unaligned(self) -> None: + warnings.warn( + "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." + ) + llvm.inline_asm( + None, + [Int32(self.barrier_id).ir_value(), Int32(self.num_threads).ir_value()], + "barrier.sync $0, $1;", + "r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + def arrive_and_wait(self) -> None: + cute.arch.barrier( + barrier_id=self.barrier_id, number_of_threads=self.num_threads + ) + + def arrive_and_drop(self) -> None: + raise NotImplementedError("Error: Not supported.") + + def sync(self) -> None: + cute.arch.barrier(barrier_id=self.barrier_id) + + def get_barrier(self) -> int: + return self.barrier_id + + def max(self) -> int: + # Transaction barriers have a maximum arrive count of 4095 (2^12 - 1). + return 4095 + + +class TmaStoreFence(SyncObject): + """ + TmaStoreFence is used for a multi-stage epilogue buffer. + """ + + def __init__(self, num_stages: int = 0) -> None: + if num_stages <= 0: + raise ValueError("Mbarrier stage count must be greater than 0.") + + self.num_stages = num_stages + + def arrive(self) -> None: + cute.arch.cp_async_bulk_commit_group() + + def wait(self) -> None: + cute.arch.cp_async_bulk_wait_group(self.num_stages - 1, read=True) + + def arrive_and_wait(self) -> None: + self.arrive() + self.wait() + + def arrive_and_drop(self) -> None: + raise NotImplementedError("Error: Not supported.") + + # TmaStoreFence doesn't have mbarriers + def get_barrier(self) -> None: + assert ( + False + ), "Error: TmaStoreFence doesn't use mbarriers and cannot return a barrier." + + def max(self) -> None: + raise NotImplementedError("Error: Not supported.") + + def tail(self) -> None: + cute.arch.cp_async_bulk_wait_group(0, read=True) + + +############################################################################## +# PipelineState class +############################################################################## + + +class PipelineUserType(enum.Enum): + Producer = enum.auto() + Consumer = enum.auto() + + +class PipelineState: + """ + Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer. + """ + + def __init__(self, stages: int, count, index, phase): + self._stages = stages + self._count = count + self._index = index + self._phase = phase + + def clone(self) -> "PipelineState": + return PipelineState(self.stages, self._count, self.index, self.phase) + + @property + def index(self) -> Int32: + return self._index + + @property + def count(self) -> Int32: + return self._count + + @property + def stages(self) -> int: + return self._stages + + @property + def phase(self) -> Int32: + return self._phase + + def reset_count(self): + self._count = Int32(0) + + def advance(self): + self._index += 1 + self._count += 1 + + def then_body(index, phase): + new_index = Int32(0) + new_phase = phase ^ 1 + return new_index, new_phase + + def else_body(index, phase): + return index, phase + + self._index, self._phase = if_generate( + self._index == self.stages, + then_body, + else_body, + [self.index, self.phase], + [Int32, Int32], + ) + + def reverse(self): + self._index -= 1 + self._count -= 1 + + def then_body(index, phase): + new_index = Int32(self.stages - 1) + new_phase = phase ^ 1 + return new_index, new_phase + + def else_body(index, phase): + return index, phase + + self._index, self._phase = if_generate( + self._index == -1, + then_body, + else_body, + [self.index, self.phase], + [Int32, Int32], + ) + + def __get_mlir_types__(self): + return [self._count.type, self._index.type, self._phase.type] + + def __extract_mlir_values__(self): + count = self._count + index = self._index + phase = self._phase + return [count.ir_value(), index.ir_value(), phase.ir_value()] + + # This can be overridden by derived classes + def __new_from_mlir_values__(self, values): + return PipelineState( + self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2]) + ) + + +def make_pipeline_state(type: PipelineUserType, stages: int): + """ + Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1. + """ + if type is PipelineUserType.Producer: + return PipelineState( + stages, + Int32(0), + Int32(0), + Int32(1), + ) + elif type is PipelineUserType.Consumer: + return PipelineState( + stages, + Int32(0), + Int32(0), + Int32(0), + ) + else: + assert ( + False + ), "Error: invalid PipelineUserType specified for make_pipeline_state." + + +############################################################################## +# Helper functions +############################################################################## + + +def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None): + """ + Fences the mbarrier init and syncs the threadblock or cluster + """ + cute.arch.mbarrier_init_fence() + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + # If not using clusters, sync the threadblock + _sync(Agent.ThreadBlock) + else: + # If using clusters, sync the cluster + _sync(Agent.ThreadBlockCluster) + + +def _sync(group: Agent): + """ + Syncs all threads within an agent. + """ + if group is Agent.Thread: + raise NotImplementedError("Error: Not supported.") + elif group is Agent.ThreadBlock: + cute.arch.sync_threads() + elif group is Agent.ThreadBlockCluster: + cute.arch.cluster_arrive() + cute.arch.cluster_wait() + else: + assert ( + False + ), "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead." + + +def _mbarrier_i64_to_ptr(val: Int64) -> cute.Pointer: + """ + Converts a smem pointer of type Int64 to cute.Pointer with 8B alignment + """ + return cute.make_ptr( + Int64, + val.ir_value(), + mem_space=_cute_ir.AddressSpace.smem, + assumed_align=8, + ) + + +# NamedBarrier free functions +def arrive(barrier_id: int, num_threads: int): + """ + The aligned flavor of arrive is used when all threads in the CTA will execute the + same instruction. See PTX documentation. + """ + cute.arch.barrier_arrive(barrier_id=barrier_id, number_of_threads=num_threads) + + +def arrive_unaligned(barrier_id: int, num_threads: int): + """ + The unaligned flavor of arrive can be used with an arbitrary number of threads in the CTA. + """ + llvm.inline_asm( + None, + [Int32(barrier_id).ir_value(), Int32(num_threads).ir_value()], + "barrier.arrive $0, $1;", + "r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +def wait(barrier_id: int, num_threads: int): + """ + NamedBarriers do not have a standalone wait like mbarriers, only an arrive_and_wait. + If synchronizing two warps in a producer/consumer pairing, the arrive count would be + 32 using mbarriers but 64 using NamedBarriers. Only threads from either the producer + or consumer are counted for mbarriers, while all threads participating in the sync + are counted for NamedBarriers. + """ + warnings.warn( + "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." + ) + arrive_and_wait() + + +def wait_unaligned(barrier_id: int, num_threads: int): + warnings.warn( + "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." + ) + llvm.inline_asm( + None, + [Int32(barrier_id).ir_value(), Int32(num_threads).ir_value()], + "barrier.sync $0, $1;", + "r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +def arrive_and_wait(barrier_id: int, num_threads: int): + cute.arch.barrier(barrier_id=barrier_id, number_of_threads=num_threads) + + +def sync(barrier_id: int = 0): + cute.arch.barrier(barrier_id=barrier_id) diff --git a/python/CuTeDSL/cutlass/pipeline/sm100.py b/python/CuTeDSL/cutlass/pipeline/sm100.py new file mode 100644 index 00000000..591e1d7a --- /dev/null +++ b/python/CuTeDSL/cutlass/pipeline/sm100.py @@ -0,0 +1,452 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Union +import warnings + +import cutlass.cute as cute +from cutlass.cutlass_dsl import Boolean, if_generate + +from cutlass.pipeline import ( + CooperativeGroup, + PipelineOp, + PipelineState, + pipeline_init_wait, + PipelineAsync, +) + +############################################################################## +# Pipeline classes +############################################################################## + + +@dataclass(frozen=True) +class PipelineTmaUmma(PipelineAsync): + """ + PipelineTmaUmma is used for TMA producers and UMMA consumers (e.g. Blackwell mainloops). + """ + + is_leader_cta: bool + cta_group: cute.nvgpu.tcgen05.CtaGroup + + @staticmethod + def _compute_mcast_arrival_mask(cta_layout_vmnk: cute.Layout): + """ + Computes a mask for signaling arrivals to multicasting threadblocks. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + + tma_mcast_mask_a = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=2 + ) + tma_mcast_mask_b = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=1 + ) + + block_in_cluster_coord_vmnk_peer = ( + cta_in_cluster_coord_vmnk[0] ^ 1, + *cta_in_cluster_coord_vmnk[1:], + ) + tma_mcast_mask_a_peer = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=2 + ) + tma_mcast_mask_b_peer = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=1 + ) + + return ( + tma_mcast_mask_a + | tma_mcast_mask_b + | tma_mcast_mask_a_peer + | tma_mcast_mask_b_peer + ) + + @staticmethod + def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout): + """ + Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders. + """ + bidx, bidy, _ = cute.arch.block_idx() + + mma_coord_vmnk = ( + bidx % cute.size(cta_layout_vmnk, mode=[0]), + bidx // cute.size(cta_layout_vmnk, mode=[0]), + bidy, + None, + ) + return mma_coord_vmnk[0] == 0 + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + tx_count: int, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: CooperativeGroup for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: CooperativeGroup for the consumer agent + :type consumer_group: CooperativeGroup + :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage + :type tx_count: int + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.TCGen05Mma + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer, tx_count + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + # No mcast mask if not using clusters + producer_mask = None + # All threadblocks are leaders if not using clusters + is_leader_cta = True + else: + producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk) + is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk) + + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + + consumer_mask = producer_mask + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineTmaUmma( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + is_leader_cta, + cta_group, + ) + + def consumer_release(self, state: PipelineState): + """ + UMMA consumer release buffer empty, cta_group needs to be provided. + """ + self.sync_object_empty.arrive(state.index, self.consumer_mask, self.cta_group) + + def producer_acquire( + self, state: PipelineState, try_acquire_token: Optional[Boolean] = None + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), + ) + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive(state.index, self.producer_mask), + ) + + def producer_commit(self, state: PipelineState): + """ + TMA producer commit is a noop since TMA instruction itself updates the transaction count. + """ + pass + + +@dataclass(frozen=True) +class PipelineAsyncUmma(PipelineAsync): + """ + PipelineAsyncUmma is used for AsyncThread producers and UMMA consumers (e.g. Blackwell input fusion pipelines). + """ + + cta_group: cute.nvgpu.tcgen05.CtaGroup + + @staticmethod + def _compute_leading_cta_rank(cta_v_size): + """ + Computes the leading CTA rank. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + return cta_rank_in_cluster // cta_v_size * cta_v_size + + @staticmethod + def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout): + """ + Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders. + """ + bidx, bidy, _ = cute.arch.block_idx() + mma_coord_vmnk = ( + bidx % cute.size(cta_layout_vmnk, mode=[0]), + bidx // cute.size(cta_layout_vmnk, mode=[0]), + bidy, + None, + ) + return mma_coord_vmnk[0] == 0 + + @staticmethod + def _compute_peer_cta_mask(cta_layout_vmnk: cute.Layout): + """ + Computes a mask for signaling arrivals to multicasting threadblocks. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + mask_self = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=0 + ) + block_in_cluster_coord_vmnk_peer = ( + cta_in_cluster_coord_vmnk[0] ^ 1, + *cta_in_cluster_coord_vmnk[1:], + ) + mask_peer = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=0 + ) + return mask_self | mask_peer + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineAsyncUmma. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: CooperativeGroup for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: CooperativeGroup for the consumer agent + :type consumer_group: CooperativeGroup + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.AsyncThread + consumer_type = PipelineOp.TCGen05Mma + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), + num_stages, + producer, + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + cta_v_size = ( + cute.size(cta_layout_vmnk, mode=[0]) if cta_layout_vmnk is not None else 1 + ) + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1: + # No mcast mask if we're not using 2CTA tcgen05 MMA + producer_mask = None + consumer_mask = None + else: + # If we're using 2CTA UMMAs, producer will arrive the mbar on leading CTA + # We need to get the target cta_rank + producer_mask = PipelineAsyncUmma._compute_leading_cta_rank(cta_v_size) + # consumer needs to get the mask to signal + consumer_mask = PipelineAsyncUmma._compute_peer_cta_mask(cta_layout_vmnk) + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineAsyncUmma( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + cta_group, + ) + + def consumer_release(self, state: PipelineState): + """ + UMMA consumer release buffer empty, cta_group needs to be provided. + """ + self.sync_object_empty.arrive(state.index, self.consumer_mask, self.cta_group) + + +@dataclass(frozen=True) +class PipelineUmmaAsync(PipelineAsync): + """ + PipelineUmmaAsync is used for UMMA producers and AsyncThread consumers (e.g. Blackwell accumulator pipelines). + """ + + cta_group: cute.nvgpu.tcgen05.CtaGroup + + @staticmethod + def _compute_tmem_sync_mask(cta_layout_vmnk: cute.Layout): + """ + Computes a mask to signal completion of tmem buffers for 2CTA kernels. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + return cute.make_layout_image_mask( + cta_layout_vmnk, cta_in_cluster_coord_vmnk, mode=0 + ) + + @staticmethod + def _compute_peer_cta_rank(): + """ + Computes a mask to signal release of tmem buffers for 2CTA kernels. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + return cta_rank_in_cluster // 2 * 2 + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineUmmaAsync. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: CooperativeGroup for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: CooperativeGroup for the consumer agent + :type consumer_group: CooperativeGroup + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TCGen05Mma + consumer_type = PipelineOp.AsyncThread + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + # Set mask to None if not using clusters (i.e. 1CTA kernels) + producer_mask = None + else: + producer_mask = PipelineUmmaAsync._compute_tmem_sync_mask(cta_layout_vmnk) + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1: + # Set mask to None if not using 2CTA intructions + consumer_mask = None + else: + consumer_mask = PipelineUmmaAsync._compute_peer_cta_rank() + + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineUmmaAsync( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + cta_group, + ) + + def producer_commit(self, state: PipelineState): + """ + UMMA producer commit buffer full, cta_group needs to be provided. + """ + self.sync_object_full.arrive(state.index, self.producer_mask, self.cta_group) + + def producer_tail(self, state: PipelineState): + """ + Make sure the last used buffer empty signal is visible to producer. + Producer tail is usually executed by producer before exit, to avoid dangling + mbarrier arrive signals after kernel exit. + + :param state: The pipeline state that points to next useful buffer + :type state: PipelineState + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + + def then_body(): + # Assume state contains that next useful buffer + # So we only need to advance to num_stages - 1 times to last used buffer + for i in range(self.num_stages - 1): + state.advance() + self.producer_acquire(state) + + if_generate(is_leader_cta, then_body) diff --git a/python/CuTeDSL/cutlass/pipeline/sm90.py b/python/CuTeDSL/cutlass/pipeline/sm90.py new file mode 100644 index 00000000..53e4dc8e --- /dev/null +++ b/python/CuTeDSL/cutlass/pipeline/sm90.py @@ -0,0 +1,803 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Union +import warnings + +import cutlass.cute as cute +from cutlass.cutlass_dsl import Boolean, Int32, if_generate + +from cutlass.pipeline import ( + CooperativeGroup, + PipelineOp, + SyncObject, + MbarrierArray, + TmaStoreFence, + PipelineUserType, + PipelineState, + make_pipeline_state, + pipeline_init_wait, +) + +############################################################################## +# Pipeline classes +############################################################################## + + +@dataclass(frozen=True) +class PipelineAsync: + """PipelineAsync is a generic pipeline class where both the producer and consumer are + AsyncThreads. It also serves as a base class for specialized pipeline classes. + + This class implements a producer-consumer pipeline pattern where both sides operate + asynchronously. The pipeline maintains synchronization state using barrier objects + to coordinate between producer and consumer threads. + + The pipeline state transitions of one pipeline entry(mbarrier) can be represented as: + + .. table:: Pipeline State Transitions + :widths: auto + + +-----------+-----------+-----------+-----------+-----------+-----------+ + | Barrier | State | p.acquire | p.commit | c.wait | c.release | + +===========+===========+===========+===========+===========+===========+ + | empty_bar | empty | | n/a | n/a | - | + +-----------+-----------+-----------+-----------+-----------+-----------+ + | empty_bar | wait | | n/a | n/a | -> empty | + +-----------+-----------+-----------+-----------+-----------+-----------+ + | full_bar | wait | n/a | -> full | | n/a | + +-----------+-----------+-----------+-----------+-----------+-----------+ + | full_bar | full | n/a | - | | n/a | + +-----------+-----------+-----------+-----------+-----------+-----------+ + + Where: + + - p: producer + - c: consumer + - : This action is blocked until transition to a state allow it to proceed by other side + - e.g. ``p.acquire()`` is blocked until ``empty_bar`` transition to ``empty`` state by ``c.release()`` + + .. code-block:: text + + Array of mbarriers as circular buffer: + + Advance Direction + <------------------- + + Producer Consumer + | ^ + V | + +-----------------+ + --|X|X|W|D|D|D|D|R|X|<-. + / +-----------------+ \\ + | | + `------------------------' + + Where: + + - X: Empty buffer (initial state) + - W: Producer writing (producer is waiting for buffer to be empty) + - D: Data ready (producer has written data to buffer) + - R: Consumer reading (consumer is consuming data from buffer) + + **Example:** + + .. code-block:: python + + # Create pipeline with 5 stages + pipeline = PipelineAsync.create( + num_stages=5, # number of pipeline stages + producer_group=producer_warp, + consumer_group=consumer_warp + barrier_storage=smem_ptr, # smem pointer for array of mbarriers in shared memory + ) + + # Producer side + producer = pipeline.make_pipeline_producer(producer_warp) + for i in range(num_iterations): + producer.acquire() # Wait for buffer to be empty + # Write data to pipeline buffer + producer.commit() # Signal buffer is full + producer.advance() # Move index to next stage + + # Consumer side + consumer = pipeline.make_pipeline_consumer(consumer_warp) + for i in range(num_iterations): + consumer.wait() # Wait for buffer to be full + # Read data from pipeline buffer + consumer.release() # Signal buffer is empty + consumer.advance() # Move index to next stage + """ + + sync_object_full: SyncObject + sync_object_empty: SyncObject + num_stages: int + producer_mask: Optional[Int32] + consumer_mask: Optional[Int32] + + @staticmethod + def _make_sync_object( + barrier_storage: cute.Pointer, + num_stages: int, + agent: tuple[PipelineOp, CooperativeGroup], + tx_count: int = 0, + ) -> SyncObject: + """ + Returns a SyncObject corresponding to an agent's PipelineOp. + """ + if agent[0] in [ + PipelineOp.AsyncThread, + PipelineOp.TmaLoad, + PipelineOp.TCGen05Mma, + PipelineOp.Composite, + ]: + return MbarrierArray( + barrier_storage=barrier_storage, + num_stages=num_stages, + agent=agent, + tx_count=tx_count, + ) + elif agent[0] is PipelineOp.TmaStore: + # Path taken for AsyncTmaStore + return TmaStoreFence(num_stages=num_stages) + else: + assert False, "Error: Invalid PipelineOp specified." + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + barrier_storage: cute.Pointer = None, + producer_mask: Int32 = None, + consumer_mask: Int32 = None, + ): + """Creates and initializes a new PipelineAsync instance. + + This helper function computes necessary attributes and returns an instance of PipelineAsync + with the specified configuration for producer and consumer synchronization. + + :param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: int + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: `CooperativeGroup` for the consumer agent + :type consumer_group: CooperativeGroup + :param producer_mask: Mask for signaling arrives for the producer agent, defaults to ``None`` + :type producer_mask: Int32, optional + :param consumer_mask: Mask for signaling arrives for the consumer agent, defaults to ``None`` + :type consumer_mask: Int32, optional + :return: A new PipelineAsync instance + :rtype: PipelineAsync + :raises ValueError: If barrier_storage is not a cute.Pointer instance + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.AsyncThread + consumer_type = PipelineOp.AsyncThread + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + pipeline_init_wait() + + return PipelineAsync( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + ) + + def producer_acquire( + self, state: PipelineState, try_acquire_token: Optional[Boolean] = None + ): + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), + ) + + def producer_try_acquire(self, state: PipelineState): + return self.sync_object_empty.try_wait(state.index, state.phase) + + def producer_commit(self, state: PipelineState): + self.sync_object_full.arrive(state.index, self.producer_mask) + + def consumer_wait( + self, state: PipelineState, try_wait_token: Optional[Boolean] = None + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(state.index, state.phase), + ) + + def consumer_try_wait(self, state: PipelineState): + return self.sync_object_full.try_wait(state.index, state.phase) + + def consumer_release(self, state: PipelineState): + self.sync_object_empty.arrive(state.index, self.consumer_mask) + + def producer_get_barrier(self, state: PipelineState) -> cute.Pointer: + return self.sync_object_full.get_barrier(state.index) + + def producer_tail(self, state: PipelineState): + """ + Make sure the last used buffer empty signal is visible to producer. + Producer tail is usually executed by producer before exit, to avoid dangling + mbarrier arrive signals after kernel exit. + + :param state: The pipeline state that points to next useful buffer + :type state: PipelineState + """ + # Assume state contains that next useful buffer + # So we only need to advance to num_stages - 1 times to last used buffer + for i in range(self.num_stages - 1): + state.advance() + self.producer_acquire(state) + + # Util methods to manage produer and consumer + def make_pipeline_producer(self, group: CooperativeGroup): + state = make_pipeline_state(PipelineUserType.Producer, self.num_stages) + return PipelineProducer(self, state, group) + + def make_pipeline_consumer(self, group: CooperativeGroup): + state = make_pipeline_state(PipelineUserType.Consumer, self.num_stages) + return PipelineConsumer(self, state, group) + + +@dataclass(frozen=True) +class PipelineTmaAsync(PipelineAsync): + """ + PipelineTmaAsync is used for TMA producers and AsyncThread consumers (e.g. Hopper mainloops). + """ + + is_signalling_thread: Boolean + + @staticmethod + @cute.jit + def init_empty_barrier_arrive_signal(cta_layout_vmnk: cute.Layout, tidx: Int32): + """ + Initialize the empty barrier arrive signal + This function returns the destination cta rank and a boolean indicating if the signalling thread is the same as the current thread + """ + # Logic to optimally schedule Empty Arrives + cluster_shape_vmnk = cta_layout_vmnk.shape + + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + + tidx = tidx % 32 + is_signalling_thread = tidx < cute.size(cluster_shape_vmnk) + dst_rank = tidx % cute.size(cluster_shape_vmnk) + + dst_cta_coord = cta_layout_vmnk.get_hier_coord(dst_rank) + cur_cta_coord = cta_layout_vmnk.get_hier_coord(cta_rank_in_cluster) + + is_same_row = ( + dst_cta_coord[0] == cur_cta_coord[0] + and dst_cta_coord[1] == cur_cta_coord[1] + and dst_cta_coord[3] == cur_cta_coord[3] + ) + is_same_col = ( + dst_cta_coord[0] == cur_cta_coord[0] + and dst_cta_coord[2] == cur_cta_coord[2] + and dst_cta_coord[3] == cur_cta_coord[3] + ) + + is_same_row_or_col = is_same_row or is_same_col + is_signalling_thread_final = is_signalling_thread and is_same_row_or_col + + return dst_rank, is_signalling_thread_final + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + tx_count: int, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + tidx: Optional[Int32] = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: CooperativeGroup for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: CooperativeGroup for the consumer agent + :type consumer_group: CooperativeGroup + :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage + :type tx_count: int + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + :param tidx: thread index to consumer async threads + :type tidx: Int32 | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.AsyncThread + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer, tx_count + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + if tidx is None: + tidx, _, _ = cute.arch.thread_idx() + if cta_layout_vmnk is None: + cta_layout_vmnk = cute.make_layout((1, 1, 1, 1)) + ( + dst_rank, + is_signalling_thread, + ) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx) + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + dst_rank = None + else: + dst_rank = dst_rank + + producer_mask = None + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineTmaAsync( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + dst_rank, + is_signalling_thread, + ) + + def producer_acquire( + self, state: PipelineState, try_acquire_token: Optional[Boolean] = None + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), + ) + self.sync_object_full.arrive(state.index, self.producer_mask) + + def producer_commit(self, state: PipelineState): + """ + TMA producer commit is a noop since TMA instruction itself updates the transaction count. + """ + pass + + def consumer_release(self, state: PipelineState): + """ + TMA consumer release conditionally signals the empty buffer to the producer. + """ + if_generate( + self.is_signalling_thread, + lambda: self.sync_object_empty.arrive(state.index, self.consumer_mask), + ) + + +@dataclass(frozen=True) +class PipelineTmaMultiConsumersAsync(PipelineAsync): + """ + PipelineTmaMultiConsumersAsync is used for TMA producers and UMMA+Async consumers. + """ + + is_leader_cta: bool + sync_object_empty_umma: SyncObject + sync_object_empty_async: SyncObject + cta_group: cute.nvgpu.tcgen05.CtaGroup + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group_umma: CooperativeGroup, + consumer_group_async: CooperativeGroup, + tx_count: int, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaMultiConsumersAsync. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: CooperativeGroup for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group_umma: CooperativeGroup for the UMMA consumer agent + :type consumer_group_umma: CooperativeGroup + :param consumer_group_async: CooperativeGroup for the AsyncThread consumer agent + :type consumer_group_async: CooperativeGroup + :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage + :type tx_count: int + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.Composite + consumer_type_umma = PipelineOp.TCGen05Mma + consumer_type_async = PipelineOp.AsyncThread + + if consumer_group_umma.agent != consumer_group_async.agent: + raise ValueError( + "UMMA and AsyncThread consumer groups must be the same agent" + ) + + if cta_layout_vmnk is not None and cute.size(cta_layout_vmnk) != 1: + raise ValueError( + f"PipelineTmaMultiConsumersAsync is not verified for cta_layout_vmnk != 1, cta_layout_vmnk:{cta_layout_vmnk}" + ) + + consumer_group = CooperativeGroup( + consumer_group_umma.agent, + consumer_group_umma.size + consumer_group_async.size, + ) + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer, tx_count + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + sync_object_empty_umma = sync_object_empty.recast_to_new_op_type( + consumer_type_umma + ) + sync_object_empty_async = sync_object_empty.recast_to_new_op_type( + consumer_type_async + ) + + # No mcast mask if not using clusters + producer_mask = None + consumer_mask = None + # All threadblocks are leaders if not using clusters + is_leader_cta = True + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineTmaMultiConsumersAsync( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + is_leader_cta, + sync_object_empty_umma, + sync_object_empty_async, + cta_group, + ) + + def producer_acquire( + self, state: PipelineState, try_acquire_token: Optional[Boolean] = None + ): + """ + TMA producer acquire waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), + ) + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive(state.index, self.producer_mask), + ) + + def producer_commit(self, state: PipelineState): + """ + TMA producer commit is a noop since TMA instruction itself updates the transaction count. + """ + pass + + def consumer_release(self, state: PipelineState, op_type: PipelineOp): + if op_type == PipelineOp.TCGen05Mma: + self.sync_object_empty_umma.arrive( + state.index, self.consumer_mask, self.cta_group + ) + elif op_type == PipelineOp.AsyncThread: + self.sync_object_empty_async.arrive(state.index, self.consumer_mask) + else: + raise ValueError(f"Invalid PipelineOp specified. op_type:{op_type}") + + +@dataclass(frozen=True) +class PipelineTmaStore(PipelineAsync): + """ + PipelineTmaStore is used for synchronizing TMA stores in the epilogue. It does not use mbarriers. + """ + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaStore. + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: CooperativeGroup for the producer agent + :type producer_group: CooperativeGroup + """ + producer_type = PipelineOp.TmaStore + + producer = (producer_type, producer_group) + + sync_object_full = PipelineAsync._make_sync_object(None, num_stages, producer) + + return PipelineTmaStore(sync_object_full, None, num_stages, None, None) + + def producer_acquire(self): + self.sync_object_full.wait() + + def producer_commit(self): + self.sync_object_full.arrive() + + def consumer_wait(self): + assert False, "Error: PipelineTmaStore does not have a consumer agent." + + def consumer_release(self): + assert False, "Error: PipelineTmaStore does not have a consumer agent." + + def producer_tail(self): + self.sync_object_full.tail() + + +################################################################# +# Utilities to help user of pipeline to simplify the workflow +################################################################# + + +class PipelineProducer: + """A class representing a producer in an asynchronous pipeline. + + The Producer class manages the producer side of an asynchronous pipeline, handling + synchronization and state management for producing data. It provides methods for + acquiring, committing, and advancing through pipeline stages. + + :ivar _pipeline: The asynchronous pipeline this producer belongs to + :type _pipeline: PipelineAsync + :ivar _state: The current state of the producer in the pipeline + :type _state: PipelineState + :ivar _group: The cooperative group this producer operates in + :type _group: CooperativeGroup + + **Examples:** + + .. code-block:: python + + pipeline = PipelineAsync.create(...) + producer = pipeline.create_producer(producer_group, stages) + for i in range(iterations): + producer.acquire() # Wait for buffer to be empty + # Produce data + producer.commit() # Signal data is ready + producer.advance() # Move to next stage + """ + + _pipeline: PipelineAsync + _state: PipelineState + _group: CooperativeGroup + + def __init__(self, pipeline, state, group: CooperativeGroup): + """Initialize a new Producer instance. + + :param pipeline: The pipeline this producer belongs to + :type pipeline: PipelineAsync + :param state: Initial pipeline state + :type state: PipelineState + :param group: The cooperative group for synchronization + :type group: CooperativeGroup + """ + self._pipeline = pipeline + self._state = state + self._group = group + + @property + def index(self): + """Get the index of the current pipeline stage.""" + return self._state.index + + def get_barrier(self): + """Get the barrier pointer for the current pipeline stage. + + :return: Pointer to the barrier for the current stage + :rtype: cute.Pointer + """ + return self._pipeline.producer_get_barrier(self._state) + + def acquire(self): + """Wait for the current buffer to be empty before producing data. + This is a blocking operation. + """ + self._pipeline.producer_acquire(self._state) + + def try_acquire(self): + """Try to acquire the current buffer without blocking. + + :return: True if acquisition was successful, False otherwise + :rtype: bool + """ + self._pipeline.producer_try_acquire(self._state) + + def commit(self): + """Signal that data production is complete for the current stage. + This allows consumers to start processing the data. + """ + self._pipeline.producer_commit(self._state) + + def tail(self): + """Ensure all used buffers are properly synchronized before producer exit. + This should be called before the producer finishes to avoid dangling signals. + """ + self._pipeline.producer_tail(self._state) + + def advance(self): + """Move to the next pipeline stage.""" + self._state.advance() + + def __extract_mlir_values__(self): + """Extract MLIR values from the current state. + + :return: List of MLIR values representing the current state + :rtype: list + """ + # TODO: need to handle pipeline as well + return self._state.__extract_mlir_values__() + + def __new_from_mlir_values__(self, values): + """Create a new Producer instance from MLIR values. + + :param values: MLIR values to initialize the state + :type values: Any + :return: New Producer instance with state initialized from values + :rtype: Producer + """ + return PipelineProducer( + self._pipeline, self._state.__new_from_mlir_values__(values), self._group + ) + + +class PipelineConsumer: + """A class representing a consumer in an asynchronous pipeline. + + The Consumer class manages the consumer side of an asynchronous pipeline, handling + synchronization and state management for consuming data. It provides methods for + waiting, releasing, and advancing through pipeline stages. + + :ivar _pipeline: The asynchronous pipeline this consumer belongs to + :type _pipeline: PipelineAsync + :ivar _state: The current state of the consumer in the pipeline + :type _state: PipelineState + :ivar _group: The cooperative group this consumer operates in + :type _group: CooperativeGroup + + **Examples:** + .. code-block:: python + + pipeline = PipelineAsync.create(...) + consumer = pipeline.create_consumer(consumer_group, stages) + for i in range(iterations): + consumer.wait() # Wait for data to be ready + # Consume data + consumer.release() # Signal buffer is empty + consumer.advance() # Move to next stage + """ + + _pipeline: PipelineAsync + _state: PipelineState + _group: CooperativeGroup + + def __init__(self, pipeline, state: PipelineState, group: CooperativeGroup): + """Initialize a new Consumer instance. + + :param pipeline: The pipeline this consumer belongs to + :type pipeline: PipelineAsync + :param state: Initial pipeline state + :type state: PipelineState + :param group: The cooperative group for synchronization + :type group: CooperativeGroup + """ + self._pipeline = pipeline + self._group = group + self._state = state + + @property + def index(self): + """Get the index of the current pipeline stage.""" + return self._state.index + + def wait(self): + """Wait for data to be ready in the current buffer. + This is a blocking operation. + """ + self._pipeline.consumer_wait(self._state) + + def try_wait(self): + """Try to check if data is ready without blocking. + + :return: True if data is ready, False otherwise + :rtype: bool + """ + self._pipeline.consumer_try_wait(self._state) + + def release(self): + """Signal that data consumption is complete for the current stage. + This allows producers to start producing new data. + """ + self._pipeline.consumer_release(self._state) + + def advance(self): + """Move to the next pipeline stage.""" + self._state.advance() + + def __extract_mlir_values__(self): + """Extract MLIR values from the current state. + + :return: List of MLIR values representing the current state + :rtype: list + """ + return self._state.__extract_mlir_values__() + + def __new_from_mlir_values__(self, values): + """Create a new Consumer instance from MLIR values. + + :param values: MLIR values to initialize the state + :type values: Any + :return: New Consumer instance with state initialized from values + :rtype: Consumer + """ + # TODO: need to call pipeline.__new_from_mlir_values__ recursively + return PipelineConsumer( + self._pipeline, self._state.__new_from_mlir_values__(values), self._group + ) diff --git a/python/CuTeDSL/cutlass/torch.py b/python/CuTeDSL/cutlass/torch.py index 0126fb04..32bf0738 100644 --- a/python/CuTeDSL/cutlass/torch.py +++ b/python/CuTeDSL/cutlass/torch.py @@ -29,6 +29,7 @@ from cutlass.cute.typing import ( from cutlass.cute.runtime import from_dlpack import cutlass.cute as cute import torch +from cuda import cuda def dtype(ty: Type[Numeric]): @@ -94,12 +95,13 @@ def create_and_permute_torch_tensor( init_config: Optional[ Union[RandomInitConfig, ScalarInitConfig, GaussianInitConfig] ] = None, + device: Optional[torch.device] = None, ) -> "torch.Tensor": """ Create a torch tensor with specified shape and dtype. Optionally permute it and initialize it with specified init type and config """ init_dtype = torch.int32 if init_type == TensorInitType.RANDOM else torch.float32 - init_torch_tensor = torch.empty(*shape, dtype=init_dtype) + init_torch_tensor = torch.empty(*shape, dtype=init_dtype, device=device) if init_type == TensorInitType.SKIP: assert init_config is None f32_torch_tensor = init_torch_tensor @@ -167,3 +169,122 @@ def convert_cute_tensor( # Copy and convert from f32 cute tensor to dtype cute tensor cute.testing.convert(fp32_cute_tensor, cute_tensor) return cute_tensor + + +def default_stream() -> cuda.CUstream: + """ + Get default CUstream from torch stream + """ + torch_stream = torch.cuda.default_stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + return stream + + +def current_stream() -> cuda.CUstream: + """ + Get current CUstream from torch stream + """ + torch_stream = torch.cuda.current_stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + return stream + + +def matrix( + l: int, + mode0: int, + mode1: int, + is_mode0_major: bool, + cutlass_dtype: Type[Numeric], + init_type: TensorInitType = TensorInitType.RANDOM, + init_config: Optional[ + Union[RandomInitConfig, ScalarInitConfig, GaussianInitConfig] + ] = None, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """ + Create a torch tensor for matrix + + :param l: length of the matrix + :param mode0: mode0 of the matrix + :param mode1: mode1 of the matrix + :param is_mode0_major: whether the matrix is mode0 major + :param cutlass_dtype: cutlass dtype of the matrix + :param init_type: type of initialization + :param init_config: configuration for initialization + :param device: target torch device + """ + + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + + if cutlass_dtype.is_float and cutlass_dtype.width <= 8: + torch_dtype = torch.int8 + else: + torch_dtype = dtype(cutlass_dtype) + + if init_type == TensorInitType.RANDOM and init_config is None: + if torch_dtype.is_signed: + min_val = -2 + max_val = 2 + else: + min_val = 0 + max_val = 4 + init_config = RandomInitConfig(min_val=min_val, max_val=max_val) + + # Create dtype torch tensor + torch_tensor = create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=init_type, + init_config=init_config, + device=device, + ) + + return torch_tensor + + +def cute_tensor_like( + data_ref: torch.Tensor, + cutlass_dtype: Type[Numeric], + is_dynamic_layout: bool, + assumed_align: Optional[int] = None, +) -> tuple[Tensor, torch.Tensor]: + """ + Create a cute tensor use a torch tensor as the data source + + :param data_ref: torch tensor as the data source + :param cutlass_dtype: cutlass dtype of the cute tensor + :param is_dynamic_layout: whether the cute tensor uses dynamic layout + :param assumed_align: assumed alignment of the cute tensor + """ + + # allocate device buffer for cute tensor + if cutlass_dtype.is_float and cutlass_dtype.width <= 8: + torch_dtype = torch.int8 + else: + torch_dtype = dtype(cutlass_dtype) + torch_tensor = torch.empty_like(data_ref, dtype=torch_dtype, device="cuda") + + # create cute tensor using the device buffer + cute_tensor = from_dlpack(torch_tensor, assumed_align=assumed_align) + cute_tensor.element_type = cutlass_dtype + if is_dynamic_layout: + for i, stride in enumerate(torch_tensor.stride()): + if stride == 1: + leading_dim = i + break + cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) + + # initialize the cute tensor data + if cutlass_dtype.is_float and cutlass_dtype.width <= 8: + cute_tensor = convert_cute_tensor( + data_ref.to(dtype=torch.float32), + cute_tensor, + cutlass_dtype, + is_dynamic_layout, + ) + else: + torch_tensor.copy_(data_ref.to(dtype=torch_dtype)) + + return cute_tensor, torch_tensor diff --git a/python/CuTeDSL/cutlass/utils/__init__.py b/python/CuTeDSL/cutlass/utils/__init__.py index dc3fdbcd..30bd2d4c 100644 --- a/python/CuTeDSL/cutlass/utils/__init__.py +++ b/python/CuTeDSL/cutlass/utils/__init__.py @@ -15,20 +15,6 @@ from .static_persistent_tile_scheduler import ( StaticPersistentTileScheduler, ) -from .pipeline import ( - Agent, - CooperativeGroup, - PipelineUserType, - PipelineState, - make_pipeline_state, - PipelineAsync, - PipelineTmaAsync, - PipelineTmaUmma, - PipelineUmmaAsync, - PipelineTmaStore, - pipeline_init_wait, -) - from .hardware_info import ( HardwareInfo, ) @@ -65,6 +51,8 @@ from .smem_allocator import SmemAllocator from .layout import LayoutEnum __all__ = [ + "SmemAllocator", + "LayoutEnum", "WorkTileInfo", "PersistentTileSchedulerParams", "StaticPersistentTileScheduler", diff --git a/python/CuTeDSL/cutlass/utils/blackwell_helpers.py b/python/CuTeDSL/cutlass/utils/blackwell_helpers.py index ca01ad49..167f3efb 100644 --- a/python/CuTeDSL/cutlass/utils/blackwell_helpers.py +++ b/python/CuTeDSL/cutlass/utils/blackwell_helpers.py @@ -51,8 +51,13 @@ from cutlass.cute.nvgpu.tcgen05 import ( is_tmem_load, get_tmem_copy_properties, ) +from cutlass.cute.nvgpu.cpasync import ( + CopyBulkTensorTileG2SMulticastOp, + CopyBulkTensorTileG2SOp, +) from cutlass.utils.layout import LayoutEnum + @dsl_user_op def compute_epilogue_tile_shape( cta_tile_shape: cute.Shape, @@ -716,6 +721,7 @@ def make_smem_layout_b( return b_smem_layout_staged + @dsl_user_op def get_smem_layout_atom_epi( layout: LayoutEnum, @@ -827,6 +833,7 @@ SMEM_CAPACITY = { "sm120": SmemCapacity.SM120_SMEM_CAPACITY_BYTES.value, } + @dsl_user_op def make_trivial_tiled_mma( ab_dtype: Type[Numeric], @@ -908,3 +915,139 @@ def make_trivial_tiled_mma( raise TypeError(f"unsupported ab_dtype, got {ab_dtype}") return cute.make_tiled_mma(cute.make_mma_atom(mma_op)) + + +@dsl_user_op +def cluster_shape_to_tma_atom_A( + cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None +) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]: + """ + Select the appropriate TMA copy atom for A based on the number of SMs and the multicast flag. + + :param cluster_shape_mnk: The shape of the cluster + :type cluster_shape_mnk: cute.Shape + :param atom_thr_id: The thread ID of the atom + :type atom_thr_id: cute.Layout + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + :raise ValueError: If the cluster shape is not divisible by the atom SM count + """ + atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip) + mcast = not (cute.size(cluster_shape_mnk, mode=[1], loc=loc, ip=ip) == 1) + cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip) + + if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int): + raise ValueError( + f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0: + raise ValueError( + f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if atom_sm_cnt == 2 and mcast: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO) + elif atom_sm_cnt == 2 and not mcast: + return CopyBulkTensorTileG2SOp(CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return CopyBulkTensorTileG2SOp(CtaGroup.ONE) + + raise ValueError( + f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}" + ) + + +@dsl_user_op +def cluster_shape_to_tma_atom_B( + cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None +) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]: + """ + Select the appropriate TMA copy atom for Bbased on the number of SMs and the multicast flag. + + :param cluster_shape_mnk: The shape of the cluster + :type cluster_shape_mnk: cute.Shape + :param atom_thr_id: The thread ID of the atom + :type atom_thr_id: cute.Layout + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + :raise ValueError: If the cluster shape is not divisible by the atom SM count + """ + atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip) + mcast = not (cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) == atom_sm_cnt) + cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip) + + if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int): + raise ValueError( + f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0: + raise ValueError( + f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if atom_sm_cnt == 2 and mcast: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO) + elif atom_sm_cnt == 2 and not mcast: + return CopyBulkTensorTileG2SOp(CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return CopyBulkTensorTileG2SOp(CtaGroup.ONE) + + raise ValueError( + f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}" + ) + + +@dsl_user_op +def cluster_shape_to_tma_atom_SFB( + cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None +) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]: + """ + Select the appropriate TMA copy atom for SFB based on the number of SMs and the multicast flag. + + :param cluster_shape_mnk: The shape of the cluster + :type cluster_shape_mnk: cute.Shape + :param atom_thr_id: The thread ID of the atom + :type atom_thr_id: cute.Layout + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + :raise ValueError: If the cluster shape is not divisible by the atom SM count + """ + atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip) + mcast = not (cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) == 1) + cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip) + + if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int): + raise ValueError( + f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0: + raise ValueError( + f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if atom_sm_cnt == 2: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return CopyBulkTensorTileG2SOp(CtaGroup.ONE) + + raise ValueError( + f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}" + ) diff --git a/python/CuTeDSL/cutlass/utils/hopper_helpers.py b/python/CuTeDSL/cutlass/utils/hopper_helpers.py index d29daf50..3b94d694 100644 --- a/python/CuTeDSL/cutlass/utils/hopper_helpers.py +++ b/python/CuTeDSL/cutlass/utils/hopper_helpers.py @@ -96,6 +96,10 @@ def make_trivial_tiled_mma( acc_dtype: Type[Numeric], atom_layout_mnk: Tuple[int, int, int], tiler_mn: Tuple[int, int], + a_source: OperandSource = OperandSource.SMEM, + *, + loc=None, + ip=None, ) -> cute.TiledMma: """Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape. By default, the MMA atom is created with SMEM operand source for A. @@ -131,7 +135,7 @@ def make_trivial_tiled_mma( a_dtype, acc_dtype, (*tiler_mn, 16), - OperandSource.SMEM, + a_source, a_leading_mode, b_leading_mode, ) @@ -144,7 +148,7 @@ def make_trivial_tiled_mma( b_dtype, acc_dtype, (*tiler_mn, 32), - OperandSource.SMEM, + a_source, a_leading_mode, b_leading_mode, ) diff --git a/python/CuTeDSL/cutlass/utils/pipeline.py b/python/CuTeDSL/cutlass/utils/pipeline.py deleted file mode 100644 index eb104278..00000000 --- a/python/CuTeDSL/cutlass/utils/pipeline.py +++ /dev/null @@ -1,1023 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -import enum -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Optional - -from cutlass.cutlass_dsl import Boolean, Int32, Int64, T, if_generate, and_, or_ - -import cutlass._mlir.dialects.cute as _cute_ir - -import cutlass.cute as cute - - -############################################################################## -# Agent class -############################################################################## - - -class Agent(enum.Enum): - """ - Agent indicates what is participating in the pipeline synchronization. - """ - - # Arbitrary grouping of N threads - Thread = enum.auto() - # Same as AsyncThread, but includes all threads in the block - ThreadBlock = enum.auto() - # Same as AsyncThread, but includes all threads in the cluster - ThreadBlockCluster = enum.auto() - - -class CooperativeGroup: - """ - CooperativeGroup contains size and alignment restrictions for an Agent. - """ - - def __init__(self, agent: Agent, size: int = 1, alignment: int = 1): - if agent is Agent.Thread: - assert size > 0 - if size == 32: - assert ( - size == alignment - ), "Error: Alignment does not match number of threads in a warp." - elif size == 128: - assert ( - size == alignment - ), "Error: Alignment does not match number of threads in a warpgroup." - elif agent is Agent.ThreadBlock: - assert False, "Error: Not yet supported." - elif agent is Agent.ThreadBlockCluster: - assert False, "Error: Not yet supported." - else: - # Should never reach this state - size = 0 - - if size <= 0: - raise ValueError( - "Error: The number of threads in a CooperativeGroup must be more than 0." - ) - - # Size indicates how many threads are participating in this CooperativeGroup - self.size = size - # Agent indicates the type of thread group - self.agent = agent - - -class _PipelineOp(enum.Enum): - """ - PipelineOp assigns an operation to an agent corresponding to a specific hardware feature. - """ - - # async-threads - AsyncThread = enum.auto() - # Blackwell (SM100a) MMA instruction - TCGen05Mma = enum.auto() - # Tensor Memory Accelerator load - TmaLoad = enum.auto() - # TMA Store consuming smem produced by AsyncThread - TmaStore = enum.auto() - - -def _get_pipeline_op(type_str): - return _PipelineOp(type_str) - - -############################################################################## -# SyncObjectArray class -############################################################################## - - -class SyncObjectArray(ABC): - """ - SyncObjectArray is an abstract base class for different types of hardware synchronizations (e.g. smem barriers, named barriers, fences) - """ - - @abstractmethod - def wait(self): - pass - - @abstractmethod - def arrive(self): - pass - - @abstractmethod - def get_barrier(self): - pass - - -class MbarrierArray(SyncObjectArray): - """ - MbarrierArray implements an abstraction for an array of smem barriers. - """ - - def __init__( - self, - barrier_storage: cute.Pointer, - num_stages: int, - agent: tuple[_PipelineOp, CooperativeGroup], - tx_count: int = 0, - ): - self.barrier_storage = barrier_storage - self.tx_count = tx_count - self.num_stages = num_stages - self.op_type, self.cg = agent - self.arrive_count = self.cg.size - - if self.num_stages <= 0: - raise ValueError("Error: Mbarrier stage count must be greater than 0.") - if self.arrive_count <= 0: - raise ValueError("Error: Mbarrier arrive count must be greater than 0.") - if self.op_type is _PipelineOp.TmaLoad and self.tx_count <= 0: - raise ValueError( - "Error: Mbarrier tx count must be greater than 0 for TMA ops." - ) - - # Store mbarrier base pointer - self.mbarrier_base = self.barrier_storage - - # Mbarrier initialization in constructor - self.mbarrier_init() - - # Mbarrier initialization - def mbarrier_init(self): - """ - Initializes an array of mbarriers using warp 0. - """ - - def then_body(): - for index in range(self.num_stages): - cute.arch.mbarrier_init_arrive_cnt( - self.get_barrier(index), self.arrive_count - ) - - warp_idx = cute.arch.warp_idx() - warp_idx = cute.arch.make_warp_uniform(warp_idx) - - if_generate(warp_idx == 0, then_body) - - def arrive( - self, - index: int, - dst: int, - cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None, - ): - """ - Select the arrive corresponding to this MbarrierArray's PipelineOp - :param index: Index of the mbarrier in the array to arrive on - :type index: int - :param dst: Destination parameter for selective arrival, which can be either a mask or destination cta rank. When None, both TCGen05Mma and AsyncThread will arrive on their local mbarrier. - - For TCGen05Mma, dst serves as a multicast mask (e.g., 0b1011 allows arrive signal to be multicast to CTAs in the cluster with rank = 0, 1, and 3). - - For AsyncThread, dst serves as a destination cta rank (e.g., 3 means threads will arrive on the mbarrier with rank = 3 in the cluster). - :type dst: int | None - :param cta_group: CTA group for TCGen05Mma, defaults to None for other op types - :type cta_group: cute.nvgpu.tcgen05.CtaGroup, optional - """ - if self.op_type is _PipelineOp.AsyncThread: - self.arrive_mbarrier(index, dst) - elif self.op_type is _PipelineOp.TCGen05Mma: - assert ( - cta_group is not None - ), "Error: CTA group must be provided for TCGen05Mma." - self.arrive_tcgen05mma(index, dst, cta_group) - elif self.op_type in [_PipelineOp.TmaLoad]: - self.arrive_and_expect_tx(index, self.tx_count) - else: - assert False, f"Error: MbarrierArray is not supported for PipelineOp: {_get_pipeline_op(self.op_type)}." - - def arrive_mbarrier(self, index: int, dst_rank: int): - if dst_rank is None: - cute.arch.mbarrier_arrive(self.get_barrier(index)) - else: - cute.arch.mbarrier_arrive(self.get_barrier(index), dst_rank) - - def arrive_tcgen05mma( - self, index: int, mask: int, cta_group: cute.nvgpu.tcgen05.CtaGroup - ): - if mask is None: - with cute.arch.elect_one(): - cute.nvgpu.tcgen05.commit(self.get_barrier(index)) - else: - with cute.arch.elect_one(): - cute.nvgpu.tcgen05.commit( - self.get_barrier(index), - mask, - cta_group, - ) - - def arrive_and_expect_tx(self, index: int, tx_count: int): - with cute.arch.elect_one(): - cute.arch.mbarrier_init_tx_bytes(self.get_barrier(index), tx_count) - - def try_wait(self, index: int, phase: int): - return cute.arch.mbarrier_try_wait(self.get_barrier(index), phase) - - def wait(self, index: int, phase: int): - cute.arch.mbarrier_wait(self.get_barrier(index), phase) - - def get_barrier(self, index: int) -> cute.Pointer: - return self.mbarrier_base + index - - -class TmaStoreFence(SyncObjectArray): - """ - TmaStoreFence is used for a multi-stage epilogue buffer. - """ - - def __init__( - self, - num_stages: int = 0, - ): - if num_stages <= 0: - raise ValueError("Mbarrier stage count must be greater than 0.") - - self.num_stages = num_stages - - def arrive(self): - cute.arch.cp_async_bulk_commit_group() - - def wait(self): - cute.arch.cp_async_bulk_wait_group(self.num_stages - 1, read=True) - - # TmaStoreFence doesn't have mbarriers - def get_barrier(self): - assert ( - False - ), "Error: TmaStoreFence doesn't use mbarriers and cannot return a barrier." - - def tail(self): - cute.arch.cp_async_bulk_wait_group(0, read=True) - - -############################################################################## -# PipelineState class -############################################################################## - - -class PipelineUserType(enum.Enum): - Producer = enum.auto() - Consumer = enum.auto() - - -class PipelineState: - """ - Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer. - """ - - def __init__(self, stages: int, count, index, phase): - self._stages = stages - self._count = count - self._index = index - self._phase = phase - - def clone(self) -> "PipelineState": - return PipelineState(self.stages, self._count, self.index, self.phase) - - @property - def index(self) -> Int32: - return self._index - - @property - def count(self) -> Int32: - return self._count - - @property - def stages(self) -> int: - return self._stages - - @property - def phase(self) -> Int32: - return self._phase - - def reset_count(self): - self._count = Int32(0) - - def advance(self): - self._index += 1 - self._count += 1 - - def then_body(index, phase): - new_index = Int32(0) - new_phase = phase ^ 1 - return new_index, new_phase - - def else_body(index, phase): - return index, phase - - self._index, self._phase = if_generate( - self._index == self.stages, - then_body, - else_body, - [self.index, self.phase], - [Int32, Int32], - ) - - def reverse(self): - self._index -= 1 - self._count -= 1 - - def then_body(index, phase): - new_index = Int32(self.stages - 1) - new_phase = phase ^ 1 - return new_index, new_phase - - def else_body(index, phase): - return index, phase - - self._index, self._phase = if_generate( - self._index == -1, - then_body, - else_body, - [self.index, self.phase], - [Int32, Int32], - ) - - def __get_mlir_types__(self): - return [self._count.type, self._index.type, self._phase.type] - - def __extract_mlir_values__(self): - count = self._count - index = self._index - phase = self._phase - return [count.ir_value(), index.ir_value(), phase.ir_value()] - - # This can be overridden by derived classes - def __new_from_mlir_values__(self, values): - return PipelineState( - self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2]) - ) - - -def make_pipeline_state(type: PipelineUserType, stages: int): - """ - Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1. - """ - if type is PipelineUserType.Producer: - return PipelineState( - stages, - Int32(0), - Int32(0), - Int32(1), - ) - elif type is PipelineUserType.Consumer: - return PipelineState( - stages, - Int32(0), - Int32(0), - Int32(0), - ) - else: - assert ( - False - ), "Error: invalid PipelineUserType specified for make_pipeline_state." - - -############################################################################## -# Pipeline classes -############################################################################## - - -@dataclass(frozen=True) -class PipelineAsync: - """ - PipelineAsync is a generic pipeline class where both the producer and consumer are - AsyncThreads. It also serves as a base class for specialized pipeline classes. - """ - - sync_object_array_full: SyncObjectArray - sync_object_array_empty: SyncObjectArray - num_stages: Int32 - producer_mask: Int32 - consumer_mask: Int32 - - @staticmethod - def _make_sync_object_array( - barrier_storage: cute.Pointer, - num_stages: Int32, - agent: tuple[_PipelineOp, CooperativeGroup], - tx_count: int = 0, - ) -> SyncObjectArray: - """ - Returns a SyncObjectArray corresponding to an agent's PipelineOp. - """ - if agent[0] in [ - _PipelineOp.AsyncThread, - _PipelineOp.TmaLoad, - _PipelineOp.TCGen05Mma, - ]: - return MbarrierArray( - barrier_storage=barrier_storage, - num_stages=num_stages, - agent=agent, - tx_count=tx_count, - ) - elif agent[0] is _PipelineOp.TmaStore: - # Path taken for AsyncTmaStore - return TmaStoreFence(num_stages=num_stages) - else: - assert False, "Error: Invalid PipelineOp specified." - - @staticmethod - def create( - barrier_storage: cute.Pointer, - num_stages: Int32, - producer_group: CooperativeGroup, - consumer_group: CooperativeGroup, - producer_mask: Int32 = None, - consumer_mask: Int32 = None, - ): - """ - This helper function computes any necessary attributes and returns an instance of PipelineAsync. - :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers - :type barrier_storage: cute.Pointer - :param num_stages: Number of buffer stages for this pipeline - :type num_stages: Int32 - :param producer_group: CooperativeGroup for the producer agent - :type producer_group: CooperativeGroup - :param consumer_group: CooperativeGroup for the consumer agent - :type consumer_group: CooperativeGroup - :param producer_mask: Mask for signaling arrives for the producer agent - :type producer_mask: Int32 | None - :param consumer_mask: Mask for signaling arrives for the consumer agent - :type consumer_mask: Int32 | None - """ - producer_type = _PipelineOp.AsyncThread - consumer_type = _PipelineOp.AsyncThread - - producer = (producer_type, producer_group) - consumer = (consumer_type, consumer_group) - - sync_object_array_full = PipelineAsync._make_sync_object_array( - barrier_storage.align(min_align=8), num_stages, producer - ) - sync_object_array_empty = PipelineAsync._make_sync_object_array( - barrier_storage.align(min_align=8) + num_stages, num_stages, consumer - ) - - pipeline_init_wait() - - return PipelineAsync( - sync_object_array_full, - sync_object_array_empty, - num_stages, - producer_mask, - consumer_mask, - ) - - def producer_acquire( - self, state: PipelineState, try_acquire_token: Optional[Boolean] = None - ): - if_generate( - try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_array_empty.wait(state.index, state.phase), - ) - - def producer_try_acquire(self, state: PipelineState): - return self.sync_object_array_empty.try_wait(state.index, state.phase) - - def producer_commit(self, state: PipelineState): - self.sync_object_array_full.arrive(state.index, self.producer_mask) - - def consumer_wait( - self, state: PipelineState, try_wait_token: Optional[Boolean] = None - ): - if_generate( - try_wait_token is None or try_wait_token == 0, - lambda: self.sync_object_array_full.wait(state.index, state.phase), - ) - - def consumer_try_wait(self, state: PipelineState): - return self.sync_object_array_full.try_wait(state.index, state.phase) - - def consumer_release(self, state: PipelineState): - self.sync_object_array_empty.arrive(state.index, self.consumer_mask) - - def producer_get_barrier(self, state: PipelineState) -> cute.Pointer: - return self.sync_object_array_full.get_barrier(state.index) - - def producer_tail(self, state: PipelineState): - """ - Make sure the last used buffer empty signal is visible to producer. - Producer tail is usually executed by producer before exit, to avoid dangling - mbarrier arrive signals after kernel exit. - - :param state: The pipeline state that points to next useful buffer - :type state: PipelineState - """ - # Assume state contains that next useful buffer - # So we only need to advance to num_stages - 1 times to last used buffer - for i in range(self.num_stages - 1): - state.advance() - self.producer_acquire(state) - - -@dataclass(frozen=True) -class PipelineTmaAsync(PipelineAsync): - """ - PipelineTmaAsync is used for TMA producers and AsyncThread consumers (e.g. Hopper mainloops). - """ - - is_signalling_thread: bool - - @staticmethod - def init_empty_barrier_arrive_signal(cta_layout_vmnk: cute.Layout): - """ - Initialize the empty barrier arrive signal - This function returns the destination cta rank and a boolean indicating if the signalling thread is the same as the current thread - """ - # Logic to optimally schedule Empty Arrives - cluster_shape_mnk = cta_layout_vmnk.shape - tidx, _, _ = cute.arch.thread_idx() - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - - is_signalling_thread = tidx < cute.size(cluster_shape_mnk) - dst_rank = tidx % cute.size(cluster_shape_mnk) - m = cluster_shape_mnk[0] - - # Check if same row - is_same_row_l = dst_rank % m - is_same_row_r = cta_rank_in_cluster % m - is_same_row = is_same_row_l == is_same_row_r - - # Check if same column - is_same_col_l = dst_rank // m - is_same_col_r = cta_rank_in_cluster // m - - is_same_col = is_same_col_l == is_same_col_r - - is_same_row_or_col = or_(is_same_row, is_same_col) - is_signalling_thread_final = and_(is_signalling_thread, is_same_row_or_col) - - return dst_rank, is_signalling_thread_final - - @staticmethod - def create( - barrier_storage: cute.Pointer, - num_stages: Int32, - producer_group: CooperativeGroup, - consumer_group: CooperativeGroup, - tx_count: int, - cta_layout_vmnk: Optional[cute.Layout] = None, - ): - """ - This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync. - :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers - :type barrier_storage: cute.Pointer - :param num_stages: Number of buffer stages for this pipeline - :type num_stages: Int32 - :param producer_group: CooperativeGroup for the producer agent - :type producer_group: CooperativeGroup - :param consumer_group: CooperativeGroup for the consumer agent - :type consumer_group: CooperativeGroup - :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage - :type tx_count: int - :param cta_layout_vmnk: Layout of the cluster shape - :type cta_layout_vmnk: cute.Layout | None - """ - producer_type = _PipelineOp.TmaLoad - consumer_type = _PipelineOp.AsyncThread - - producer = (producer_type, producer_group) - consumer = (consumer_type, consumer_group) - - sync_object_array_full = PipelineAsync._make_sync_object_array( - barrier_storage.align(min_align=8), num_stages, producer, tx_count - ) - sync_object_array_empty = PipelineAsync._make_sync_object_array( - barrier_storage.align(min_align=8) + num_stages, num_stages, consumer - ) - - dst_rank, is_signalling_thread = ( - PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk) - ) - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: - dst_rank = None - else: - dst_rank = dst_rank - - is_signalling_thread = is_signalling_thread - producer_mask = None - - pipeline_init_wait(cta_layout_vmnk) - - return PipelineTmaAsync( - sync_object_array_full, - sync_object_array_empty, - num_stages, - producer_mask, - dst_rank, - is_signalling_thread, - ) - - def producer_acquire( - self, state: PipelineState, try_acquire_token: Optional[Boolean] = None - ): - """ - TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. - """ - if_generate( - try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_array_empty.wait(state.index, state.phase), - ) - self.sync_object_array_full.arrive(state.index, self.producer_mask) - - def producer_commit(self, state: PipelineState): - """ - TMA producer commit is a NOP. The transaction barrier signals the commit upon completion of the TMA. - """ - pass - - def consumer_release(self, state: PipelineState): - """ - TMA consumer release conditionally signals the empty buffer to the producer. - """ - if_generate( - self.is_signalling_thread, - lambda: self.sync_object_array_empty.arrive( - state.index, self.consumer_mask - ), - ) - - -@dataclass(frozen=True) -class PipelineTmaUmma(PipelineAsync): - """ - PipelineTmaUmma is used for TMA producers and UMMA consumers (e.g. Blackwell mainloops). - """ - - is_leader_cta: bool - cta_group: cute.nvgpu.tcgen05.CtaGroup - - @staticmethod - def _compute_mcast_arrival_mask(cta_layout_vmnk: cute.Layout): - """ - Computes a mask for signaling arrivals to multicasting threadblocks. - """ - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster) - - tma_mcast_mask_a = cute.nvgpu.cpasync.create_tma_multicast_mask( - cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=2 - ) - tma_mcast_mask_b = cute.nvgpu.cpasync.create_tma_multicast_mask( - cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=1 - ) - - block_in_cluster_coord_vmnk_peer = ( - cta_in_cluster_coord_vmnk[0] ^ 1, - *cta_in_cluster_coord_vmnk[1:], - ) - tma_mcast_mask_a_peer = cute.nvgpu.cpasync.create_tma_multicast_mask( - cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=2 - ) - tma_mcast_mask_b_peer = cute.nvgpu.cpasync.create_tma_multicast_mask( - cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=1 - ) - return ( - tma_mcast_mask_a - | tma_mcast_mask_b - | tma_mcast_mask_a_peer - | tma_mcast_mask_b_peer - ) - - @staticmethod - def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout): - """ - Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders. - """ - bidx, bidy, _ = cute.arch.block_idx() - - mma_coord_vmnk = ( - bidx % cute.size(cta_layout_vmnk, mode=[0]), - bidx // cute.size(cta_layout_vmnk, mode=[0]), - bidy, - None, - ) - return mma_coord_vmnk[0] == 0 - - @staticmethod - def create( - barrier_storage: cute.Pointer, - num_stages: Int32, - producer_group: CooperativeGroup, - consumer_group: CooperativeGroup, - tx_count: int, - cta_layout_vmnk: Optional[cute.Layout] = None, - ): - """ - This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma. - :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers - :type barrier_storage: cute.Pointer - :param num_stages: Number of buffer stages for this pipeline - :type num_stages: Int32 - :param producer_group: CooperativeGroup for the producer agent - :type producer_group: CooperativeGroup - :param consumer_group: CooperativeGroup for the consumer agent - :type consumer_group: CooperativeGroup - :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage - :type tx_count: int - :param cta_layout_vmnk: Layout of the cluster shape - :type cta_layout_vmnk: cute.Layout | None - """ - producer_type = _PipelineOp.TmaLoad - consumer_type = _PipelineOp.TCGen05Mma - - producer = (producer_type, producer_group) - consumer = (consumer_type, consumer_group) - - sync_object_array_full = PipelineAsync._make_sync_object_array( - barrier_storage.align(min_align=8), num_stages, producer, tx_count - ) - sync_object_array_empty = PipelineAsync._make_sync_object_array( - barrier_storage.align(min_align=8) + num_stages, num_stages, consumer - ) - - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: - # No mcast mask if not using clusters - producer_mask = None - # All threadblocks are leaders if not using clusters - is_leader_cta = True - else: - producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk) - is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk) - - cta_group = ( - cute.nvgpu.tcgen05.CtaGroup.ONE - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 - else cute.nvgpu.tcgen05.CtaGroup.TWO - ) - - consumer_mask = producer_mask - - pipeline_init_wait(cta_layout_vmnk) - - return PipelineTmaUmma( - sync_object_array_full, - sync_object_array_empty, - num_stages, - producer_mask, - consumer_mask, - is_leader_cta, - cta_group, - ) - - def consumer_release(self, state: PipelineState): - """ - UMMA consumer release buffer empty, cta_group needs to be provided. - """ - self.sync_object_array_empty.arrive( - state.index, self.consumer_mask, self.cta_group - ) - - def producer_acquire( - self, state: PipelineState, try_acquire_token: Optional[Boolean] = None - ): - """ - TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. - """ - if_generate( - try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_array_empty.wait(state.index, state.phase), - ) - if_generate( - self.is_leader_cta, - lambda: self.sync_object_array_full.arrive(state.index, self.producer_mask), - ) - - def producer_commit(self, state: PipelineState): - """ - TMA producer commit is a NOP. The transaction barrier signals the commit upon completion of the TMA. - """ - pass - - -@dataclass(frozen=True) -class PipelineUmmaAsync(PipelineAsync): - """ - PipelineTmaUmma is used for UMMA producers and AsyncThread consumers (e.g. Blackwell accumulator pipelines). - """ - - cta_group: cute.nvgpu.tcgen05.CtaGroup - - @staticmethod - def _compute_tmem_sync_mask(cta_layout_vmnk: cute.Layout): - """ - Computes a mask to signal completion of tmem buffers for 2CTA kernels. - """ - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster) - return cute.make_layout_image_mask( - cta_layout_vmnk, cta_in_cluster_coord_vmnk, mode=0 - ) - - @staticmethod - def _compute_peer_cta_rank(): - """ - Computes a mask to signal release of tmem buffers for 2CTA kernels. - """ - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - return cta_rank_in_cluster // 2 * 2 - - @staticmethod - def create( - barrier_storage: cute.Pointer, - num_stages: Int32, - producer_group: CooperativeGroup, - consumer_group: CooperativeGroup, - cta_layout_vmnk: Optional[cute.Layout] = None, - ): - """ - This helper function computes any necessary attributes and returns an instance of PipelineUmmaAsync. - :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers - :type barrier_storage: cute.Pointer - :param num_stages: Number of buffer stages for this pipeline - :type num_stages: Int32 - :param producer_group: CooperativeGroup for the producer agent - :type producer_group: CooperativeGroup - :param consumer_group: CooperativeGroup for the consumer agent - :type consumer_group: CooperativeGroup - :param cta_layout_vmnk: Layout of the cluster shape - :type cta_layout_vmnk: cute.Layout | None - """ - producer_type = _PipelineOp.TCGen05Mma - consumer_type = _PipelineOp.AsyncThread - - producer = (producer_type, producer_group) - consumer = (consumer_type, consumer_group) - - sync_object_array_full = PipelineAsync._make_sync_object_array( - barrier_storage.align(min_align=8), num_stages, producer - ) - sync_object_array_empty = PipelineAsync._make_sync_object_array( - barrier_storage.align(min_align=8) + num_stages, num_stages, consumer - ) - - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: - # Set mask to None if not using clusters (i.e. 1CTA kernels) - producer_mask = None - else: - producer_mask = PipelineUmmaAsync._compute_tmem_sync_mask(cta_layout_vmnk) - - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1: - # Set mask to None if not using 2CTA intructions - consumer_mask = None - else: - consumer_mask = PipelineUmmaAsync._compute_peer_cta_rank() - - cta_group = ( - cute.nvgpu.tcgen05.CtaGroup.ONE - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 - else cute.nvgpu.tcgen05.CtaGroup.TWO - ) - - pipeline_init_wait(cta_layout_vmnk) - - return PipelineUmmaAsync( - sync_object_array_full, - sync_object_array_empty, - num_stages, - producer_mask, - consumer_mask, - cta_group, - ) - - def producer_commit(self, state: PipelineState): - """ - UMMA producer commit buffer full, cta_group needs to be provided. - """ - self.sync_object_array_full.arrive( - state.index, self.producer_mask, self.cta_group - ) - - def producer_tail(self, state: PipelineState): - """ - Make sure the last used buffer empty signal is visible to producer. - Producer tail is usually executed by producer before exit, to avoid dangling - mbarrier arrive signals after kernel exit. - - :param state: The pipeline state that points to next useful buffer - :type state: PipelineState - """ - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - is_leader_cta = cta_rank_in_cluster % 2 == 0 - - def then_body(): - # Assume state contains that next useful buffer - # So we only need to advance to num_stages - 1 times to last used buffer - for i in range(self.num_stages - 1): - state.advance() - self.producer_acquire(state) - - if_generate(is_leader_cta, then_body) - - -@dataclass(frozen=True) -class PipelineTmaStore(PipelineAsync): - """ - PipelineTmaStore is used for synchronizing TMA stores in the epilogue. It does not use mbarriers. - """ - - @staticmethod - def create( - num_stages: Int32, - producer_group: CooperativeGroup, - ): - """ - This helper function computes any necessary attributes and returns an instance of PipelineTmaStore. - :param num_stages: Number of buffer stages for this pipeline - :type num_stages: Int32 - :param producer_group: CooperativeGroup for the producer agent - :type producer_group: CooperativeGroup - """ - producer_type = _PipelineOp.TmaStore - - producer = (producer_type, producer_group) - - sync_object_array_full = PipelineAsync._make_sync_object_array( - None, num_stages, producer - ) - - return PipelineTmaStore(sync_object_array_full, None, num_stages, None, None) - - def producer_acquire(self): - self.sync_object_array_full.wait() - - def producer_commit(self): - self.sync_object_array_full.arrive() - - def consumer_wait(self): - assert False, "Error: PipelineTmaStore does not have a consumer agent." - - def consumer_release(self): - assert False, "Error: PipelineTmaStore does not have a consumer agent." - - def producer_tail(self): - self.sync_object_array_full.tail() - - -############################################################################## -# Helper functions -############################################################################## - - -def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None): - """ - Fences the mbarrier init and syncs the threadblock or cluster - """ - cute.arch.mbarrier_init_fence() - - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: - # If not using clusters, sync the threadblock - _sync(Agent.ThreadBlock) - else: - # If using clusters, sync the cluster - _sync(Agent.ThreadBlockCluster) - - -def _sync(group: Agent): - """ - Syncs all threads within an agent. - """ - if group is Agent.Thread: - assert False, "Error: Not supported." - elif group is Agent.ThreadBlock: - cute.arch.sync_threads() - elif group is Agent.ThreadBlockCluster: - cute.arch.cluster_arrive() - cute.arch.cluster_wait() - else: - assert ( - False - ), "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead." - - -def _mbarrier_i64_to_ptr(val: Int64) -> cute.Pointer: - """ - Converts a smem pointer of type Int64 to cute.Pointer with 8B alignment - """ - return cute.make_ptr( - Int64, - val.ir_value(), - mem_space=_cute_ir.AddressSpace.smem, - assumed_align=8, - ) diff --git a/python/CuTeDSL/cutlass/utils/smem_allocator.py b/python/CuTeDSL/cutlass/utils/smem_allocator.py index a2d72698..ded0ae43 100644 --- a/python/CuTeDSL/cutlass/utils/smem_allocator.py +++ b/python/CuTeDSL/cutlass/utils/smem_allocator.py @@ -18,41 +18,25 @@ from cutlass.cute.arch import get_dyn_smem class SmemAllocator: - """ - A class for managing shared memory allocation on GPU. + """A class for managing shared memory allocation on GPU. - This class manages a chunk of shared memory and provide APIs for sub-allocation + This class manages a chunk of shared memory and provides APIs for sub-allocation inside the chunk. - Attributes - ---------- - _base : cute.Pointer as i8 typed dynamic value - The current base address of the shared memory. + :ivar _base: The current base address of the shared memory as an i8 typed dynamic value. + :type _base: cute.Pointer + :ivar _allocated_bytes: The total number of bytes allocated in shared memory. + :type _allocated_bytes: int - _allocated_bytes: - The bytes allocated in shared memory. - - Methods - ------- - allocate(num_bytes, alignment) - Allocates num_bytes in the shared memory with the given byte alignment. - - allocate_value(value_ty, num_elems) - Allocates num_elems of value_ty values in the shared memory. - - allocate_tensor(value_ty, layout, alignment) - Allocates a tensor in the shared memory with given layout and byte alignment. - - Notes - ----- - This class is responsible for managing the allocation of tensors in shared memory. + .. note:: + This class is responsible for managing the allocation of tensors in shared memory. + The base pointer is aligned to 1024 bytes upon initialization. """ def __init__(self): - """ - Initializes the SmemAllocator instance with dynamic smem base ptr, - which is i8 type and aligned to 1024. + """Initialize the SmemAllocator instance. + Creates a dynamic shared memory base pointer of type i8, aligned to 1024 bytes. """ self._base = get_dyn_smem(Int8, alignment=1024) self._allocated_bytes = 0 @@ -64,30 +48,19 @@ class SmemAllocator: def allocate(self, size_or_type: cute.struct, byte_alignment: int): ... def allocate(self, size_or_type, byte_alignment: int = 1) -> int: + """Allocate a block of memory with specified size and alignment. + + This method adjusts the base pointer to ensure proper alignment and updates + the internal state to track allocated memory. + + :param size_or_type: The number of bytes to allocate or a struct class + :type size_or_type: Union[int, cute.struct] + :param byte_alignment: The byte alignment requirement, defaults to 1 (no alignment) + :type byte_alignment: int, optional + :return: Pointer to the start of the allocated memory block or struct instance + :rtype: cute.Pointer + :raises ValueError: If size is negative or alignment is less than 1 """ - Allocates a block of memory with the specified size and byte alignment. - - This method adjusts the base cute.Pointer to ensure that the allocated memory - is aligned according to the specified byte alignment. It updates the internal - state to reflect the new base cute.Pointer and the total allocated bytes. - - Parameters - ---------- - size_or_type : int or struct - The number of bytes to allocate or struct class. - byte_alignment : int - The byte alignment requirement for the allocation. Defaults to 1 (no alignment). - - Returns - ---------- - A cute.Pointer to the start of the allocated memory block or struct instance. - - Raises - ---------- - ValueError - If num_bytes is negative or if byte_alignmemt is less than 1. - """ - if isinstance(size_or_type, cute.struct): alignment = max(byte_alignment, size_or_type.__alignof__()) base_ptr = self.allocate(size_or_type.__sizeof__(), alignment) @@ -110,27 +83,16 @@ class SmemAllocator: return ptr def allocate_array(self, element_type: Type[Numeric], num_elems: int = 1): - """ - Allocates num_elems values of element_type in shared memory. + """Allocate an array of elements in shared memory. - This method calls allocate() to return a byte ptr, pointing to start of shared - memory. Then calls cute.recast_ptr() to recast this byte cute.Pointer to element_type. - - Parameters - ---------- - element_type : Type[Numeric] - The type of the values in the tensor. - num_elems : int, optional - The number of elements for each allocation. Defaults to 1. - - Returns - ---------- - A value_type cute.Pointer to the start of the allocated memory block. - - Raises - ---------- - ValueError - If num_elems is less than 1. + :param element_type: The type of elements to allocate + :type element_type: Type[Numeric] + :param num_elems: Number of elements to allocate, defaults to 1 + :type num_elems: int, optional + :return: Pointer to the start of the allocated array + :rtype: cute.Pointer + :raises ValueError: If num_elems is less than 1 + :raises TypeError: If element_type is not a Numeric type """ if num_elems < 1: raise ValueError("num_elems must be at least 1") @@ -152,28 +114,21 @@ class SmemAllocator: byte_alignment: int = 1, swizzle: cute.Swizzle = None, ): - """ - Allocates a tensor in the shared memory with value type, layout and byte alignment. + """Allocate a tensor in shared memory. - Parameters - ---------- - element_type : Type[Numeric] - The type of the values in the tensor. - layout : int | DynamicInt | cute.Layout | cute.ComposedLayout - The layout of the tensor. - byte_alignment : int, optional - The byte alignment requirement for the allocation. Defaults to 1 (no alignment). - swizzle : cute.Swizzle - A swizzle for the iterator (for position-dependent swizzling). - - Returns - ------- - tensor : cute.Tensor - The allocated tensor with specified value type, layout and byte alignment. - - Notes - ----- - The base address is updated to point to the next available memory location. + :param element_type: The type of elements in the tensor + :type element_type: Type[Numeric] + :param layout: The layout specification for the tensor + :type layout: Union[int, cute.Layout, cute.ComposedLayout] + :param byte_alignment: The byte alignment requirement, defaults to 1 + :type byte_alignment: int, optional + :param swizzle: Swizzle for position-dependent swizzling, defaults to None + :type swizzle: cute.Swizzle, optional + :return: The allocated tensor with specified properties + :rtype: cute.Tensor + :raises TypeError: If element_type is not a Numeric type or if swizzle conflicts with layout + :raises ValueError: If allocation is not byte-aligned + :raises NotImplementedError: If dynamic layout is specified """ if not isinstance(element_type, NumericMeta): raise TypeError( diff --git a/python/CuTeDSL/cutlass_dsl/__init__.py b/python/CuTeDSL/cutlass_dsl/__init__.py index 9c6861c3..4d4b4ad2 100644 --- a/python/CuTeDSL/cutlass_dsl/__init__.py +++ b/python/CuTeDSL/cutlass_dsl/__init__.py @@ -23,6 +23,11 @@ from ..base_dsl.ast_helpers import ( dynamic_expr, assert_executor, bool_cast, + compare_executor, + any_executor, + all_executor, + range_value_check, + range_perf_warning, ) from ..base_dsl import * diff --git a/python/CuTeDSL/cutlass_dsl/cutlass.py b/python/CuTeDSL/cutlass_dsl/cutlass.py index 3cbd6874..9a61a746 100644 --- a/python/CuTeDSL/cutlass_dsl/cutlass.py +++ b/python/CuTeDSL/cutlass_dsl/cutlass.py @@ -20,6 +20,7 @@ from inspect import isclass import functools import pkgutil from dataclasses import is_dataclass +from collections.abc import Sequence from ..base_dsl import * from ..base_dsl import compiler @@ -51,6 +52,11 @@ from ..base_dsl.ast_helpers import ( while_executor, assert_executor, bool_cast, + compare_executor, + any_executor, + all_executor, + range_value_check, + range_perf_warning, ) from ..base_dsl.runtime.dlpack_runtime import ( get_cute_tensor_c_pointer, @@ -67,18 +73,6 @@ from .cutlass_ast_decorators import ( _while_execute_dynamic, ) -# ============================================================================= -# Set the AST decorator -# ============================================================================= - -# Set the DSL specific functions -executor.set_functions( - is_dynamic_expression, - _loop_execute_range_dynamic, - _if_execute_dynamic, - _while_execute_dynamic, -) - # ============================================================================= # Cutlass DSL Base Abstract Class @@ -1023,7 +1017,6 @@ def select_(cond, if_value, else_value): ) return value - # Non-DSL dynamic cond should be handled before this. if const_expr(not is_dynamic_expression(cond)): raise DSLRuntimeError("Conditional expression must be dynamic") @@ -1089,6 +1082,7 @@ def for_generate( iter_args: Optional[Sequence[ir.Value]] = None, *, unroll: LoopUnroll = None, + pipelining=None, loc=None, ip=None, ): @@ -1126,6 +1120,9 @@ def for_generate( if unroll is not None: for_op.attributes["loop_annotation"] = unroll + if pipelining is not None: + for_op.attributes["cutlass.pipelining"] = _createI32Attr(pipelining) + iv = for_op.induction_variable new_results = new_from_mlir_values(iter_args, for_op.results) new_iter_args = new_from_mlir_values(iter_args, for_op.inner_iter_args) @@ -1319,3 +1316,122 @@ def while_generate( Generate a WhileLoopContext for a dynamic loop. """ return WhileLoopContext(inputs, condition, loc=loc, ip=ip) + + +def equal(lhs, rhs): + if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): + return lhs == rhs + + # Both sequence + if isinstance(lhs, Sequence) and isinstance(rhs, Sequence): + # Short-circuit for unequal length + if len(lhs) != len(rhs): + return False + return all_(equal(l, r) for l, r in zip(lhs, rhs)) + return lhs == rhs + + +def in_(lhs, rhs, op): + if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): + return lhs in rhs + + if not isinstance(rhs, Sequence): + raise DSLRuntimeError( + f"'{op}' not supported between instances of {type(lhs)} and {type(rhs)}" + ) + + return any_(equal(lhs, r) for r in rhs) + + +def _lt_gt(lhs, rhs, op): + def native_lt_gt(lhs, rhs, op): + if op == "<": + return lhs < rhs + elif op == ">": + return lhs > rhs + else: + raise DSLRuntimeError(f"Unsupported comparison operator: {op}") + + if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): + return native_lt_gt(lhs, rhs, op) + + # Both sequence, comparisons other than == and != do not allow mixing different types of sequences + if ( + isinstance(lhs, Sequence) + and isinstance(rhs, Sequence) + and type(lhs) == type(rhs) + ): + unequal_found = False + comp_results = [] + mask = [] + for l, r in zip(lhs, rhs): + is_equal = equal(l, r) + mask.append(not_(or_(is_equal, unequal_found))) + unequal_found = not_(is_equal) + comp_results.append(_lt_gt(l, r, op)) + + result = any_(and_(r, m) for r, m in zip(comp_results, mask)) + + if len(lhs) != len(rhs): + # Ref https://docs.python.org/3/tutorial/datastructures.html#comparing-sequences-and-other-types + # If one sequence is an initial sub-sequence of the other, the shorter sequence is the smaller (lesser) one + has_valid_mask = any_(mask) + if op == "<": + length_result = len(lhs) < len(rhs) + elif op == ">": + length_result = len(lhs) > len(rhs) + if type(has_valid_mask) == bool: + return result if has_valid_mask else length_result + else: + return select_(has_valid_mask, result, length_result) + else: + return result + else: + return native_lt_gt(lhs, rhs, op) + + +def greater_than(lhs, rhs): + return _lt_gt(lhs, rhs, ">") + + +def less_than(lhs, rhs): + return _lt_gt(lhs, rhs, "<") + + +def _compare_executor(left, comparators, ops): + result = left + for comparator, op in zip(comparators, ops): + # 'is' and 'is not' are pure python operators + if op == "is": + result = result is comparator + elif op == "is not": + result = result is not comparator + elif op in ["in", "not in"]: + result = in_(left, comparator, op) + elif op in ["==", "!="]: + result = equal(left, comparator) + elif op in ["<", ">="]: + result = less_than(left, comparator) + elif op in [">", "<="]: + result = greater_than(left, comparator) + else: + raise DSLRuntimeError(f"Unsupported comparison operator: {op}") + # Invert the result for NotIn, NotEq, GtE, LtE + if op in ["not in", "!=", ">=", "<="]: + result = not_(result) + return result + +# ============================================================================= +# Set the AST decorator +# ============================================================================= + +# Set the DSL specific functions +executor.set_functions( + is_dynamic_expression, + _loop_execute_range_dynamic, + _if_execute_dynamic, + _while_execute_dynamic, + _compare_executor, + any_, + all_, +) diff --git a/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py b/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py index 56262a76..370a0c9f 100644 --- a/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py +++ b/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py @@ -10,6 +10,7 @@ # is strictly prohibited. from typing import List, Tuple +from types import NoneType from cutlass._mlir import ir from cutlass._mlir.dialects import scf, arith from cutlass._mlir.extras import types as T @@ -145,6 +146,30 @@ class ScfGenerator: # Use the provided terminator generator block_term_op_builder[builder](region_result) else: + # For standard yield op, check result + for arg, result, name in zip( + mix_iter_args, + ( + region_result + if isinstance(region_result, list) + else [region_result] + ), + mix_iter_arg_names, + ): + if isinstance(arg, NoneType) and not isinstance( + result, NoneType + ): + raise DSLRuntimeError( + ( + f"`{name}` is None prior to this `{op_type_name}`, " + f"and update to non-None value inside of this `{op_type_name}` is not supported." + ), + suggestion=( + f"Please make sure `{name}` is not None prior to this `{op_type_name}`, " + f"or mark this `{op_type_name}` with " + f"`{'range' if op_type_name == 'for' else 'const_expr'}`." + ), + ) # Normalize region_result region_result_list = ScfGenerator._normalize_region_result_to_list( region_result @@ -200,6 +225,7 @@ def _loop_execute_range_dynamic( mix_iter_arg_names: List[str] = [], unroll: int = -1, unroll_full: bool = False, + pipelining: int = None, ): """ Example: build an scf.for with optional unroll, using our universal helper. @@ -236,6 +262,18 @@ def _loop_execute_range_dynamic( unroll_attr = LoopUnroll(count=unroll) log().debug("Unroll attribute: %s", unroll_attr) + pipelining_attr = None + if pipelining is not None: + if pipelining >= 0: + pipelining_attr = ir.IntegerAttr.get( + ir.IntegerType.get_signless(32), pipelining + ) + else: + raise DSLRuntimeError( + f"Pipelining must be non-negative, got {pipelining}" + ) + log().debug("Pipelining attribute: %s", pipelining_attr) + log().debug( "Creating scf.ForOp \n\t\tstart=%s: type : %s\n\t\tstop=%s: type : %s\n\t\tstep=%s: type : %s", start_, @@ -265,6 +303,9 @@ def _loop_execute_range_dynamic( if unroll_attr is not None: for_op.attributes["loop_annotation"] = unroll_attr + if pipelining_attr is not None: + for_op.attributes["cutlass.pipelining"] = pipelining_attr + return for_op def for_body_builder( diff --git a/python/CuTeDSL/requirements.txt b/python/CuTeDSL/requirements.txt index 814a296c..aec440f9 100644 --- a/python/CuTeDSL/requirements.txt +++ b/python/CuTeDSL/requirements.txt @@ -1,3 +1,3 @@ # Use `pip install -r requirements.txt` with the present file to install a # wheel consistent with the present state of the github repository -nvidia-cutlass-dsl==4.0.0 +nvidia-cutlass-dsl==4.1.0.dev0 diff --git a/python/cutlass/backend/evt/frontend/frontend_base.py b/python/cutlass/backend/evt/frontend/frontend_base.py index 442a708d..06d3477b 100644 --- a/python/cutlass/backend/evt/frontend/frontend_base.py +++ b/python/cutlass/backend/evt/frontend/frontend_base.py @@ -73,6 +73,7 @@ class EVTFrontendBase: self.dag_ir = DAGIR(self.cc, self.element_compute) self.compute_cnt = 0 self.layout_cnt = 0 + self.imm_cnt = 0 self.pass_manager = EVTPassManager( self.dag_ir, @@ -107,6 +108,13 @@ class EVTFrontendBase: # Parse the input self.parse(*args, **kwargs) + # Verify the DAG IR to ensure that "D" is the output node with out_degree = 0 + if (self.cc >= 90): + if (self.dag_ir.out_degree("D") != 0): + raise RuntimeError( + f"On SM90 or higher, D is expected to be a output node with 0 users to " + f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}") + # Run the passes self.pass_manager() # Set the epilogue type @@ -187,7 +195,8 @@ class EVTFrontendBase: except: raise ValueError(f"{type(value).__name__} cannot be converted to float.") - name = f"imm_{value}".replace('.', '_') + name = f"imm_{value}_k{self.imm_cnt}".replace('.', '_') + self.imm_cnt += 1 load_node = LoadNode(name) load_node.tensor = {"tensor": value, "is_constant": True} self.add_node(load_node) diff --git a/python/cutlass/backend/evt/frontend/python_ast.py b/python/cutlass/backend/evt/frontend/python_ast.py index 14827812..258e9421 100644 --- a/python/cutlass/backend/evt/frontend/python_ast.py +++ b/python/cutlass/backend/evt/frontend/python_ast.py @@ -42,7 +42,7 @@ from cutlass_library import DataType import cutlass from cutlass.backend.evt.frontend.frontend_base import EVTFrontendBase -from cutlass.backend.epilogue import relu +from cutlass.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu from cutlass.backend.library import FunctionalOp @@ -72,10 +72,17 @@ class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor): ast.Div: FunctionalOp.Divides, "maximum": FunctionalOp.Maximum, "minimum": FunctionalOp.Minimum, + "identity": identity.binding_type, "relu": relu.binding_type, + "tanh": tanh.binding_type, + "sigmoid": sigmoid.binding_type, + "silu": silu.binding_type, + "hardswish": hardswish.binding_type, + "gelu": gelu.binding_type, "multiply_add": FunctionalOp.MultiplyAdd, "sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd), - "max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum) + "max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum), + "exp": FunctionalOp.Exp } return mapping[op] diff --git a/python/cutlass/backend/evt/ir/dag_ir.py b/python/cutlass/backend/evt/ir/dag_ir.py index 16281d34..bb9121a7 100644 --- a/python/cutlass/backend/evt/ir/dag_ir.py +++ b/python/cutlass/backend/evt/ir/dag_ir.py @@ -38,7 +38,9 @@ import networkx as nx from cutlass_library import DataType +from cutlass.backend.evt.ir.compute_nodes import ComputeNode from cutlass.backend.evt.ir.node import NodeBase +from cutlass.backend.library import ActivationOp from cutlass.backend.utils import device_cc @@ -59,6 +61,8 @@ class DAGIR: self.cc = cc + self.identity_counter = 0 + # # IR manipulator # @@ -79,7 +83,21 @@ class DAGIR: raise SyntaxError(f"Variable '{src}' is undefined.") if not self.has_node(dst): raise SyntaxError(f"Variable '{dst}' is undefined.") - self._graph.add_edge(src, dst, weight=weight) + + if self._graph.has_edge(src, dst): + # The DiGraph doesn't support multiple edges between two nodes + # We insert an identity node in such case as a workaround + identity_name = f"autogen_identity_{self.identity_counter}" + self.identity_counter += 1 + compute_node = ComputeNode( + name=identity_name, fn=ActivationOp.Identity, + element_output=self.element_compute, + element_compute=self.element_compute) + self.add_node(compute_node) + self.add_edge(src, identity_name, 0) + self.add_edge(identity_name, dst, weight) + else: + self._graph.add_edge(src, dst, weight=weight) def remove_node(self, node: str): """ diff --git a/python/cutlass/backend/evt/ir/tensor.py b/python/cutlass/backend/evt/ir/tensor.py index b8d1bbe0..9eea9f42 100644 --- a/python/cutlass/backend/evt/ir/tensor.py +++ b/python/cutlass/backend/evt/ir/tensor.py @@ -51,15 +51,19 @@ class Tensor: """ The tensor abstracts the data type """ - def __init__(self, tensor=None, element=None, shape=None, layout_tag=None, is_constant=False) -> None: + def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None: if element is not None and tensor is not None: raise Exception(f"Must not specify both element and tensor") elif shape is not None and tensor is not None: raise Exception(f"Must not specify both shape and tensor") elif layout_tag is not None and tensor is not None: raise Exception(f"Must not specify both layout_tag and tensor") - elif (element is None or layout_tag is None or shape is None) and (tensor is None) : - raise Exception(f"Must specify one of (element, shape, layout) or (tensor)") + elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) : + raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)") + elif stride is not None and tensor is not None: + raise Exception(f"Must not specify both stride and tensor") + elif stride is not None and layout_tag is not None: + raise Exception(f"Must not specify layout_tag when stride is provided") if isinstance(tensor, Tensor): # Directly copy all the attributes @@ -70,10 +74,13 @@ class Tensor: else: self.element, layout_tag = get_datatype_and_layout(tensor) shape = get_tensor_shape(tensor) - if layout_tag == LayoutType.RowMajor: - self.layout = Layout(shape[::-1]) - elif layout_tag == LayoutType.ColumnMajor: - self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))]) + if stride is not None: + self.layout = Layout(shape[::-1], stride[::-1]) + else: + if layout_tag == LayoutType.RowMajor: + self.layout = Layout(shape[::-1]) + elif layout_tag == LayoutType.ColumnMajor: + self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))]) self.layout = canonicalization(self.layout) self.is_constant = is_constant diff --git a/python/cutlass/backend/evt/passes/pass_dag_2_tree.py b/python/cutlass/backend/evt/passes/pass_dag_2_tree.py index 91eb2054..9fad1de3 100644 --- a/python/cutlass/backend/evt/passes/pass_dag_2_tree.py +++ b/python/cutlass/backend/evt/passes/pass_dag_2_tree.py @@ -77,11 +77,12 @@ class PassDAG2Tree(EVTPassBase): reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent))) # get the common reachable objects common_items = set.intersection(*reachable_nodes) + node_to_fuse = set.union(*reachable_nodes).difference(common_items) + lca = None # If common ancestor exists, find the lowest one if len(common_items) > 0: topo_order = self.dag_ir.nodes_topological_order() - lca = None topo_idx = -1 for item in common_items: if lca is None: @@ -91,53 +92,74 @@ class PassDAG2Tree(EVTPassBase): if topo_idx > topo_order.index(item): lca = item topo_idx = topo_order.index(item) - # The lca is the output node of the DAG node - # Get the nodes to be fused - node_to_fuse = set.union(*reachable_nodes).difference(common_items) - node_to_fuse.add(lca) - # Get all the input nodes - all_input_nodes = [] - all_output_nodes = [] - for node in node_to_fuse: - all_input_nodes.append(set(self.dag_ir.get_all_inputs(node))) - all_output_nodes.append(set(self.dag_ir.get_users(node))) - all_input_nodes = set.union(*all_input_nodes) - all_output_nodes = set.union(*all_output_nodes) - - new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes) - - # Create the subgraph - subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes) - subgraph = DAGIR(self.dag_ir.cc) - for node in subgraph_.nodes: - meta = deepcopy(self.dag_ir.get_node_meta(node)) - if node not in node_to_fuse: - meta.disabled = True - subgraph.add_node(meta) - for edge in subgraph_.edges: - subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1])) - - - # Create the fused node - dag_node = TopoVisitorNode( - name=f"dag_{lca}", subgraph=subgraph, - output_node=self.dag_ir.get_node_meta(lca)) - self.dag_ir.add_node(dag_node) - - # Add input edges - for idx, node in enumerate(all_input_nodes): - self.dag_ir.add_edge(node, dag_node.name, weight=idx) - - # Replace all uses with DAG node (only 1 output node) - self.dag_ir.replace_all_uses_with(lca, dag_node.name) - - # Remove all fused nodes - node_to_fuse.remove(lca) - for node in node_to_fuse: - self.dag_ir.remove_node(node) - else: - raise NotImplementedError("No LCA found. Consider SplitTreeVisitor.") + # there is no common ancestor for all the parents, we pack all the reachable + # nodes into a single DAG node as a fallback. The lca should be the input node of + # one of the output nodes with out_degree = 0 + potential_output_nodes = [] + for node in node_to_fuse: + if self.dag_ir.out_degree(node) == 0: + potential_output_nodes.append(node) + if len(potential_output_nodes) == 0: + raise RuntimeError(f"No output node with out degree = 0 found.") + + output_node = None + if (self.dag_ir.cc >= 90): + # For SM90, the lca should be the input node of D + if (not self.dag_ir.has_node("D")): + raise RuntimeError(f"D is not a node in the DAG IR.") + output_node = "D" + else: + output_node = potential_output_nodes[0] + + if (output_node is None): + raise RuntimeError(f"No output node found.") + lca = self.dag_ir.get_all_inputs(output_node)[0] + node_to_fuse.remove(output_node) + + # The lca is the output node of the DAG node + # Get the nodes to be fused + node_to_fuse.add(lca) + # Get all the input nodes + all_input_nodes = [] + all_output_nodes = [] + for node in node_to_fuse: + all_input_nodes.append(set(self.dag_ir.get_all_inputs(node))) + all_output_nodes.append(set(self.dag_ir.get_users(node))) + all_input_nodes = set.union(*all_input_nodes) + all_output_nodes = set.union(*all_output_nodes) + + new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes) + + # Create the subgraph + subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes) + subgraph = DAGIR(self.dag_ir.cc) + for node in subgraph_.nodes: + meta = deepcopy(self.dag_ir.get_node_meta(node)) + if node not in node_to_fuse: + meta.disabled = True + subgraph.add_node(meta) + for edge in subgraph_.edges: + subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1])) + + + # Create the fused node + dag_node = TopoVisitorNode( + name=f"dag_{lca}", subgraph=subgraph, + output_node=self.dag_ir.get_node_meta(lca)) + self.dag_ir.add_node(dag_node) + + # Add input edges + for idx, node in enumerate(all_input_nodes): + self.dag_ir.add_edge(node, dag_node.name, weight=idx) + + # Replace all uses with DAG node (only 1 output node) + self.dag_ir.replace_all_uses_with(lca, dag_node.name) + + # Remove all fused nodes + node_to_fuse.remove(lca) + for node in node_to_fuse: + self.dag_ir.remove_node(node) def ensures(self) -> None: # Ensure that after the pass, the resulting DAG becomes a tree diff --git a/python/cutlass/backend/library.py b/python/cutlass/backend/library.py index 4e0812c4..51fe1fe9 100644 --- a/python/cutlass/backend/library.py +++ b/python/cutlass/backend/library.py @@ -118,6 +118,7 @@ class FunctionalOp(enum.Enum): Multiplies = enum_auto() MultiplyAdd = enum_auto() Plus = enum_auto() + Exp = enum_auto() FunctionalOpTag = { @@ -130,6 +131,7 @@ FunctionalOpTag = { FunctionalOp.Multiplies: "cutlass::multiplies", FunctionalOp.MultiplyAdd: "cutlass::multiply_add", FunctionalOp.Plus: "cutlass::plus", + FunctionalOp.Exp: "cutlass::fast_exp_op", } diff --git a/python/cutlass/epilogue/__init__.py b/python/cutlass/epilogue/__init__.py index 3646d9b1..43a0beb6 100644 --- a/python/cutlass/epilogue/__init__.py +++ b/python/cutlass/epilogue/__init__.py @@ -52,4 +52,5 @@ from cutlass.epilogue.evt_ops import ( reshape, maximum, minimum, + exp ) diff --git a/python/cutlass/epilogue/evt_ops.py b/python/cutlass/epilogue/evt_ops.py index aa4ec292..0d2ef36e 100644 --- a/python/cutlass/epilogue/evt_ops.py +++ b/python/cutlass/epilogue/evt_ops.py @@ -73,6 +73,12 @@ def minimum(x, y): elif is_torch_tensor(x): return torch.minimum(x, torch.tensor(y)) +def exp(x): + if is_numpy_tensor(x): + return np.exp(x) + elif is_torch_tensor(x): + return torch.exp(x) + ############################################################################## # Layout manipulate nodes diff --git a/python/cutlass_library/emit_kernel_listing.py b/python/cutlass_library/emit_kernel_listing.py index 243f5adb..0e6d8883 100755 --- a/python/cutlass_library/emit_kernel_listing.py +++ b/python/cutlass_library/emit_kernel_listing.py @@ -297,7 +297,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode sm100_mma_data_type_general = [ 'gemm_f16_f16_f16_f16_f16', 'gemm_f16_f16_f16_void_f16', - 'gemm_f16_f16_f32_f16_f16', + #'gemm_f16_f16_f32_f16_f16', 'tf32gemm_f32_f32_f32_f32_f32', 'bf16gemm_f32_f32_f32_f32_f32', ] @@ -336,7 +336,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2', 'gemm.*ue4m3xf4_ue4m3xf4_f32_f16_e5m2', 'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2', - 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1', + #'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1', 'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2', ] @@ -547,7 +547,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode if dynamic_cluster: if mode == "functional_L0": - runtime_cluster_shapes = [[1,1,1], [2,1,1], [2,2,1], [4,1,1], [4,4,1]] + runtime_cluster_shapes = [[1,1,1], [2,2,1]] else: runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1], [2,4,1], [4,2,1], [4,4,1]] cta_tile_shape_m, cta_tile_shape_n, cta_tile_shape_k = operation.tile_description.threadblock_shape diff --git a/test/python/cutlass/evt/evt_compute_sm80_90.py b/test/python/cutlass/evt/evt_compute_sm80_90.py index dd58eb0a..43a9c02d 100644 --- a/test/python/cutlass/evt/evt_compute_sm80_90.py +++ b/test/python/cutlass/evt/evt_compute_sm80_90.py @@ -117,6 +117,82 @@ class TestEVTCompute(EVTTestCaseBase): input_keys = ["C", "alpha", "beta"] result_keys = ["D"] launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_tanh(self): + """ + Test Tanh op + """ + def evt_tanh(accum): + D = tanh(accum) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_tanh, example_inputs) + input_keys = [] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_sigmoid(self): + """ + Test Sigmoid op + """ + def evt_sigmoid(accum): + D = sigmoid(accum) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_sigmoid, example_inputs) + input_keys = [] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_gelu(self): + """ + Test GELU op + """ + def evt_gelu(accum): + D = gelu(accum) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_gelu, example_inputs) + input_keys = [] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_exp(self): + """ + Test Exp op + """ + def evt_exp(accum): + D = exp(accum) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_exp, example_inputs) + input_keys = [] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) if __name__ == '__main__': unittest.main() diff --git a/test/python/cutlass/evt/evt_mixed_sm80_90.py b/test/python/cutlass/evt/evt_mixed_sm80_90.py index 448e4b70..45392af6 100644 --- a/test/python/cutlass/evt/evt_mixed_sm80_90.py +++ b/test/python/cutlass/evt/evt_mixed_sm80_90.py @@ -49,6 +49,51 @@ cutlass.set_log_level(logging.WARNING) @unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") class TestEVTMixed(EVTTestCaseBase): + + def test_same_variable_used_multiple_times(self): + """ + The same variable z0 is used multiple times + """ + def evt_aux_store(accum): + z0 = relu(accum) + D = z0 + z0 + return z0, D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + "z0": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_aux_store, example_inputs) + input_keys = ["accum"] + result_keys = ["z0", "D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_no_lca(self): + """ + The same variable z0 is used multiple times + """ + def evt_no_lca(accum, bias): + E = relu(accum) + F = E + bias + tmp_2 = E + 2 + D = tmp_2 + E + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + "bias": self.fake_tensor(self.element, (m,1), stride=(1,0)), + } + + launcher = EVTTestBed(self.element, evt_no_lca, example_inputs) + input_keys = ["accum", "bias"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + def test_mixed_dag(self): def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): F = alpha * accum + (beta * C + aux) diff --git a/test/python/cutlass/evt/evt_store_sm80_90.py b/test/python/cutlass/evt/evt_store_sm80_90.py index 772e06b8..9ff3d7d7 100644 --- a/test/python/cutlass/evt/evt_store_sm80_90.py +++ b/test/python/cutlass/evt/evt_store_sm80_90.py @@ -49,6 +49,31 @@ cutlass.set_log_level(logging.WARNING) @unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") class TestEVTStore(EVTTestCaseBase): + @unittest.skipIf(device_cc() != 90, "This test is only for CC 90") + def test_invalid_store(self): + """ + Test invalid store + """ + def evt_invalid_store(accum): + D = accum + F = D + 1 # D has users, which is not allowed on SM90 or higher + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)) + } + with self.assertRaisesRegex( + RuntimeError, + r"On SM90 or higher, D is expected to be a output node with 0 users " + r"to enable smem reuse between C and D, but got 1" + ): + launcher = EVTTestBed(self.element, evt_invalid_store, example_inputs) + + break # Only need to test once + def test_aux_store(self): """ Returning a tensor with shape [m, n] diff --git a/test/python/cutlass/evt/utils/evt_testbed.py b/test/python/cutlass/evt/utils/evt_testbed.py index f5ee2a33..bd027803 100644 --- a/test/python/cutlass/evt/utils/evt_testbed.py +++ b/test/python/cutlass/evt/utils/evt_testbed.py @@ -185,7 +185,9 @@ class EVTTestBed: # Compare the results for result, ref in zip(result_keys, reference_results): - assert torch.equal(epilogue_args[result].flatten(), ref.flatten()) + assert torch.equal( + epilogue_args[result].flatten(), + ref.masked_fill(torch.isnan(ref), float('inf')).flatten()) # Run profile if self.profile: @@ -210,8 +212,11 @@ class EVTTestCaseBase(unittest.TestCase): torch.random.manual_seed(42) - def fake_tensor(self, element, shape): - return Tensor(element=element, shape=shape, layout_tag=cutlass.LayoutType.RowMajor) + def fake_tensor(self, element, shape, stride=None): + if stride is None: + return Tensor(element=element, shape=shape, layout_tag=cutlass.LayoutType.RowMajor) + else: + return Tensor(element=element, shape=shape, stride=stride) def get_problem_sizes(self, alignment, k=None, batch_count=[3,]): k = k if k else self.k diff --git a/test/unit/common/filter_architecture.cpp b/test/unit/common/filter_architecture.cpp index 5ee68881..a21d8698 100644 --- a/test/unit/common/filter_architecture.cpp +++ b/test/unit/common/filter_architecture.cpp @@ -121,6 +121,7 @@ void FilterArchitecture() { { "SM89*", 89, 89}, { "SM90*", 90, 90}, { "SM100*", 100, 100}, + { "*sm100_*", 100, 100}, { 0, 0, false } }; diff --git a/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu index b98ed8eb..56a2989a 100644 --- a/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu +++ b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu @@ -100,6 +100,53 @@ TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op EXPECT_TRUE(test::conv::device::TestAllConv()); } +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f32, 64x64x64_1x1x1_alpha_beta_scaled_bias_relu_residual) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using ElementSrc = cutlass::half_t; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::PerColResAddPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias, ElementSrc>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementSrc, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + // // Cluster tile shape 128x64x64 // Cluster shape 1x1x1 diff --git a/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu index 42c789dc..1253a12d 100644 --- a/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu +++ b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu @@ -234,6 +234,53 @@ TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s EXPECT_TRUE(test::conv::device::TestAllConv()); } +TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_scaled_bias_relu_residual) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementSrc = int8_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::PerColResAddPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias, ElementSrc>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementSrc, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + // alpha != 1 && beta != 0 && bias && gelu TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias_gelu) { using ElementAct = int8_t; diff --git a/test/unit/conv/device_3x/testbed_conv.hpp b/test/unit/conv/device_3x/testbed_conv.hpp index 031e9a91..99ba9c40 100644 --- a/test/unit/conv/device_3x/testbed_conv.hpp +++ b/test/unit/conv/device_3x/testbed_conv.hpp @@ -176,6 +176,8 @@ struct ConvTestbed { static constexpr bool DisableSource = cute::is_void_v; + static constexpr bool IsResidualEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithResidualAdd::value; + using StrideC = typename Conv::ConvKernel::StrideC; using StrideD = typename Conv::ConvKernel::StrideD; using ThreadEpilogueOp = typename Conv::ConvKernel::CollectiveEpilogue::ThreadEpilogueOp; @@ -494,6 +496,7 @@ struct ConvTestbed { ElementCompute, ElementC, ElementD, + IsResidualEnabled, decltype(mAlpha), decltype(mBeta), decltype(mBias), diff --git a/test/unit/cute/core/array_subbyte.cpp b/test/unit/cute/core/array_subbyte.cpp index ccf0f662..aed37a1c 100644 --- a/test/unit/cute/core/array_subbyte.cpp +++ b/test/unit/cute/core/array_subbyte.cpp @@ -148,6 +148,19 @@ TEST(CuTe_core, Subbyte_iterator) } + { + array_subbyte a{}; + auto tensor = make_tensor(a.begin(), make_shape(15)); + + fill(a, uint6b_t(13)); + for (int i = 0; i < int(a.size()); ++i) { + EXPECT_EQ(uint6b_t(tensor(i)), uint6b_t(13)); + tensor(i) = uint6b_t(i); + EXPECT_EQ(uint6b_t(a[i]), uint6b_t(tensor(i))); + } + + } + { array_subbyte a{}; auto tensor = make_tensor(a.begin(), make_shape(15)); diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 00a1fef7..bd6b0629 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -950,3 +950,32 @@ cutlass_test_unit_gemm_device_add_executable( endif() +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_simt_sm100 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm100_gemm_f32_f32_f32_simt_align1.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_simt_sm100_bias_relu + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm100_gemm_f32_f32_f32_simt_align1_bias_relu.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_simt_sm100_ptr_array + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + sm100_gemm_f32_f32_f32_simt_align1_ptr_array.cu +) + diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index b601054f..731076f0 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -190,13 +190,7 @@ private: template auto make_iterator(T* ptr) { - using namespace cute; - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } template @@ -286,26 +280,26 @@ bool initialize_tensor( scope_max = 2; scope_min = 0; } - + else if (bits_input <= 6) { scope_max = 2; scope_min = -2; } - + else if (bits_input <= 8) { - + if constexpr ( cute::is_same_v){ scope_max = 4; scope_min = 1; } else { - + scope_max = 1; scope_min = -1; - + } - + } else{ scope_max = 4; @@ -354,11 +348,11 @@ static constexpr bool is_row_or_col_major(){ // Default MMA input Operands : A , B // template< - class ScheduleType_, - class Gemm, + class ScheduleType_, + class Gemm, class ElementA_ = typename Gemm::GemmKernel::ElementA, class ElementB_ = typename Gemm::GemmKernel::ElementB, - class Enable = void> + class Enable = void> struct HostCollectiveMainloop { // Kernel data types using ElementA = ElementA_; @@ -520,7 +514,7 @@ struct HostCollectiveMainloop { Arguments to_args() { - + // Runtime datatype selection if constexpr (not cute::is_same_v) { using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; @@ -531,13 +525,13 @@ struct HostCollectiveMainloop { }; } else { - - Arguments arguments = + + Arguments arguments = { tensor_A.device_data(), stride_a, tensor_B.device_data(), stride_b }; return arguments; - } + } } auto to_host_args(ProblemShapeType problem_size) { @@ -555,19 +549,19 @@ struct HostCollectiveMainloop { auto B = make_tensor(make_iterator(tensor_B.host_data()), make_layout(make_shape(N, K, L), stride_b)); - + auto dummy_SFA = cute::make_tensor(static_cast(nullptr), cute::make_layout(cute::make_shape(M, K, L), stride_a)); auto dummy_SFB = cute::make_tensor(static_cast(nullptr), cute::make_layout(cute::make_shape(N, K, L), stride_b)); - - cutlass::reference::host::GettMainloopParams mainloop_params{}; mainloop_params.A = A; @@ -631,7 +625,7 @@ template< class ElementB_> struct HostCollectiveMainloopSparse { - + // Kernel data types using ElementA = ElementA_; // CuTe layout A for the kernel's sparse tensorA. @@ -875,8 +869,8 @@ struct HostCollectiveMainloopSparse }; template< - class ScheduleType_, - class Gemm, + class ScheduleType_, + class Gemm, class ElementA_, class ElementB_ > @@ -1076,7 +1070,7 @@ struct HostCollectiveMainloop::layout_factory(a_coord, stride_factor_A)); tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); - + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); @@ -1098,7 +1092,7 @@ struct HostCollectiveMainloop::layout_factory(sfa_coord, stride_factor_A)); tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); @@ -1145,12 +1139,12 @@ struct HostCollectiveMainloop + > mainloop_params{A, SfA, B, SfB}; return mainloop_params; } @@ -1184,7 +1178,7 @@ template< class ElementA_, class ElementB_ > -struct HostCollectiveMainloop, +struct HostCollectiveMainloop, Gemm, ElementA_, ElementB_> : public HostCollectiveMainloop, Gemm, ElementA_, ElementB_> { @@ -1454,7 +1448,7 @@ struct HostCollectiveMainloop::layout_factory(sfa_coord, stride_factor_A)); tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); @@ -1503,12 +1497,12 @@ struct HostCollectiveMainloop + > mainloop_params{A, SfA, B, SfB}; return mainloop_params; } @@ -1577,7 +1571,7 @@ struct HostCollectiveMainloop; static_assert(cute::is_base_of_v); - + // Scale factor Generation related using SfStrategy = cutlass::reference::host::SfStrategy; static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; @@ -1880,11 +1874,11 @@ struct HostCollectiveEpilogue { using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig; using Blk_MN = typename Sm1xxBlockScaledOutputConfig::Blk_MN; - using Blk_SF = typename Sm1xxBlockScaledOutputConfig::Blk_SF; + using Blk_SF = typename Sm1xxBlockScaledOutputConfig::Blk_SF; using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; cutlass::HostTensor tensor_SFD; cutlass::HostTensor reference_SFD; - + using ElementCompute = typename FusionOp::ElementCompute; using ElementScalar = typename FusionOp::ElementScalar; using ElementBias = non_void_t; @@ -1968,9 +1962,9 @@ struct HostCollectiveEpilogue { cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = kDefaultSeed - ): init_scale(init_scale_), init_bias(init_bias_), - init_C(init_C_), seed(seed_), - stride_factor_C(typename LayoutTagC::Stride()), + ): init_scale(init_scale_), init_bias(init_bias_), + init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), stride_factor_D(typename LayoutTagD::Stride()), check_relative_equality(check_relative_equality_), use_device_scalars(use_device_scalars_){ } @@ -2172,7 +2166,7 @@ struct HostCollectiveEpilogue { } } - + if constexpr (IsBlockScaleSupported) { auto m_blks = cutlass::ceil_div(M, cute::size<0>(cute::shape(OutputSFAtom{}))); auto n_blks = cutlass::ceil_div(N, cute::size<1>(cute::shape(OutputSFAtom{}))); @@ -2191,7 +2185,7 @@ struct HostCollectiveEpilogue { EXPECT_TRUE(initialize_tensor(norm_constant.host_view(), init_scale, seed + 2023)); norm_constant.sync_device(); } - + return true; } @@ -2258,7 +2252,7 @@ struct HostCollectiveEpilogue { } } #endif - std::cout<<"D is incorrect"<(problem_size, 1); auto [M, N, K, L] = problem_shape_MNKL; - Arguments arguments = + Arguments arguments = { {}, tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d @@ -2484,7 +2478,7 @@ struct HostCollectiveEpilogue { } } - + if constexpr (IsBlockScaleSupported) { arguments.thread.block_scale_factor_ptr = tensor_SFD.device_data(); arguments.thread.norm_constant_ptr = norm_constant.device_data(); @@ -2550,7 +2544,7 @@ struct HostCollectiveEpilogue { cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, cute::_0{}, cute::_1{}))); } }(); - + auto SfD = [&](){ if constexpr (IsBlockScaleSupported) { auto tensor = make_tensor(detail::make_iterator(reference_SFD.host_data()), @@ -2574,11 +2568,11 @@ struct HostCollectiveEpilogue { decltype(Valpha), decltype(Vbeta), ActivationFunctor, - decltype(SfD), - Int, + decltype(SfD), + Int, cutlass::plus, IsColBiasEnabled - , SfGenStrategy + , SfGenStrategy > epilogue_params{}; epilogue_params.C = C; @@ -2593,7 +2587,7 @@ struct HostCollectiveEpilogue { epilogue_params.scale_d = scale_D.at(coord_0); } - if constexpr (IsRowBiasEnabled or IsColBiasEnabled or IsDeBiasEnabled) + if constexpr (IsRowBiasEnabled or IsColBiasEnabled or IsDeBiasEnabled) { epilogue_params.Bias = Bias; } @@ -2628,7 +2622,7 @@ struct HostCollectiveEpilogue { epilogue_params.Vbeta = Vbeta; } } - + if constexpr (IsBlockScaleSupported) { epilogue_params.SfD = SfD; epilogue_params.st = norm_constant.at(coord_0); @@ -2643,19 +2637,19 @@ template < bool force_legacy_epilogue = false, typename ElementA = typename Gemm::GemmKernel::ElementA, typename ElementB = typename Gemm::GemmKernel::ElementB - , typename RuntimeDatatypeA = void* - , typename RuntimeDatatypeB = void* + , typename RuntimeDatatypeA = void* + , typename RuntimeDatatypeB = void* > struct TestbedImpl { // Kernel data types using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type using HostCollectiveMainloopType = HostCollectiveMainloop; - - using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, - HostCollectiveDefaultEpilogue, + + using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, + HostCollectiveDefaultEpilogue, HostCollectiveEpilogue>; - + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; using ElementCompute = typename ElementComputeType::Type; @@ -2666,7 +2660,7 @@ struct TestbedImpl { using LayoutTagC = typename CollectiveEpilogue::LayoutTagC; using LayoutTagD = typename CollectiveEpilogue::LayoutTagD; - + using InternalElementA = typename Gemm::GemmKernel::ElementA; using InternalElementB = typename Gemm::GemmKernel::ElementB; static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); @@ -2674,11 +2668,11 @@ struct TestbedImpl { static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || - (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; - + uint32_t sm_count; // Used to force multi-wave tests for persistent kernel schedules @@ -2705,7 +2699,7 @@ struct TestbedImpl { cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = kDefaultSeed - ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, init_A_, init_B_, seed_)), + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, init_A_, init_B_, seed_)), collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } TestbedImpl( @@ -2759,7 +2753,7 @@ struct TestbedImpl { file << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; - + collective_mma_inputs.print_tensors(file); collective_epilogue.print_tensors(file); } @@ -2777,7 +2771,7 @@ struct TestbedImpl { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto mainloop_params = collective_mma_inputs.to_host_args(problem_size); auto epilogue_params = collective_epilogue.to_host_args(problem_size); - + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); bool passed = compare_reference(problem_shape_MNKL, alpha, beta); @@ -2865,12 +2859,12 @@ struct TestbedImpl { detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, detail::Splits splits = detail::Splits{}, DecompositionMode decomposition_mode = DecompositionMode::Heuristic - , RuntimeDatatypeA runtime_input_datatype_a = {} - , RuntimeDatatypeB runtime_input_datatype_b = {} + , RuntimeDatatypeA runtime_input_datatype_a = {} + , RuntimeDatatypeB runtime_input_datatype_b = {} ) { #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) - CUTLASS_TRACE_HOST("TestbedImpl::run"); + CUTLASS_TRACE_HOST("TestbedImpl::run"); #endif // Fail test if insufficient CUDA device @@ -2933,12 +2927,12 @@ struct TestbedImpl { mainloop_args = collective_mma_inputs.to_args(); - + if constexpr (IsRuntimeDataType) { mainloop_args.runtime_data_type_a = runtime_input_datatype_a; mainloop_args.runtime_data_type_b = runtime_input_datatype_b; } - + arguments = { @@ -3062,19 +3056,19 @@ template < bool force_legacy_epilogue = false, typename ElementA = typename Gemm::GemmKernel::ElementA, typename ElementB = typename Gemm::GemmKernel::ElementB - , typename RuntimeDatatypeA = void* - , typename RuntimeDatatypeB = void* + , typename RuntimeDatatypeA = void* + , typename RuntimeDatatypeB = void* > struct Testbed3x { using TestBedImpl = typename detail::TestbedImpl< - Gemm, - ActivationFunctor, - force_legacy_epilogue, - ElementA, + Gemm, + ActivationFunctor, + force_legacy_epilogue, + ElementA, ElementB - , RuntimeDatatypeA - , RuntimeDatatypeB + , RuntimeDatatypeA + , RuntimeDatatypeB >; using Kernel = typename Gemm::GemmKernel; using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; @@ -3115,13 +3109,13 @@ struct Testbed3x { DecompositionMode decomposition_mode = DecompositionMode::Heuristic, bool profiling = false, detail::Iterations iterations = detail::Iterations{} - , RuntimeDatatypeA runtime_input_datatype_a = {} - , RuntimeDatatypeB runtime_input_datatype_b = {} + , RuntimeDatatypeA runtime_input_datatype_a = {} + , RuntimeDatatypeB runtime_input_datatype_b = {} ) { return impl_.run( problem_size, alpha, beta, profiling, iterations, raster_order, max_swizzle, splits, decomposition_mode - , runtime_input_datatype_a, runtime_input_datatype_b + , runtime_input_datatype_a, runtime_input_datatype_b ); } }; @@ -3176,7 +3170,7 @@ bool TestGemmPerf3x(int iterations = 20) { ///////////////////////////////////////////////////////////////////////////////////////////////// // template < - typename Gemm, + typename Gemm, typename RuntimeDataTypeA, typename RuntimeDataTypeB, bool force_legacy_epilogue = false> @@ -3266,8 +3260,8 @@ bool TestRuntimeDataTypeSmall( problem_splits.push_back(detail::Splits{2}); } for (auto splits : problem_splits) { - - if constexpr (cute::is_same_v && + + if constexpr (cute::is_same_v && cute::is_same_v) { // e2m1_e2m1 if (runtime_input_datatype_a == cute::UMMA::MXF4Format::E2M1 && @@ -3300,16 +3294,16 @@ bool TestRuntimeDataTypeSmall( return false; } } - - else - if constexpr (cute::is_same_v && + + else + if constexpr (cute::is_same_v && cute::is_same_v) { static_assert((cute::is_same_v || cute::is_same_v || - cute::is_same_v) && + cute::is_same_v) && (cute::is_same_v || cute::is_same_v || - cute::is_same_v), + cute::is_same_v), "Runtime datatype must be selected with an appropriate static umbrella data type."); if constexpr (cute::is_same_v && cute::is_same_v) { @@ -3483,7 +3477,7 @@ bool TestRuntimeDataTypeSmall( return false; } } - else + else if constexpr (cute::is_same_v && cute::is_same_v) { // e5m2_e5m2 @@ -3622,16 +3616,16 @@ bool TestRuntimeDataTypeSmall( template bool TestSmall(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, - CheckEquality check_relative_equality = CheckEquality::RELATIVE, - ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, - VectorScale vector_scale_mode = VectorScale::ENABLED, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED, std::vector override_problem_size_k = {}) { - + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; - CtaShape_MNK cta_shape; + CtaShape_MNK cta_shape; Testbed3x testbed(check_relative_equality, use_device_scalars, vector_scale_mode); static constexpr int SmCount = 16; static constexpr int MultiplierOffsetM = 1; @@ -3901,7 +3895,7 @@ bool TestAll(double alpha = 1.0, double beta = cute::is_same_v(max_swizzle_size) @@ -3912,7 +3906,7 @@ bool TestAll(double alpha = 1.0, double beta = cute::is_same_v(max_swizzle_size) @@ -3923,7 +3917,7 @@ bool TestAll(double alpha = 1.0, double beta = cute::is_same_v(max_swizzle_size) diff --git a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp index d2badb8c..0c8cc2c0 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp @@ -142,13 +142,7 @@ private: template auto make_iterator(T* ptr) { - using namespace cute; - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } template @@ -224,26 +218,26 @@ bool initialize_tensor( scope_max = 2; scope_min = 0; } - + else if (bits_input <= 6) { scope_max = 2; scope_min = -2; } - + else if (bits_input <= 8) { - + if constexpr ( cute::is_same_v){ scope_max = 4; scope_min = 1; } else { - + scope_max = 1; scope_min = -1; - + } - + } else{ scope_max = 4; @@ -292,10 +286,10 @@ static constexpr bool is_row_or_col_major(){ // Default MMA input Operands : A , B // template< - class ScheduleType_, - class Gemm, + class ScheduleType_, + class Gemm, class ElementA_ = typename Gemm::GemmKernel::ElementA, - class ElementB_ = typename Gemm::GemmKernel::ElementB> + class ElementB_ = typename Gemm::GemmKernel::ElementB> struct HostCollectiveMainloop { // Kernel data types using ElementA = ElementA_; @@ -432,13 +426,13 @@ struct HostCollectiveMainloop { if constexpr (IsGroupGemm) { arguments - = + = { device_tensors_A.get(), stride_a_device.get(), device_tensors_B.get(), stride_b_device.get() }; - } + } else { - arguments = + arguments = { device_tensors_A.get(), stride_a_host[0], device_tensors_B.get(), stride_b_host[0] }; @@ -458,8 +452,8 @@ struct HostCollectiveMainloop { auto B = make_tensor(make_iterator(tensors_B[batch].host_data()), make_layout(make_shape(N, K, 1), stride_b_host[batch])); - cutlass::reference::host::GettMainloopParams mainloop_params{}; @@ -542,7 +536,7 @@ struct HostCollectiveMainloop; static constexpr bool IsGroupGemm = !cute::is_same_v; - + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; using ElementScalingFactor = ElementAccumulator; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -627,7 +621,7 @@ struct HostCollectiveMainloop(problem_shapes.get_host_problem_shape(0), 1); L = std::max(problem_shapes.groups(), L); @@ -636,7 +630,7 @@ struct HostCollectiveMainloop(shape(SfAtom{}))); auto m_blks = cutlass::ceil_div(M, Blk_MN{}); auto n_blks = cutlass::ceil_div(N, Blk_MN{}); layout_sfa_host.push_back(Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1))); layout_sfb_host.push_back(Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1))); - + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{}, k_blks * Blk_SF{}); auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{}, k_blks * Blk_SF{}); @@ -717,13 +711,13 @@ struct HostCollectiveMainloop + > {A, SfA, B, SfB}; } @@ -795,7 +789,7 @@ template< class ElementA_, class ElementB_ > -struct HostCollectiveMainloop, +struct HostCollectiveMainloop, Gemm, ElementA_, ElementB_> : public HostCollectiveMainloop, Gemm, ElementA_, ElementB_> { @@ -820,7 +814,7 @@ template< class ElementA_, class ElementB_ > -struct HostCollectiveMainloop, +struct HostCollectiveMainloop, Gemm, ElementA_, ElementB_> : public HostCollectiveMainloop, Gemm, ElementA_, ElementB_> { @@ -854,7 +848,7 @@ struct HostCollectiveDefaultEpilogue { using ElementC = non_void_t; using StrideC = typename kernel::StrideC; using InternalStrideC = typename kernel::InternalStrideC; - + static constexpr bool IsGroupGemm = !cute::is_same_v; using FusionOp = typename Gemm::EpilogueOutputOp; @@ -884,7 +878,7 @@ struct HostCollectiveDefaultEpilogue { /// Initialization cutlass::DeviceAllocation stride_c_device; cutlass::DeviceAllocation stride_d_device; - + std::vector stride_c_host; std::vector stride_d_host; @@ -920,15 +914,15 @@ struct HostCollectiveDefaultEpilogue { cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = kDefaultSeed - ): init_C(init_C_), seed(seed_), - stride_factor_C(typename LayoutTagC::Stride()), + ): init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), stride_factor_D(typename LayoutTagD::Stride()), check_relative_equality(check_relative_equality_), use_device_scalars(use_device_scalars_){ } bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { // Initialize Epilogue tensors - + tensors_C.clear(); tensors_D.clear(); references_D.clear(); @@ -991,7 +985,7 @@ struct HostCollectiveDefaultEpilogue { return cutlass::reference::host::TensorEquals(lhs, rhs); } } - + bool compare_reference( ProblemShapeType problem_shapes, ElementScalar alpha, @@ -1013,7 +1007,7 @@ struct HostCollectiveDefaultEpilogue { bool passed = equality_check(references_D[batch].host_view(), tensors_D[batch].host_view()); if(!passed) { - std::cout<<"D is incorrect"<); - + // Scale factor Generation related using SfStrategy = cutlass::reference::host::SfStrategy; static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; @@ -1164,7 +1158,7 @@ struct HostCollectiveEpilogue { SFD_VectorSize >; using Blk_MN = typename Sm1xxBlockScaledOutputConfig::Blk_MN; - using Blk_SF = typename Sm1xxBlockScaledOutputConfig::Blk_SF; + using Blk_SF = typename Sm1xxBlockScaledOutputConfig::Blk_SF; using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; std::vector> tensors_SFD; std::vector> references_SFD; @@ -1197,7 +1191,7 @@ struct HostCollectiveEpilogue { /// Initialization cutlass::DeviceAllocation stride_c_device; cutlass::DeviceAllocation stride_d_device; - + std::vector stride_c_host; std::vector stride_d_host; @@ -1216,7 +1210,7 @@ struct HostCollectiveEpilogue { std::vector> tensors_C; cutlass::DeviceAllocation device_tensors_C; cutlass::HostTensor norm_constant; - + // Outputs cutlass::HostTensor abs_max_Aux; cutlass::HostTensor abs_max_D; @@ -1256,25 +1250,25 @@ struct HostCollectiveEpilogue { cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = kDefaultSeed - ): init_scale(init_scale_), init_bias(init_bias_), - init_C(init_C_), seed(seed_), - stride_factor_C(typename LayoutTagC::Stride()), + ): init_scale(init_scale_), init_bias(init_bias_), + init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), stride_factor_D(typename LayoutTagD::Stride()), check_relative_equality(check_relative_equality_), use_device_scalars(use_device_scalars_){ } bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { // Initialize Epilogue tensors - + tensors_C.clear(); tensors_D.clear(); references_D.clear(); stride_c_host.clear(); stride_d_host.clear(); - + tensors_SFD.clear(); references_SFD.clear(); - + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); L = std::max(problem_shapes.groups(), L); @@ -1406,7 +1400,7 @@ struct HostCollectiveEpilogue { } } - + if constexpr (IsBlockScaleSupported) { for (int32_t i = 0; i < L; ++i) { auto [M, N, K, _] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); @@ -1424,7 +1418,7 @@ struct HostCollectiveEpilogue { EXPECT_TRUE(initialize_tensor(norm_constant.host_view(), init_scale, seed + 2023)); norm_constant.sync_device(); } - + return true; } @@ -1457,7 +1451,7 @@ struct HostCollectiveEpilogue { return cutlass::reference::host::TensorEquals(lhs, rhs); } } - + bool compare_reference( ProblemShapeType problem_shapes, ElementScalar alpha, @@ -1476,7 +1470,7 @@ struct HostCollectiveEpilogue { bool passed = equality_check(references_D[batch].host_view(), tensors_D[batch].host_view()); if(!passed) { - std::cout<<"D is incorrect"< ? nullptr : beta.device_data(); - + if constexpr (IsScaleFactorEnabled) { fusion_args.scale_a = scale_A.at(coord_0); fusion_args.scale_b = scale_B.at(coord_0); @@ -1717,7 +1711,7 @@ struct HostCollectiveEpilogue { fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); } } - + if constexpr (IsBlockScaleSupported) { std::vector ptr_SFD_host(L); for (int32_t i = 0; i < L; ++i) { @@ -1729,7 +1723,7 @@ struct HostCollectiveEpilogue { arguments.thread.block_scale_factor_ptr = device_tensors_SFD.get(); arguments.thread.norm_constant_ptr = norm_constant.device_data(); } - + } return arguments; @@ -1763,7 +1757,7 @@ struct HostCollectiveEpilogue { cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, M))); auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()), cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, N))); - + auto SfD = [&](){ if constexpr (IsBlockScaleSupported) { auto tensor = make_tensor(detail::make_iterator(references_SFD[batch].host_data()), @@ -1775,7 +1769,7 @@ struct HostCollectiveEpilogue { return D; } }(); - + cutlass::reference::host::GettEpilogueParams< ElementScalar, @@ -1789,11 +1783,11 @@ struct HostCollectiveEpilogue { decltype(Valpha), decltype(Vbeta), ActivationFunctor - , decltype(SfD) - , Int + , decltype(SfD) + , Int , cutlass::plus , false - , SfGenStrategy + , SfGenStrategy > epilogue_params{}; epilogue_params.C = C; @@ -1836,12 +1830,12 @@ struct HostCollectiveEpilogue { epilogue_params.Vbeta = Vbeta; } } - + if constexpr (IsBlockScaleSupported) { epilogue_params.SfD = SfD; epilogue_params.st = norm_constant.at(coord_0); } - + return epilogue_params; } }; @@ -1858,8 +1852,8 @@ struct TestbedImpl { using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type using HostCollectiveMainloopType = HostCollectiveMainloop; - using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, - HostCollectiveDefaultEpilogue, + using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, + HostCollectiveDefaultEpilogue, HostCollectiveEpilogue>; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -1899,7 +1893,7 @@ struct TestbedImpl { cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = kDefaultSeed - ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, init_A_, init_B_, seed_)), + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, init_A_, init_B_, seed_)), collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } TestbedImpl( @@ -2127,10 +2121,10 @@ template < struct Testbed3x { using TestBedImpl = typename detail::TestbedImpl< - Gemm, - ActivationFunctor, - force_legacy_epilogue, - ElementA, + Gemm, + ActivationFunctor, + force_legacy_epilogue, + ElementA, ElementB >; using Kernel = typename Gemm::GemmKernel; @@ -2220,7 +2214,7 @@ bool TestAll(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative cutlass::from_real(alpha), cutlass::from_real(beta) ); - } + } else { ProblemShapeType problem_size{{m, n, k, batch}}; @@ -2247,9 +2241,9 @@ bool TestAll(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative template bool TestSmall(double alpha = 1.0, double beta = 1.0, - CheckEquality check_relative_equality = CheckEquality::RELATIVE, - ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, - VectorScale vector_scale_mode = VectorScale::ENABLED, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED, std::vector override_problem_size_k = {}) { using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; @@ -2257,13 +2251,13 @@ bool TestSmall(double alpha = 1.0, double beta = 1.0, using ElementB = typename Gemm::GemmKernel::ElementB; using TiledMma = typename Gemm::GemmKernel::TiledMma; int alignment_bits = 128; - + static constexpr bool IsF8F6F4 = cutlass::gemm::collective::detail::is_sm100_mma_f8f6f4(); alignment_bits = cutlass::detail::get_input_alignment_bits(); // For fp4 and fp6 kernels, the min alignment_input is 128 elements, so we don't need to add alignment_input in test problem sizes. int alignment_input = (alignment_bits / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits / cute::sizeof_bits::value); - + if constexpr (apply_alignment_offset) { // If BlockScaled, then min alignment is SFVecSize static constexpr bool IsBlockScaleSupported = Gemm::EpilogueOutputOp::IsBlockScaleSupported; @@ -2272,13 +2266,13 @@ bool TestSmall(double alpha = 1.0, double beta = 1.0, alignment_input = cutlass::round_up(alignment_input, SFVecSize); } } - + using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; CtaShape_MNK cta_shape; Testbed3x testbed(check_relative_equality, use_device_scalars, vector_scale_mode); - // For Ptr-Array and Grouped GEMM ideally we need to know SM count at runtime + // For Ptr-Array and Grouped GEMM ideally we need to know SM count at runtime static constexpr int SmCount = 16; float waves[] = {0.5, 2.5}; diff --git a/test/unit/gemm/device/sm100_gemm_f32_f32_f32_simt_align1.cu b/test/unit/gemm/device/sm100_gemm_f32_f32_f32_simt_align1.cu new file mode 100644 index 00000000..10f6b904 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f32_f32_f32_simt_align1.cu @@ -0,0 +1,321 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface (SGEMM) +*/ + +#include "cutlass/cutlass.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// CTA tile shape: 128x128x16 + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM100Only_Device_Gemm_f32n_f32t_f32n_simt_f32_align1, 128x128x16) { + // NT layout + using namespace cute; + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = LayoutC; + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using TileShape = Shape<_128, _128, _16>; + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int kAlignmentA = 1; + static constexpr int kAlignmentB = 1; + static constexpr int kAlignmentC = 1; + static constexpr int kAlignmentD = 1; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCount<3>, + cutlass::gemm::KernelMultistage + >::CollectiveOp; + + // Epilogue + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + kAlignmentC, + ElementD, + LayoutD, + kAlignmentD, + cutlass::epilogue::EpilogueSimtVectorized + >::CollectiveOp; + + // Kernel + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = test::gemm::device::TestSmall(); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f32t_f32n_f32n_simt_f32_align1, 128x128x16) { + // TN layout + using namespace cute; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = LayoutC; + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using TileShape = Shape<_128, _128, _16>; + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int kAlignmentA = 1; + static constexpr int kAlignmentB = 1; + static constexpr int kAlignmentC = 1; + static constexpr int kAlignmentD = 1; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCount<3>, + cutlass::gemm::KernelMultistage + >::CollectiveOp; + + // Epilogue + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + kAlignmentC, + ElementD, + LayoutD, + kAlignmentD, + cutlass::epilogue::EpilogueSimtVectorized + >::CollectiveOp; + + // Kernel + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = test::gemm::device::TestSmall(); + EXPECT_TRUE(result); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// CTA tile shape: 64x32x16 + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM100Only_Device_Gemm_f32n_f32n_f32n_simt_f32_align1, 64x32x16) { + // NN layout + using namespace cute; + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = LayoutC; + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using TileShape = Shape<_64, _32, _16>; + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int kAlignmentA = 1; + static constexpr int kAlignmentB = 1; + static constexpr int kAlignmentC = 1; + static constexpr int kAlignmentD = 1; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCount<3>, + cutlass::gemm::KernelMultistage + >::CollectiveOp; + + // Epilogue + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + kAlignmentC, + ElementD, + LayoutD, + kAlignmentD, + cutlass::epilogue::EpilogueSimtVectorized + >::CollectiveOp; + + // Kernel + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = test::gemm::device::TestSmall(); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f32t_f32t_f32n_simt_f32_align1, 64x32x16) { + // TT layout + using namespace cute; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = LayoutC; + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using TileShape = Shape<_64, _32, _16>; + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int kAlignmentA = 1; + static constexpr int kAlignmentB = 1; + static constexpr int kAlignmentC = 1; + static constexpr int kAlignmentD = 1; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCount<3>, + cutlass::gemm::KernelMultistage + >::CollectiveOp; + + // Epilogue + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + kAlignmentC, + ElementD, + LayoutD, + kAlignmentD, + cutlass::epilogue::EpilogueSimtVectorized + >::CollectiveOp; + + // Kernel + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = test::gemm::device::TestSmall(); + EXPECT_TRUE(result); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f32_f32_f32_simt_align1_bias_relu.cu b/test/unit/gemm/device/sm100_gemm_f32_f32_f32_simt_align1_bias_relu.cu new file mode 100644 index 00000000..cd094504 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f32_f32_f32_simt_align1_bias_relu.cu @@ -0,0 +1,337 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface (SGEMM) +*/ + +#include "cutlass/cutlass.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_f32n_f32t_f32n_simt_f32_align1_bias_relu, 128x128x16) { + // NT layout + using namespace cute; + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = LayoutC; + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBias = float; + using TileShape = Shape<_128, _128, _16>; + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int kAlignmentA = 1; + static constexpr int kAlignmentB = 1; + static constexpr int kAlignmentC = 1; + static constexpr int kAlignmentD = 1; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCount<3>, + cutlass::gemm::KernelMultistage + >::CollectiveOp; + + // Epilogue + // Treat Clamp as Relu + using FusionOp = cutlass::epilogue::fusion::LinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, ElementD, ElementCompute, ElementBias, ElementC, ElementCompute, kAlignmentD + >; + + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + kAlignmentC, + ElementD, + LayoutD, + kAlignmentD, + cutlass::epilogue::EpilogueSimtVectorized + ,FusionOp + >::CollectiveOp; + + // Kernel + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = test::gemm::device::TestSmallFusion(2.0, 6.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f32n_f32t_f32n_simt_f32_align1_bias_relu, 128x256x16) { + // NT layout + using namespace cute; + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = LayoutC; + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBias = float; + using TileShape = Shape<_128, _256, _16>; + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int kAlignmentA = 1; + static constexpr int kAlignmentB = 1; + static constexpr int kAlignmentC = 1; + static constexpr int kAlignmentD = 1; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCount<3>, + cutlass::gemm::KernelMultistage + >::CollectiveOp; + + // Epilogue + // Treat Clamp as Relu + using FusionOp = cutlass::epilogue::fusion::LinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, ElementD, ElementCompute, ElementBias, ElementC, ElementCompute, kAlignmentD + >; + + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + kAlignmentC, + ElementD, + LayoutD, + kAlignmentD, + cutlass::epilogue::EpilogueSimtVectorized + ,FusionOp + >::CollectiveOp; + + // Kernel + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = test::gemm::device::TestSmallFusion(2.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f32t_f32n_f32n_simt_f32_align1_bias_relu, 64x128x16) { + // NT layout + using namespace cute; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = LayoutC; + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBias = float; + using TileShape = Shape<_64, _128, _16>; + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int kAlignmentA = 1; + static constexpr int kAlignmentB = 1; + static constexpr int kAlignmentC = 1; + static constexpr int kAlignmentD = 1; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCount<3>, + cutlass::gemm::KernelMultistage + >::CollectiveOp; + + // Epilogue + // Treat Clamp as Relu + using FusionOp = cutlass::epilogue::fusion::LinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, ElementD, ElementCompute, ElementBias, ElementC, ElementCompute, kAlignmentD + >; + + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + kAlignmentC, + ElementD, + LayoutD, + kAlignmentD, + cutlass::epilogue::EpilogueSimtVectorized + ,FusionOp + >::CollectiveOp; + + // Kernel + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = test::gemm::device::TestSmallFusion(2.0, 6.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f32t_f32n_f32n_simt_f32_align1_bias_relu, 64x32x16) { + // NT layout + using namespace cute; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = LayoutC; + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBias = float; + using TileShape = Shape<_64, _32, _16>; + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int kAlignmentA = 1; + static constexpr int kAlignmentB = 1; + static constexpr int kAlignmentC = 1; + static constexpr int kAlignmentD = 1; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCount<3>, + cutlass::gemm::KernelMultistage + >::CollectiveOp; + + // Epilogue + // Treat Clamp as Relu + using FusionOp = cutlass::epilogue::fusion::LinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, ElementD, ElementCompute, ElementBias, ElementC, ElementCompute, kAlignmentD + >; + + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + kAlignmentC, + ElementD, + LayoutD, + kAlignmentD, + cutlass::epilogue::EpilogueSimtVectorized + ,FusionOp + >::CollectiveOp; + + // Kernel + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = test::gemm::device::TestSmallFusion(2.0, 0.0); + EXPECT_TRUE(result); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f32_f32_f32_simt_align1_ptr_array.cu b/test/unit/gemm/device/sm100_gemm_f32_f32_f32_simt_align1_ptr_array.cu new file mode 100644 index 00000000..8591de0c --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f32_f32_f32_simt_align1_ptr_array.cu @@ -0,0 +1,329 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface (SGEMM) +*/ + +#include "cutlass/cutlass.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// CTA tile shape: 128x128x16 + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM100Only_Device_Gemm_f32n_f32t_f32t_simt_f32_align1_ptr_array, 128x128x16) { + // NT layout + using namespace cute; + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using TileShape = Shape<_128, _128, _16>; + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int kAlignmentA = 1; + static constexpr int kAlignmentB = 1; + static constexpr int kAlignmentC = 1; + static constexpr int kAlignmentD = 1; + + using EpilogueSchedule = cutlass::epilogue::EpiloguePtrArraySimtVectorized; + using KernelSchedule = cutlass::gemm::KernelPtrArrayMultistage; + + using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCount<3>, + KernelSchedule + >::CollectiveOp; + + // Epilogue + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + kAlignmentC, + ElementD, + LayoutD, + kAlignmentD, + EpilogueSchedule + >::CollectiveOp; + + // Kernel + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = test::gemm::device::TestSmall(); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f32t_f32n_f32n_simt_f32_align1_ptr_array, 128x128x16) { + // TN layout + using namespace cute; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = LayoutC; + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using TileShape = Shape<_128, _128, _16>; + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int kAlignmentA = 1; + static constexpr int kAlignmentB = 1; + static constexpr int kAlignmentC = 1; + static constexpr int kAlignmentD = 1; + + using EpilogueSchedule = cutlass::epilogue::EpiloguePtrArraySimtVectorized; + using KernelSchedule = cutlass::gemm::KernelPtrArrayMultistage; + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCount<3>, + KernelSchedule + >::CollectiveOp; + + // Epilogue + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + kAlignmentC, + ElementD, + LayoutD, + kAlignmentD, + EpilogueSchedule + >::CollectiveOp; + + // Kernel + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = test::gemm::device::TestSmall(); + EXPECT_TRUE(result); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// CTA tile shape: 64x256x16 + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM100Only_Device_Gemm_f32n_f32n_f32t_simt_f32_align1_ptr_array, 64x256x16) { + // NN layout + using namespace cute; + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using TileShape = Shape<_64, _256, _16>; + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int kAlignmentA = 1; + static constexpr int kAlignmentB = 1; + static constexpr int kAlignmentC = 1; + static constexpr int kAlignmentD = 1; + + using EpilogueSchedule = cutlass::epilogue::EpiloguePtrArraySimtVectorized; + using KernelSchedule = cutlass::gemm::KernelPtrArrayMultistage; + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCount<3>, + KernelSchedule + >::CollectiveOp; + + // Epilogue + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + kAlignmentC, + ElementD, + LayoutD, + kAlignmentD, + EpilogueSchedule + >::CollectiveOp; + + // Kernel + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = test::gemm::device::TestSmall(); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f32t_f32t_f32n_simt_f32_align1_ptr_array, 64x256x16) { + // TT layout + using namespace cute; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = LayoutC; + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using TileShape = Shape<_64, _256, _16>; + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int kAlignmentA = 1; + static constexpr int kAlignmentB = 1; + static constexpr int kAlignmentC = 1; + static constexpr int kAlignmentD = 1; + + using EpilogueSchedule = cutlass::epilogue::EpiloguePtrArraySimtVectorized; + using KernelSchedule = cutlass::gemm::KernelPtrArrayMultistage; + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCount<3>, + KernelSchedule + >::CollectiveOp; + + // Epilogue + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, + cutlass::arch::OpClassSimt, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + kAlignmentC, + ElementD, + LayoutD, + kAlignmentD, + EpilogueSchedule + >::CollectiveOp; + + // Kernel + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = test::gemm::device::TestSmall(); + EXPECT_TRUE(result); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/tools/library/include/cutlass/library/arch_mappings.h b/tools/library/include/cutlass/library/arch_mappings.h index 4dad3125..34939b33 100644 --- a/tools/library/include/cutlass/library/arch_mappings.h +++ b/tools/library/include/cutlass/library/arch_mappings.h @@ -99,7 +99,7 @@ template struct ArchMap { template struct ArchMap { static int const kMin = 89; - static int const kMax = 89; + static int const kMax = 100; }; template struct ArchMap { diff --git a/tools/library/src/reference/block_scaled_gemm_reference_operation.h b/tools/library/src/reference/block_scaled_gemm_reference_operation.h index e0c3a8c1..769da1c8 100644 --- a/tools/library/src/reference/block_scaled_gemm_reference_operation.h +++ b/tools/library/src/reference/block_scaled_gemm_reference_operation.h @@ -59,13 +59,7 @@ namespace library { namespace detail { template auto make_iterator(T* ptr) { - using namespace cute; - if constexpr (cute::is_subbyte_v) { - return subbyte_iterator(ptr); - } - else { - return ptr; - } + return cute::recast_ptr(ptr); } } @@ -73,7 +67,7 @@ auto make_iterator(T* ptr) { template < Provider Provider_, - typename ElementA_, + typename ElementA_, typename LayoutA_, typename ElementSFA_, typename ElementB_, @@ -125,7 +119,7 @@ public: /// Constructor BlockScaledGemmReferenceOperation() { - + // Basic information description_.provider = kProvider; description_.kind = OperationKind::kBlockScaledGemm; @@ -139,7 +133,7 @@ public: description_.C = make_TensorDescription(); description_.D = make_TensorDescription(); description_.SFD = make_TensorDescription(); - + // Epilogue compute and accumulator type description description_.element_epilogue = NumericTypeMap::kId; @@ -147,7 +141,7 @@ public: NumericTypeMap::kId; // Compute capability for gemm reference - description_.tile_description.minimum_compute_capability = + description_.tile_description.minimum_compute_capability = (kProvider == Provider::kReferenceDevice ? 50 : 0); description_.tile_description.maximum_compute_capability = 1024; @@ -158,7 +152,7 @@ public: // Procedural name std::stringstream ss; - ss << "gemm" + ss << "gemm" << "_reference_" << to_string(description_.provider) << "_" << to_string(description_.A.element) << to_string(description_.A.layout) << "_" << to_string(description_.SFA.element) << to_string(description_.SFA.layout) @@ -221,7 +215,7 @@ public: BlockScaledGemmArguments const &args = *static_cast(arguments); - // Construct cute::Tensor A/B/C + // Construct cute::Tensor A/B/C int M = args.problem_size.m(); int N = args.problem_size.n(); @@ -266,12 +260,12 @@ public: auto D = cute::make_tensor(detail::make_iterator(static_cast(args.D)), cute::make_layout(cute::make_shape(M, N, L), stride_d)); - cutlass::reference::host::GettBlockScalingMainloopParams + cutlass::reference::host::GettBlockScalingMainloopParams mainloop_params{A, SfA, B, SfB}; - if constexpr (not is_same_v) { + if constexpr (not is_same_v) { using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< EpilogueSFVecSize @@ -289,7 +283,7 @@ public: else { // W/O SF generation auto SfD = cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(M, N, L))); // not used. + cute::make_layout(cute::make_shape(M, N, L))); // not used. cutlass::reference::host::GettBlockScalingEpilogueParams< ElementCompute, ElementAccumulator, ElementCompute, decltype(C), decltype(D), decltype(SfD)> @@ -362,7 +356,7 @@ template < typename InnerProductOp_ = multiply_add > void make_block_scaled_gemm(Manifest &manifest) { - /// + /// /// A is Row , B is Col /// manifest.append(new BlockScaledGemmReferenceOperation< @@ -405,7 +399,7 @@ void make_block_scaled_gemm(Manifest &manifest) { ConvertOp_, InnerProductOp_ >); - /// + /// /// A is Col , B is Row /// manifest.append(new BlockScaledGemmReferenceOperation< diff --git a/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h index 08e28db2..368fabb2 100644 --- a/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h @@ -105,8 +105,6 @@ public: cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; int swizzle_size{1}; - - cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; @@ -264,7 +262,8 @@ protected: std::array const &preferred_cluster, std::array const &fallback_cluster, cutlass::library::RasterOrder const &raster_order, - int swizzle_size); + int swizzle_size, + bool is_dynamic_cluster_enabled); /// Update performance result configuration according to flexible user setups void update_result_( @@ -275,7 +274,8 @@ protected: cutlass::library::RasterOrder const &raster_order, std::array const &preferred_cluster, std::array const &fallback_cluster, - int swizzle_size); + int swizzle_size, + bool is_dynamic_cluster_enabled); /// Initializes the performance result void initialize_result_( diff --git a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h index 489dc5af..faf31715 100644 --- a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h @@ -103,8 +103,6 @@ public: cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; int swizzle_size{1}; - - cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; @@ -268,7 +266,8 @@ protected: std::array const &preferred_cluster, std::array const &fallback_cluster, cutlass::library::RasterOrder const &raster_order, - int swizzle_size); + int swizzle_size, + bool is_dynamic_cluster_enabled); /// Update performance result configuration according to flexible user setups void update_result_( @@ -279,7 +278,8 @@ protected: cutlass::library::RasterOrder const &raster_order, std::array const &preferred_cluster, std::array const &fallback_cluster, - int swizzle_size); + int swizzle_size, + bool is_dynamic_cluster_enabled); /// Initializes the performance result void initialize_result_( diff --git a/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h index 72360487..d41b1ba5 100644 --- a/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h @@ -286,14 +286,25 @@ protected: library::GroupedGemmDescription const& operation_desc, ProblemSpace const& problem_space); + /// Update workspace configuration according to flexible user setups + void update_workspace_( + GroupedGemmWorkspace &gemm_workspace, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + cutlass::library::RasterOrder const &raster_order, + int swizzle_size, + bool is_dynamic_cluster_enabled); + /// Update performance result configuration for exploration parameters - void update_result_( + void update_workspace_and_result_( + GroupedGemmWorkspace &gemm_workspace, PerformanceResult &result, ProblemSpace const &problem_space, cutlass::library::RasterOrder const &raster_order, std::array const &preferred_cluster, std::array const &fallback_cluster, - int swizzle_size); + int swizzle_size, + bool is_dynamic_cluster_enabled); /// Verifies CUTLASS against host and device references bool verify_with_reference_( diff --git a/tools/profiler/src/block_scaled_gemm_operation_profiler.cu b/tools/profiler/src/block_scaled_gemm_operation_profiler.cu index 078acc96..73239dce 100644 --- a/tools/profiler/src/block_scaled_gemm_operation_profiler.cu +++ b/tools/profiler/src/block_scaled_gemm_operation_profiler.cu @@ -559,7 +559,8 @@ void BlockScaledGemmOperationProfiler::update_workspace_( std::array const &preferred_cluster, std::array const &fallback_cluster, cutlass::library::RasterOrder const &raster_order, - int swizzle_size + int swizzle_size, + bool is_dynamic_cluster_enabled ) { gemm_workspace.arguments.problem_size.m() = problem_shape.m(); @@ -573,15 +574,17 @@ void BlockScaledGemmOperationProfiler::update_workspace_( gemm_workspace.arguments.swizzle_size = swizzle_size; gemm_workspace.arguments.raster_order = raster_order; - gemm_workspace.arguments.cluster_shape = {int(preferred_cluster[0]), int(preferred_cluster[1]), int(preferred_cluster[2])}; - gemm_workspace.arguments.cluster_shape_fallback = {int(fallback_cluster[0]), int(fallback_cluster[1]), int(fallback_cluster[2])}; + if (is_dynamic_cluster_enabled) { + gemm_workspace.arguments.cluster_shape = {int(preferred_cluster[0]), int(preferred_cluster[1]), int(preferred_cluster[2])}; + gemm_workspace.arguments.cluster_shape_fallback = {int(fallback_cluster[0]), int(fallback_cluster[1]), int(fallback_cluster[2])}; + gemm_workspace.configuration.cluster_shape = {int(preferred_cluster[0]), int(preferred_cluster[1]), int(preferred_cluster[2])}; + gemm_workspace.configuration.cluster_shape_fallback = {int(fallback_cluster[0]), int(fallback_cluster[1]), int(fallback_cluster[2])}; + } gemm_workspace.configuration.problem_size.m() = problem_shape.m(); gemm_workspace.configuration.problem_size.n() = problem_shape.n(); gemm_workspace.configuration.problem_size.k() = problem_shape.k(); - gemm_workspace.configuration.cluster_shape = {int(preferred_cluster[0]), int(preferred_cluster[1]), int(preferred_cluster[2])}; - gemm_workspace.configuration.cluster_shape_fallback = {int(fallback_cluster[0]), int(fallback_cluster[1]), int(fallback_cluster[2])}; gemm_workspace.configuration.lda = leading_dim[0]; gemm_workspace.configuration.ldb = leading_dim[1]; @@ -598,7 +601,8 @@ void BlockScaledGemmOperationProfiler::update_result_( cutlass::library::RasterOrder const &raster_order, std::array const &preferred_cluster, std::array const &fallback_cluster, - int swizzle_size + int swizzle_size, + bool is_dynamic_cluster_enabled ) { result.bytes = problem_.bytes_with_problem_shape(operation_desc, problem_shape); result.flops = problem_.flops_with_problem_shape(operation_desc, problem_shape); @@ -609,12 +613,14 @@ void BlockScaledGemmOperationProfiler::update_result_( set_argument(result, "raster_order", problem_space, library::to_string(raster_order)); set_argument(result, "swizzle_size", problem_space, swizzle_size); - set_argument(result, "cluster_m", problem_space, preferred_cluster[0]); - set_argument(result, "cluster_n", problem_space, preferred_cluster[1]); - set_argument(result, "cluster_k", problem_space, preferred_cluster[2]); - set_argument(result, "cluster_m_fallback", problem_space, fallback_cluster[0]); - set_argument(result, "cluster_n_fallback", problem_space, fallback_cluster[1]); - set_argument(result, "cluster_k_fallback", problem_space, fallback_cluster[2]); + if (is_dynamic_cluster_enabled) { + set_argument(result, "cluster_m", problem_space, preferred_cluster[0]); + set_argument(result, "cluster_n", problem_space, preferred_cluster[1]); + set_argument(result, "cluster_k", problem_space, preferred_cluster[2]); + set_argument(result, "cluster_m_fallback", problem_space, fallback_cluster[0]); + set_argument(result, "cluster_n_fallback", problem_space, fallback_cluster[1]); + set_argument(result, "cluster_k_fallback", problem_space, fallback_cluster[2]); + } } @@ -1377,9 +1383,8 @@ bool BlockScaledGemmOperationProfiler::profile( library::BlockScaledGemmDescription const &operation_desc = static_cast(operation->description()); - auto min_cc = operation_desc.tile_description.minimum_compute_capability; - - bool is_dynamic_cluster_enabled = (min_cc >= 100); + auto cluster_shape = operation_desc.tile_description.cluster_shape; + bool is_dynamic_cluster_enabled = cluster_shape.m() == 0 || cluster_shape.n() == 0 || cluster_shape.k() == 0; // Helper function wrapping up performance test with flexible parameters. auto initialize_and_profile = [&]( @@ -1419,7 +1424,7 @@ bool BlockScaledGemmOperationProfiler::profile( gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; } - update_workspace_(gemm_workspace_, problem_shape, leading_dim, preferred_cluster, fallback_cluster, raster_order, swizzle_size); + update_workspace_(gemm_workspace_, problem_shape, leading_dim, preferred_cluster, fallback_cluster, raster_order, swizzle_size, is_dynamic_cluster_enabled); const auto can_implement = operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments); if (can_implement != Status::kSuccess) { @@ -1447,7 +1452,7 @@ bool BlockScaledGemmOperationProfiler::profile( } PerformanceResult curr_result(result); - update_result_(curr_result, operation_desc, problem_space, problem_shape, raster_order, preferred_cluster, fallback_cluster, swizzle_size); + update_result_(curr_result, operation_desc, problem_space, problem_shape, raster_order, preferred_cluster, fallback_cluster, swizzle_size, is_dynamic_cluster_enabled); curr_result.status = profile_cutlass_( curr_result, @@ -1490,16 +1495,12 @@ bool BlockScaledGemmOperationProfiler::profile( PerformanceResult result_base = results_.back(); results_.pop_back(); - bool dynamic_cluster = int64_t(operation_desc.tile_description.cluster_shape.m()) == 0 || - int64_t(operation_desc.tile_description.cluster_shape.n()) == 0 || - int64_t(operation_desc.tile_description.cluster_shape.k()) == 0; - std::vector> preferred_clusters; std::vector> fallback_clusters; // Only loop over built-in cluster shape lists for dynamic cluster kernels // and for kernels that can leverage the dynamic cluster feature. - if (dynamic_cluster && is_dynamic_cluster_enabled) { + if (is_dynamic_cluster_enabled) { preferred_clusters = this->problem_.preferred_clusters; fallback_clusters = this->problem_.fallback_clusters; } @@ -1510,7 +1511,7 @@ bool BlockScaledGemmOperationProfiler::profile( for (auto preferred_cluster : preferred_clusters) { for (auto fallback_cluster : fallback_clusters) { - if (dynamic_cluster && !is_valid_dynamic_cluster_shape(preferred_cluster, fallback_cluster)) { + if (is_dynamic_cluster_enabled && !is_valid_dynamic_cluster_shape(preferred_cluster, fallback_cluster)) { continue; } for (auto swizzle_size : this->problem_.swizzle_sizes) { diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index 60088075..f932772d 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -661,7 +661,8 @@ void GemmOperationProfiler::update_workspace_( std::array const &preferred_cluster, std::array const &fallback_cluster, cutlass::library::RasterOrder const &raster_order, - int swizzle_size + int swizzle_size, + bool is_dynamic_cluster_enabled ) { gemm_workspace.arguments.problem_size.m() = problem_shape.m(); @@ -675,16 +676,17 @@ void GemmOperationProfiler::update_workspace_( gemm_workspace.arguments.swizzle_size = swizzle_size; gemm_workspace.arguments.raster_order = raster_order; - gemm_workspace.arguments.cluster_shape = {int(preferred_cluster[0]), int(preferred_cluster[1]), int(preferred_cluster[2])}; - gemm_workspace.arguments.cluster_shape_fallback = {int(fallback_cluster[0]), int(fallback_cluster[1]), int(fallback_cluster[2])}; + if (is_dynamic_cluster_enabled) { + gemm_workspace.arguments.cluster_shape = {int(preferred_cluster[0]), int(preferred_cluster[1]), int(preferred_cluster[2])}; + gemm_workspace.arguments.cluster_shape_fallback = {int(fallback_cluster[0]), int(fallback_cluster[1]), int(fallback_cluster[2])}; + gemm_workspace.configuration.cluster_shape = {int(preferred_cluster[0]), int(preferred_cluster[1]), int(preferred_cluster[2])}; + gemm_workspace.configuration.cluster_shape_fallback = {int(fallback_cluster[0]), int(fallback_cluster[1]), int(fallback_cluster[2])}; + } gemm_workspace.configuration.problem_size.m() = problem_shape.m(); gemm_workspace.configuration.problem_size.n() = problem_shape.n(); gemm_workspace.configuration.problem_size.k() = problem_shape.k(); - gemm_workspace.configuration.cluster_shape = {int(preferred_cluster[0]), int(preferred_cluster[1]), int(preferred_cluster[2])}; - gemm_workspace.configuration.cluster_shape_fallback = {int(fallback_cluster[0]), int(fallback_cluster[1]), int(fallback_cluster[2])}; - gemm_workspace.configuration.lda = leading_dim[0]; gemm_workspace.configuration.ldb = leading_dim[1]; gemm_workspace.configuration.ldc = leading_dim[2]; @@ -699,7 +701,8 @@ void GemmOperationProfiler::update_result_( cutlass::library::RasterOrder const &raster_order, std::array const &preferred_cluster, std::array const &fallback_cluster, - int swizzle_size + int swizzle_size, + bool is_dynamic_cluster_enabled ) { result.bytes = problem_.bytes_with_problem_shape(operation_desc, problem_shape); result.flops = problem_.flops_with_problem_shape(operation_desc, problem_shape); @@ -711,12 +714,14 @@ void GemmOperationProfiler::update_result_( set_argument(result, "raster_order", problem_space, library::to_string(raster_order)); set_argument(result, "swizzle_size", problem_space, swizzle_size); - set_argument(result, "cluster_m", problem_space, preferred_cluster[0]); - set_argument(result, "cluster_n", problem_space, preferred_cluster[1]); - set_argument(result, "cluster_k", problem_space, preferred_cluster[2]); - set_argument(result, "cluster_m_fallback", problem_space, fallback_cluster[0]); - set_argument(result, "cluster_n_fallback", problem_space, fallback_cluster[1]); - set_argument(result, "cluster_k_fallback", problem_space, fallback_cluster[2]); + if (is_dynamic_cluster_enabled) { + set_argument(result, "cluster_m", problem_space, preferred_cluster[0]); + set_argument(result, "cluster_n", problem_space, preferred_cluster[1]); + set_argument(result, "cluster_k", problem_space, preferred_cluster[2]); + set_argument(result, "cluster_m_fallback", problem_space, fallback_cluster[0]); + set_argument(result, "cluster_n_fallback", problem_space, fallback_cluster[1]); + set_argument(result, "cluster_k_fallback", problem_space, fallback_cluster[2]); + } } @@ -1576,9 +1581,8 @@ bool GemmOperationProfiler::profile( library::GemmDescription const &operation_desc = static_cast(operation->description()); - auto min_cc = operation_desc.tile_description.minimum_compute_capability; - - bool is_dynamic_cluster_enabled = (min_cc >= 100); + auto cluster_shape = operation_desc.tile_description.cluster_shape; + bool is_dynamic_cluster_enabled = cluster_shape.m() == 0 || cluster_shape.n() == 0 || cluster_shape.k() == 0; // Helper function wrapping up performance test with flexible parameters. auto initialize_and_profile = [&]( @@ -1618,7 +1622,7 @@ bool GemmOperationProfiler::profile( workspace.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; } - update_workspace_(workspace, problem_shape, leading_dim, preferred_cluster, fallback_cluster, raster_order, swizzle_size); + update_workspace_(workspace, problem_shape, leading_dim, preferred_cluster, fallback_cluster, raster_order, swizzle_size, is_dynamic_cluster_enabled); const auto can_implement = operation->can_implement(&workspace.configuration, &workspace.arguments); if (can_implement != Status::kSuccess) { @@ -1672,7 +1676,7 @@ bool GemmOperationProfiler::profile( } PerformanceResult curr_result(result); - update_result_(curr_result, operation_desc, problem_space, problem_shape, raster_order, preferred_cluster, fallback_cluster, swizzle_size); + update_result_(curr_result, operation_desc, problem_space, problem_shape, raster_order, preferred_cluster, fallback_cluster, swizzle_size, is_dynamic_cluster_enabled); curr_result.status = profile_cutlass_( curr_result, @@ -1712,17 +1716,13 @@ bool GemmOperationProfiler::profile( std::vector candidates; PerformanceResult result_base = results_.back(); results_.pop_back(); - - bool dynamic_cluster = int64_t(operation_desc.tile_description.cluster_shape.m()) == 0 || - int64_t(operation_desc.tile_description.cluster_shape.n()) == 0 || - int64_t(operation_desc.tile_description.cluster_shape.k()) == 0; std::vector> preferred_clusters; std::vector> fallback_clusters; // Only loop over built-in cluster shape lists for dynamic cluster kernels // and for kernels that can leverage the dynamic cluster feature. - if (dynamic_cluster && is_dynamic_cluster_enabled) { + if (is_dynamic_cluster_enabled) { preferred_clusters = this->problem_.preferred_clusters; fallback_clusters = this->problem_.fallback_clusters; } @@ -1733,7 +1733,7 @@ bool GemmOperationProfiler::profile( for (auto preferred_cluster : preferred_clusters) { for (auto fallback_cluster : fallback_clusters) { - if (dynamic_cluster && !is_valid_dynamic_cluster_shape(preferred_cluster, fallback_cluster)) { + if (is_dynamic_cluster_enabled && !is_valid_dynamic_cluster_shape(preferred_cluster, fallback_cluster)) { continue; } for (auto swizzle_size : this->problem_.swizzle_sizes) { diff --git a/tools/profiler/src/grouped_gemm_operation_profiler.cu b/tools/profiler/src/grouped_gemm_operation_profiler.cu index e7b02215..4ef9f564 100644 --- a/tools/profiler/src/grouped_gemm_operation_profiler.cu +++ b/tools/profiler/src/grouped_gemm_operation_profiler.cu @@ -538,23 +538,33 @@ void GroupedGemmOperationProfiler::GroupedGemmProblem::initialize_result( library::lexical_cast(beta, operation_desc.gemm.element_epilogue)); } -void GroupedGemmOperationProfiler::update_result_( +void GroupedGemmOperationProfiler::update_workspace_and_result_( + GroupedGemmWorkspace &gemm_workspace, PerformanceResult &result, ProblemSpace const &problem_space, cutlass::library::RasterOrder const &raster_order, std::array const &preferred_cluster, std::array const &fallback_cluster, - int swizzle_size + int swizzle_size, + bool is_dynamic_cluster_enabled ) { + + gemm_workspace.arguments.swizzle_size = swizzle_size; + gemm_workspace.arguments.raster_order = raster_order; + set_argument(result, "raster_order", problem_space, library::to_string(raster_order)); set_argument(result, "swizzle_size", problem_space, swizzle_size); - set_argument(result, "cluster_m", problem_space, preferred_cluster[0]); - set_argument(result, "cluster_n", problem_space, preferred_cluster[1]); - set_argument(result, "cluster_k", problem_space, preferred_cluster[2]); - set_argument(result, "cluster_m_fallback", problem_space, fallback_cluster[0]); - set_argument(result, "cluster_n_fallback", problem_space, fallback_cluster[1]); - set_argument(result, "cluster_k_fallback", problem_space, fallback_cluster[2]); + if (is_dynamic_cluster_enabled) { + gemm_workspace.arguments.cluster_shape = {int(preferred_cluster[0]), int(preferred_cluster[1]), int(preferred_cluster[2])}; + gemm_workspace.arguments.cluster_shape_fallback = {int(fallback_cluster[0]), int(fallback_cluster[1]), int(fallback_cluster[2])}; + set_argument(result, "cluster_m", problem_space, preferred_cluster[0]); + set_argument(result, "cluster_n", problem_space, preferred_cluster[1]); + set_argument(result, "cluster_k", problem_space, preferred_cluster[2]); + set_argument(result, "cluster_m_fallback", problem_space, fallback_cluster[0]); + set_argument(result, "cluster_n_fallback", problem_space, fallback_cluster[1]); + set_argument(result, "cluster_k_fallback", problem_space, fallback_cluster[2]); + } } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -1605,9 +1615,8 @@ bool GroupedGemmOperationProfiler::profile_cutlass_for_fixed_shape_( library::GroupedGemmDescription const &operation_desc = static_cast(operation->description()); - auto min_cc = operation_desc.tile_description.minimum_compute_capability; - - bool is_dynamic_cluster_enabled = (min_cc >= 100); + auto cluster_shape = operation_desc.tile_description.cluster_shape; + bool is_dynamic_cluster_enabled = cluster_shape.m() == 0 || cluster_shape.n() == 0 || cluster_shape.k() == 0; // Helper function to test validity of fallback cluster shapes and preferred cluster shapes. auto is_valid_dynamic_cluster_shape = [](const std::array& preferred_cluster, const std::array& fallback_cluster) { @@ -1636,19 +1645,15 @@ bool GroupedGemmOperationProfiler::profile_cutlass_for_fixed_shape_( PerformanceResult result_base = results_.back(); results_.pop_back(); - bool dynamic_cluster = int64_t(operation_desc.tile_description.cluster_shape.m()) == 0 || - int64_t(operation_desc.tile_description.cluster_shape.n()) == 0 || - int64_t(operation_desc.tile_description.cluster_shape.k()) == 0; - std::vector> preferred_clusters; std::vector> fallback_clusters; // Only loop over built-in cluster shape lists for dynamic cluster kernels // and for kernels that can leverage the dynamic cluster feature. - if (dynamic_cluster && is_dynamic_cluster_enabled) { + if (is_dynamic_cluster_enabled) { preferred_clusters = this->problem_.preferred_clusters; fallback_clusters = this->problem_.fallback_clusters; - } + } else { preferred_clusters = {{int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}}; fallback_clusters = {{int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}}; @@ -1656,13 +1661,13 @@ bool GroupedGemmOperationProfiler::profile_cutlass_for_fixed_shape_( for (auto preferred_cluster : preferred_clusters) { for (auto fallback_cluster : fallback_clusters) { - if (dynamic_cluster && !is_valid_dynamic_cluster_shape(preferred_cluster, fallback_cluster)) { + if (is_dynamic_cluster_enabled && !is_valid_dynamic_cluster_shape(preferred_cluster, fallback_cluster)) { continue; } for (auto swizzle_size : this->problem_.swizzle_sizes) { for (auto raster_order : this->problem_.raster_orders) { PerformanceResult curr_result(result_base); - update_result_(curr_result, problem_space, raster_order, preferred_cluster, fallback_cluster, swizzle_size); + update_workspace_and_result_(gemm_workspace_, curr_result, problem_space, raster_order, preferred_cluster, fallback_cluster, swizzle_size, is_dynamic_cluster_enabled); curr_result.status = profile_cutlass_( curr_result, options, diff --git a/tools/util/include/cutlass/util/mixed_dtype_utils.hpp b/tools/util/include/cutlass/util/mixed_dtype_utils.hpp index 0c7a93a3..43f5a3f9 100644 --- a/tools/util/include/cutlass/util/mixed_dtype_utils.hpp +++ b/tools/util/include/cutlass/util/mixed_dtype_utils.hpp @@ -75,15 +75,7 @@ __global__ void dequantize_kernel(DequantizedElement* dq_buffer, // Represent the full tensors to gmem elements. // These are expected to have shape [MN, K, L] cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout); - auto init_quantized_iterator = [&]() { - if constexpr (cute::sizeof_bits_v >= 8) { - return cute::make_gmem_ptr(q_buffer); - } - else { - return cute::subbyte_iterator(q_buffer); - } - }; - cute::Tensor gmem_op_q = cute::make_tensor(init_quantized_iterator(), operand_layout); + cute::Tensor gmem_op_q = cute::make_tensor(cute::make_gmem_ptr(q_buffer), operand_layout); // While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting // It is expected that K % G == 0 cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout); diff --git a/tools/util/include/cutlass/util/reference/host/conv.hpp b/tools/util/include/cutlass/util/reference/host/conv.hpp index 1056f4e9..57443325 100644 --- a/tools/util/include/cutlass/util/reference/host/conv.hpp +++ b/tools/util/include/cutlass/util/reference/host/conv.hpp @@ -95,6 +95,7 @@ template< class ElementCompute_, class ElementC_, class ElementOut_, + bool ResidualAdd_, class TensorAlpha_, class TensorBeta_, class TensorBias_, @@ -110,6 +111,8 @@ struct ConvEpilogueFusionParams { using TensorBeta = TensorBeta_; using TensorBias = TensorBias_; using ActivationFunctor = ActivationFunctor_; + static constexpr bool ResidualAdd = ResidualAdd_; // Source added after activation + ElementScalar alpha = ElementScalar(1); ElementScalar beta = ElementScalar(0); @@ -228,12 +231,17 @@ private: epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + - scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g)); + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g)); + } if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { output += bias_converter(epi_fusion_params_.tensor_bias[k]); } output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g)); + } tensor_d_(k, q, n, g) = output_converter(output); } } @@ -279,12 +287,17 @@ private: epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + - scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g)); + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g)); + } if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { output += bias_converter(epi_fusion_params_.tensor_bias[k]); } output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g)); + } tensor_d_(k, q, p, n, g) = output_converter(output); } } @@ -337,12 +350,17 @@ private: epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + - scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g)); + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g)); + } if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { output += bias_converter(epi_fusion_params_.tensor_bias[k]); } output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g)); + } tensor_d_(k, q, p, z, n, g) = output_converter(output); } } @@ -389,12 +407,17 @@ private: ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + - scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g)); + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g)); + } if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { output += bias_converter(epi_fusion_params_.tensor_bias[c]); } output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g)); + } tensor_d_(c, w, n, g) = output_converter(output); } } @@ -451,12 +474,17 @@ private: ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + - scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g)); + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g)); + } if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { output += bias_converter(epi_fusion_params_.tensor_bias[c]); } output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g)); + } tensor_d_(c, w, h, n, g) = output_converter(output); } @@ -527,12 +555,17 @@ private: ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + - scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g)); + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g)); + } if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { output += bias_converter(epi_fusion_params_.tensor_bias[c]); } output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g)); + } tensor_d_(c, w, h, d, n, g) = output_converter(output); } } @@ -583,12 +616,17 @@ private: ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + - scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g)); + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g)); + } if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { output += bias_converter(epi_fusion_params_.tensor_bias[c]); } output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g)); + } tensor_d_(c, s, k, g) = output_converter(output); } } @@ -643,12 +681,17 @@ private: ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + - scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g)); + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g)); + } if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { output += bias_converter(epi_fusion_params_.tensor_bias[c]); } output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g)); + } tensor_d_(c, s, r, k, g) = output_converter(output); } } @@ -711,12 +754,17 @@ private: ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; - ElementCompute output = scale_converter(alpha) * acc_converter(accumulator) + - scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g)); + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g)); + } if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { output += bias_converter(epi_fusion_params_.tensor_bias[c]); } output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g)); + } tensor_d_(c, s, r, t, k, g) = output_converter(output); } }