diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d4e8e75..9dd1b663 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,72 @@ # CUTLASS 4.x +## [4.3.0](https://github.com/NVIDIA/cutlass/tree/main) (2025-10-20) + +### CuTe DSL +* Debuggability improvements: + - Supported source location tracking for DSL APIs + - Supported dumping PTX and CUBIN code +* More examples and notebooks to get started with CuTe DSL: + - [Kernel launch with Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/programmatic_dependent_launch.py) + - Improved performance of elementwise kernel (https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/elementwise_apply.py): + + Generalize code to handle list of input tensors + + Generalize TV layout computation to handle different data types + - Demonstrate the new Pipeline APIs in [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py): + + New Pipeline API `PipelineProducer` and `PipelineConsumer` to simplify code (no more explicit pipeline state management) + - Separate epilogue code for non-TMA and TMA implementation + + Note that the updates simplifies the codes but existing APIs still work and are supported + - [Basic Blackwell SM100 GEMM with decent performance](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py) + + Simple tutorial achieves 84% SOL performance with MNK 8K + - Reworked [elementwise add notebook](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/elementwise_add.ipynb) with more details and detailed explanation about TV layout + + Updated implementation to handle general data type and multiple inputs + + Updated explanation for TV layout in simpler language + + Added visualization of TV Layout with 3rd party utils + - [Benchmark and autotune demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/benchmark_autotune.ipynb) +* More examples of authorizing peak-performance kernels: + - [Blackwell SM100 mixed-input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mixed_input_gemm.py) + - [Blackwell SM100 persistent blockwise dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py) + - [Blackwell SM100 persistent blockwise contiguous grouped dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/contiguous_grouped_gemm.py) + - [Blackwell SM100 persistent blockwise masked grouped dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/masked_grouped_gemm.py) + - [Blackwell SM100 fmha bwd](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha_bwd.py) + - [Blackwell SM100 mla](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mla.py) + - [Hopper SM90 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm_persistent.py) + - [Blackwell GeForce batched dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell_geforce/dense_gemm.py) + - [Ampere HSTU Attention](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/hstu_attention.py) +* API updates: + - Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details +* Bug fixings and improvements + - Add mma_tiler_n=64 and mma_tiler_n=192 support in [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py). + - Fixed ``TensorSSA.reduce`` to support static value as initial value + - Updated docstring for following APIs to be more concise and easier to understand: + - ``make_layout_tv`` + - ``is_static`` + - ``PipelineAsync`` + - ``SmemAllocator`` + - Fixed documentation for ``pipeline``, ``utils`` and ``cute.math`` + +### CUTLASS C++ +* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). + - Add softmax skip correction. + - Fix a shared memory allocation bug where it needs to opt in maximum dynamics shared memory explicitly once it exceeds 48KB. + - Fix a dead hang issue caused by early return warp. +* Add Ragged Contiguous Grouped gemm kernel in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm/). + - This kernel uses a TMA 3D load to load the weights matrix and use the tensormap update method to load activations. +* Optimize group gemm kernels by enabling async TMA desc update. +* Support Blackwell SM100 convolution stream-K kernel. + - Unit tests: [fprop_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu), [dgrad_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu), [wgrad_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu). +* Add profiler support for Blackwell SM100 and SM120 blockscaled sparse kernels. +* Fix some kernel issues: + - Fix a race check issue of Blackwell SM103 kernels by adding missing elect one for prefetch barrier initialization. + - Allow user to directly specify the number of stages for Hopper sm90 mixed input gemm. + - Remove warnings caused by cuda vector type alignment setting in CUDA 13. + - Remove problematic `cutlass::int8_t` and replace it with `int8_t`. +* Fix some profiler issues: + - Add some missing reference kernels. + - Add calculation of scale factor A and B in function `bytes_with_problem_shape` of block scaled profiler. +* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! +* Optimal code generation with CUDA toolkit versions 13.0U1. + ## [4.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v4.2.1) (2025-09-22) ### CuTe DSL @@ -26,7 +92,7 @@ - Updates on [TensorSSA demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/tensorssa.ipynb) + Added a section for introducing the broadcast * API updates - - Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details + - Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details * Bug fixings and improvements - Fixed ``cute.print_tensor`` for coordinate tensor - Fixed `cute.print` for tuple of layouts @@ -95,7 +161,7 @@ - Fix smallest MMA-N allowed for Blackwell fp8 and fp16 gemm kernels. - Support fp16 accmulator for sm89 fp8 mma. - Shorten `nullspace` implementation. - - Isolate and comment on `cosize` hacks. + - Isolate and comment on `cosize` risky changes. - Important documentation correction: `E<0,1> == 1@0@1`. * Fix some kernel issues: - Fix Hopper SM90 group gemm kernel to only use the commit group and wait group instead of also waiting on mbarriers. @@ -115,7 +181,7 @@ - [Blackwell Mamba2 SSD](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py) - [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py) * API updates - - Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details + - Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details ### CUTLASS C++ * Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). @@ -149,9 +215,9 @@ ### CuTe DSL * CuTe DSL, a Python DSL centered around CuTe's abstractions - [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL) - - [DSL quick start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html) - - [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html) -* [Overhauled documentation with a new dedicated website](https://docs.nvidia.com/cutlass) + - [DSL quick start](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/quick_start.html) + - [DSL Overview](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/overview.html) +* [Overhauled documentation with a new dedicated website](https://docs.nvidia.com/cutlass/latest) * Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels - [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py) - [Blackwell SM100 grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py) @@ -163,7 +229,7 @@ - [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/ffi/jit_argument.py) * [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks) * API updates - - Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details + - Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details ### CUTLASS C++ * Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9 @@ -251,7 +317,7 @@ - Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels. - Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance. - Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration. - - More detailed introductions and examples to leverage this feature can be found in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss). + - More detailed introductions and examples to leverage this feature can be found in [profiler.md](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss). * Support `void` as the D element in sm100 kernel epilogues. * Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! * Optimal code generation with CUDA toolkit versions 12.8U1. @@ -270,7 +336,7 @@ - [Pipelines that implement Blackwell specific synchronization](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/pipeline/sm100_pipeline.hpp). - [Cluster launch control API supporting preferred and fallback cluster shapes](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/cluster_launch.hpp). - Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types. - - Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_cluster_launch_control.html) to implement dynamic persistence scheduling for [GEMMs](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp). + - Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_cluster_launch_control.html) to implement dynamic persistence scheduling for [GEMMs](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp). - Extensions to testbeds and reference check code for unit tests and CUTLASS profiler. * Full support for Blackwell SM100 kernels in CUTLASS 3.x API: - [Blackwell specific kernel layers](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that @@ -308,11 +374,11 @@ - A set of new [Hopper grouped GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes. - A new [Hopper FP8 GEMM with groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu). * Documentation updates: - - [Quickstart - instantiating a Blackwell block-scaled GEMM](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#instantiating-a-blackwell-sm100-gemm-kernel). - - Detailed [Blackwell block-scaled GEMM functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html) - - A new [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures. - - Updates to [compatibility](https://docs.nvidia.com/cutlass/overview.html#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](https://docs.nvidia.com/cutlass/overview.html#target-architecture). - - Updates to [profiler documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) for testing mixed input GEMM kernels on Hopper. + - [Quickstart - instantiating a Blackwell block-scaled GEMM](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html#instantiating-a-blackwell-sm100-gemm-kernel). + - Detailed [Blackwell block-scaled GEMM functionality documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_functionality.html) + - A new [functionality documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures. + - Updates to [compatibility](https://docs.nvidia.com/cutlass/latest/overview.html#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](https://docs.nvidia.com/cutlass/latest/overview.html#target-architecture). + - Updates to [profiler documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html) for testing mixed input GEMM kernels on Hopper. ## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11) - [Hopper blockwise scaling FP8 GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439). @@ -325,7 +391,7 @@ + Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication. + Remove `cute::copy_vec` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment,...)`. + A refactor of default epilogue struct `DefaultEpilogue` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel. -- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#cutlass-profiler). +- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html#cutlass-profiler). - Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! - Optimal code generation with CUDA toolkit versions 12.6. @@ -339,12 +405,12 @@ - A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API. - [An improved mixed input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. - [EVT nodes for Top-K selection and softmax](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](https://github.com/NVIDIA/cutlass/tree/main/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu). -- [Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html). -- [A new debugging tool, synclog](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. +- [Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/dependent_kernel_launch.html). +- [A new debugging tool, synclog](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/utilities.html#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. - A new TMA-enabled [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support. - A SIMT-enabled pointer-array [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp). - A new [Ping-Pong kernel schedule for Grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations. -- [A new instantiation strategy for CUTLASS profiler kernels](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#instantiating-more-kernels-with-hopper). +- [A new instantiation strategy for CUTLASS profiler kernels](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html#instantiating-more-kernels-with-hopper). - A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/bfloat16.h) - Fixed use of isnan on Windows for [`half_t`](https://github.com/NVIDIA/cutlass/tree/main/test/unit/core/functional.cu). - Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! @@ -367,7 +433,7 @@ - Support for residual add (beta != 0) in convolution kernels. - A new convolution [epilogue](https://github.com/NVIDIA/cutlass/tree/main/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output. - A refactor of [include files throughout CUTLASS core directories](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](https://github.com/NVIDIA/cutlass/tree/main/test/self_contained_includes/CMakeLists.txt). -- [A guide for setting up VSCode to work well with CUTLASS](https://docs.nvidia.com/cutlass/media/docs/cpp/ide_setup.html) and [expanded code style guide](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html). +- [A guide for setting up VSCode to work well with CUTLASS](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/ide_setup.html) and [expanded code style guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/programming_guidelines.html). - Better support for MSVC as a host compiler. - Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2. - Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1. @@ -375,7 +441,7 @@ ## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09) - Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm90_im2col.hpp) - + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html). + + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api_3x.html). + Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/convnd_problem_shape.hpp). + Support for [Fprop](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms + [CUTLASS profiler support](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API. diff --git a/CMakeLists.txt b/CMakeLists.txt index 23be6991..627d1713 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -73,6 +73,16 @@ endif() include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) +# nvcc supports response files with --options-file but some tools like clangd +# might choke on it. Thus provide a way to control the use of this feature. +set(CUTLASS_CUDA_USE_RESPONSE_FILE ON CACHE BOOL "Enable CUDA response files for includes, libraries, and objects") + +if(NOT CUTLASS_CUDA_USE_RESPONSE_FILE) + set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_INCLUDES 0) + set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_LIBRARIES 0) + set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_OBJECTS 0) +endif() + if (CUDA_VERSION VERSION_LESS 11.3) message(WARNING "CUTLASS ${CUTLASS_VERSION} requires CUDA 11.4 or higher, and strongly recommends CUDA 11.8 or higher.") elseif (CUDA_VERSION VERSION_LESS 11.4) @@ -804,9 +814,9 @@ if(NOT WIN32) # Add common library search paths so executables and libraries can load and run # without LD_LIBRARY_PATH being set. link_libraries( - "-Wl,-rpath,'$ORIGIN'" - "-Wl,-rpath,'$ORIGIN/../lib64'" - "-Wl,-rpath,'$ORIGIN/../lib'" + "-Wl,-rpath,'$$ORIGIN'" + "-Wl,-rpath,'$$ORIGIN/../lib64'" + "-Wl,-rpath,'$$ORIGIN/../lib'" "-Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/lib64'" "-Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/lib'" ${CMAKE_DL_LIBS} @@ -934,7 +944,7 @@ function(cutlass_add_executable_tests NAME TARGET) install( FILES ${__RESULT_CACHE_FILE} - DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR}/ + DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR} ) endif() @@ -1062,7 +1072,7 @@ function(cutlass_generate_profiler_tests NAME) install( FILES ${CUTLASS_PROFILER_REGRESSION_LIST_FILE} - DESTINATION ${CMAKE_INSTALL_INFODIR}/cutlass/ + DESTINATION ${CMAKE_INSTALL_INFODIR}/cutlass RENAME profiler_regressions.csv ) diff --git a/README.md b/README.md index 8ce2151a..9e99803b 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ ![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") # Overview -# CUTLASS 4.2.1 +# CUTLASS 4.3.0 -_CUTLASS 4.2.1 - Sept 2025_ +_CUTLASS 4.3.0 - Oct 2025_ CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. It incorporates strategies for @@ -27,115 +27,90 @@ native support of such data types) across NVIDIA's Volta, Turing, Ampere, Ada, H To this rich ecosystem of C++ based kernel programming abstractions, CUTLASS 4 adds CUTLASS DSLs. These are Python native interfaces for writing high-performance CUDA kernels based on core CUTLASS and CuTe concepts without any performance compromises. This allows for a much smoother learning curve, orders of magnitude faster compile times, native integration with DL frameworks without writing glue code, and much more intuitive metaprogramming that does not require deep C++ expertise. -Overall we envision CUTLASS DSLs as a family of domain-specific languages (DSLs). With the release of 4.0, we are releasing the first of these in CuTe DSL. This is a low level programming model that is fully consistent with CuTe C++ abstractions -- exposing core concepts such as layouts, tensors, hardware atoms, and full control over the hardware thread and data hierarchy. +Overall we envision CUTLASS DSLs as a family of domain-specific languages (DSLs). With the release of 4.0, we are releasing the first of these in CuTe DSL. This is a low level programming model that is fully consistent with CuTe C++ abstractions — exposing core concepts such as layouts, tensors, hardware atoms, and full control over the hardware thread and data hierarchy. CuTe DSL demonstrates optimal matrix multiply and other linear algebra operations targeting the programmable, high-throughput _Tensor Cores_ implemented by NVIDIA's Ampere, Hopper, and Blackwell architectures. We believe it will become an indispensable tool for students, researchers, and performance -engineers alike -- flattening the learning curve of GPU programming, rapidly prototyping kernel +engineers alike — flattening the learning curve of GPU programming, rapidly prototyping kernel designs, and bringing optimized solutions into production. CuTe DSL is currently in public beta and will graduate out of beta by end of summer 2025. To get started quickly - please refer : - - [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html). - - [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html). + - [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html). + - [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/quick_start.html). -# What's New in CUTLASS 4.2 +# What's New in CUTLASS 4.3 ## CuTe DSL -* More Python versions are now supported for both x86-64 and aarch64, including - - Python 3.10, 3.11, 3.12, and 3.13 -* Added new example and updated notebook to get started with CuTe DSL - - [Call kernels with dlpack bypassed](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py) - - Updates on [TensorSSA demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/tensorssa.ipynb) - + Added a section for introducing the broadcast -* API updates - - Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details +* Debuggability improvements: + - Supported source location tracking for DSL APIs + - Supported dumping PTX and CUBIN code +* More examples and notebooks to get started with CuTe DSL: + - [Kernel launch with Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/programmatic_dependent_launch.py) + - Improved performance of elementwise kernel (https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/elementwise_apply.py): + + Generalize code to handle list of input tensors + + Generalize TV layout computation to handle different data types + - Demonstrate the new Pipeline APIs in [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py): + + New Pipeline API `PipelineProducer` and `PipelineConsumer` to simplify code (no more explicit pipeline state management) + - Separate epilogue code for non-TMA and TMA implementation + + Note that the updates simplifies the codes but existing APIs still work and are supported + - [Basic Blackwell SM100 GEMM with decent performance](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py) + + Simple tutorial achieves 84% SOL performance with MNK 8K + - Reworked [elementwise add notebook](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/elementwise_add.ipynb) with more details and detailed explanation about TV layout + + Updated implementation to handle general data type and multiple inputs + + Updated explanation for TV layout in simpler language + + Added visualization of TV Layout with 3rd party utils + - [Benchmark and autotune demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/benchmark_autotune.ipynb) +* More examples of authorizing peak-performance kernels: + - [Blackwell SM100 mixed-input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mixed_input_gemm.py) + - [Blackwell SM100 persistent blockwise dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py) + - [Blackwell SM100 persistent blockwise contiguous grouped dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/contiguous_grouped_gemm.py) + - [Blackwell SM100 persistent blockwise masked grouped dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/masked_grouped_gemm.py) + - [Blackwell SM100 fmha bwd](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha_bwd.py) + - [Blackwell SM100 mla](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mla.py) + - [Hopper SM90 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm_persistent.py) + - [Blackwell GeForce batched dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell_geforce/dense_gemm.py) + - [Ampere HSTU Attention](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/hstu_attention.py) +* API updates: + - Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details * Bug fixings and improvements - - Fixed ``cute.print_tensor`` for coordinate tensor - - Fixed `cute.print` for tuple of layouts - - Fixed frozen object is not properly updated after fully assigned in dynamic control flow - - Fixed assign tuple/list element in a dynamic control flow may cause compilation failure - - Improved error message when CUDA context is not initialized - - Improved docstring of congruent and weakly_congruent + - Add mma_tiler_n=64 and mma_tiler_n=192 support in [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py). + - Fixed ``TensorSSA.reduce`` to support static value as initial value + - Updated docstring for following APIs to be more concise and easier to understand: + - ``make_layout_tv`` + - ``is_static`` + - ``PipelineAsync`` + - ``SmemAllocator`` + - Fixed documentation for ``pipeline``, ``utils`` and ``cute.math`` ## CUTLASS C++ -* Support for Blackwell SM103 kernels for B300 GPUs. - - Collective mainloop codes: [Blockscaled datatypes with support for dense GEMM mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp) - - New [GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders. - - Kernel codes: [Blockscaled datatypes with support for dense GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp). -* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM103 architecture: - - [Blockscaled ultra fp4 dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/89_sm103_fp4_ultra_gemm/). - - [Blockscaled ultra fp4 dense grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/90_sm103_fp4_ultra_grouped_gemm). -* Set of unit tests that demonstrate the usage of Blackwell SM103 blockscaled GEMM - - Unit test files with prefix name of `sm103_` under [GEMM device unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/). -* Support for Blackwell SM121 kernels for DGX Spark GPUs. - - Share the major codes with Blackwell SM120 kernels. -* Add support for heuristics-based kernel filtering and autotuning using `nvidia-matmul-heuristics` to find the best kernels for a given scenario. - - Details please refer to [heuristics doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/heuristics.md). * Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). - - Add fused reduction kernel support for cutlass MLA. - Add softmax skip correction. - - Support for GQA in FMHA backward kernel. - - Fix an issue where `get_unmasked_trip_count` may return a negative value. - - Fix an issue where mbarriers are initialized with a zero arrival count. - - Fix a corner case issue where the sequence length of q is not a multiple of tile_q. - - Remove tma padding for forward kernel inputs. -* Add Blackwell SM100 kernels for MoEs (focusing on Low-Latency inference performance): [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm/). It uses TMA (for weights) and CPASYNC (for tokens) to load input matrices and allow only one problem dimension to vary across groups/experts, unlike general Grouped GEMMs. Note: further API simplifications and kernel improvements are upcoming. Any feedback on API is welcome. -* Further enhance blockwise and groupwise GEMMs on Hopper and Blackwell - - On Blackwell SM120, a blockwise gemm kernel is added: [example 87](https://github.com/NVIDIA/cutlass/tree/main/examples/87_blackwell_geforce_gemm_blockwise/). - - On Hopper, add K major scale factor support for SM90 blockwise kernels. - - On Hopper, relax the restriction that the k dimension of the problem size has to be the multiple of the k dimension of the tile size. - - On Hopper, grouped version supports the case when k = 0. -* Support for Blackwell SM100 fp4 gemv kernels. - - Kernel codes: [Gemv kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/gemv_blockscaled.h). - - Example codes: [example 91](https://github.com/NVIDIA/cutlass/tree/main/examples/91_fp4_gemv/) -* Support for Blackwell SM100 legacy mixed input GEMM kernels. - - Collective mainloop codes: [Mixed input mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp). - - Kernel codes: [Mixed input kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp). - - Example codes: [example 86](https://github.com/NVIDIA/cutlass/tree/main/examples/86_blackwell_mixed_dtype_gemm/). -* Support for Blackwell SM100 cpasync kernel. - - Collective mainloop codes: [cpasync mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp). - - Kernel codes: [cpasync kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp). -* Support Blackwell SM120 mixed input blockscaled grouped GEMM. -* Instantiating more Blackwell kernels in profiler. - - Blackwell SM100 and SM103 kernels support `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate all possible combinations. - - To use this feature, `CUTLASS_LIBRARY_KERNELS` must be non-empty. Profiler will combine `CUTLASS_LIBRARY_KERNELS` and `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate specific kernels. - - Details please check [Profiler Doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/profiler.md). -* Fix some profiler issues: - - Modify default cluster callback values to none 0 to avoid profiler failure when these values are not set in command line. - - Fix some no output and timeout issues. - - Fix Pingpong Blockwise Hopper library generation. -* From CUDA 13.0, the Blackwell SM101 for Thor GPUs is renamed to SM110. - - For CUDA toolkit version < 13.0, SM101 is still used for Thor GPUs. - - For CUDA toolkit version >= 13.0, SM110 is used for Thor GPUs and SM101 is no longer valid. -* Rename legacy Python API package from `cutlass` to `cutlass_cppgen` and add Blackwell EVT support to legacy Python interface. - - Restructuring the C++ Blackwell SM100 Collective Epilogue Builder to work with the Python interface's `EpilogueDescriptors`. - - Added Blackwell SM100 EVT Emitter on the Python side and routed most emission through Hopper SM90 Emitter. - - Added some support for running SM100 kernels via the Python interface. -* CuTe changes: - - Fix inaccurate GridDim calculation under [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/blackwell/). - - Add [movmatrix](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-movmatrix) support. - - Fix smallest MMA-N allowed for Blackwell fp8 and fp16 gemm kernels. - - Support fp16 accmulator for sm89 fp8 mma. - - Shorten `nullspace` implementation. - - Isolate and comment on `cosize` hacks. - - Important documentation correction: `E<0,1> == 1@0@1`. + - Fix a shared memory allocation bug where it needs to opt in maximum dynamics shared memory explicitly once it exceeds 48KB. + - Fix a dead hang issue caused by early return warp. +* Add Ragged Contiguous Grouped gemm kernel in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm/). + - This kernel uses a TMA 3D load to load the weights matrix and use the tensormap update method to load activations. +* Optimize group gemm kernels by enabling async TMA desc update. +* Support Blackwell SM100 convolution stream-K kernel. + - Unit tests: [fprop_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu), [dgrad_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu), [wgrad_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu). +* Add profiler support for Blackwell SM100 and SM120 blockscaled sparse kernels. * Fix some kernel issues: - - Fix Hopper SM90 group gemm kernel to only use the commit group and wait group instead of also waiting on mbarriers. - - Fix a tiny bug when K is large for Blackwell SM103 fp4 grouped GEMM kernel. -* Add following unit tests: - - [fp16 accmulator for sm89 fp8 mma](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/ampere/cooperative_gemm.cu) - - [movmatrix test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/turing/movm.cu) - - [fp8 narrow mma n](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32_narrow_mma_n.cu) and [fp16 narrow mma n](test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_bf16_narrow_mma_n.cu) + - Fix a race check issue of Blackwell SM103 kernels by adding missing elect one for prefetch barrier initialization. + - Allow user to directly specify the number of stages for Hopper sm90 mixed input gemm. + - Remove warnings caused by cuda vector type alignment setting in CUDA 13. + - Remove problematic `cutlass::int8_t` and replace it with `int8_t`. +* Fix some profiler issues: + - Add some missing reference kernels. + - Add calculation of scale factor A and B in function `bytes_with_problem_shape` of block scaled profiler. Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits. CUTLASS team is working on a fix. -**See the [CHANGELOG](https://docs.nvidia.com/cutlass/CHANGELOG.html) for details of all past releases and updates.** +**See the [CHANGELOG](https://docs.nvidia.com/cutlass/latest/CHANGELOG.html) for details of all past releases and updates.** # Performance @@ -177,7 +152,7 @@ Layouts can also be combined and manipulated via functional composition, on whic CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design and improves code composability and readability. More documentation specific to CuTe can be found in its -[dedicated documentation directory](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/00_quickstart.html). +[dedicated documentation directory](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/00_quickstart.html). # Compatibility @@ -263,7 +238,7 @@ NVIDIA Blackwell GeForce RTX 50 series GPUs (SM120). As a result, kernels compiled for Blackwell SM100 architecture with arch conditional features (using `sm100a`) are not compatible with RTX 50 series GPUs. -Please refer to the [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) +Please refer to the [functionality documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html) for details on which kernels require which target architectures. # Documentation @@ -271,22 +246,22 @@ for details on which kernels require which target architectures. CUTLASS is described in the following documents and the accompanying [Doxygen documentation](https://nvidia.github.io/cutlass). -- [Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html) - basics of building and running CUTLASS -- [Functionality](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) - summarizes functionality available in CUTLASS -- [Efficient GEMM in CUDA](https://docs.nvidia.com/cutlass/media/docs/cpp/efficient_gemm.html) - describes how GEMM kernels may be implemented efficiently in CUDA -- [CUTLASS 3.x Design](https://docs.nvidia.com/cutlass/media/docs/cpp/cutlass_3x_design.html) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components -- [GEMM API 3.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html) - describes the CUTLASS 3.x GEMM model and C++ template concepts -- [GEMM API 2.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api.html) - describes the CUTLASS 2.x GEMM model and C++ template concepts -- [Implicit GEMM Convolution](https://docs.nvidia.com/cutlass/media/docs/cpp/implicit_gemm_convolution.html) - describes 2-D and 3-D convolution in CUTLASS -- [Code Organization](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html) - describes the organization and contents of the CUTLASS project -- [Terminology](https://docs.nvidia.com/cutlass/media/docs/cpp/terminology.html) - describes terms used in the code -- [Programming Guidelines](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html) - guidelines for writing efficient modern CUDA C++ -- [Fundamental types](https://docs.nvidia.com/cutlass/media/docs/cpp/fundamental_types.html) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays -- [Layouts](https://docs.nvidia.com/cutlass/media/docs/cpp/layout.html) - describes layouts of matrices and tensors in memory -- [Tile Iterators](https://docs.nvidia.com/cutlass/media/docs/cpp/tile_iterator_concept.html) - describes C++ concepts for iterating over tiles of matrices in memory -- [CUTLASS Profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) - command-line driven profiling application -- [CUTLASS Utilities](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html) - additional templates used to facilitate rapid development -- [Dependent kernel launch](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html) - describes a new feature in Hopper which allows overlapping dependent +- [Quick Start Guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html) - basics of building and running CUTLASS +- [Functionality](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html) - summarizes functionality available in CUTLASS +- [Efficient GEMM in CUDA](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/efficient_gemm.html) - describes how GEMM kernels may be implemented efficiently in CUDA +- [CUTLASS 3.x Design](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_design.html) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components +- [GEMM API 3.x](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api_3x.html) - describes the CUTLASS 3.x GEMM model and C++ template concepts +- [GEMM API 2.x](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api.html) - describes the CUTLASS 2.x GEMM model and C++ template concepts +- [Implicit GEMM Convolution](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/implicit_gemm_convolution.html) - describes 2-D and 3-D convolution in CUTLASS +- [Code Organization](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/code_organization.html) - describes the organization and contents of the CUTLASS project +- [Terminology](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/terminology.html) - describes terms used in the code +- [Programming Guidelines](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/programming_guidelines.html) - guidelines for writing efficient modern CUDA C++ +- [Fundamental types](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/fundamental_types.html) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays +- [Layouts](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/layout.html) - describes layouts of matrices and tensors in memory +- [Tile Iterators](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/tile_iterator_concept.html) - describes C++ concepts for iterating over tiles of matrices in memory +- [CUTLASS Profiler](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html) - command-line driven profiling application +- [CUTLASS Utilities](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/utilities.html) - additional templates used to facilitate rapid development +- [Dependent kernel launch](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/dependent_kernel_launch.html) - describes a new feature in Hopper which allows overlapping dependent kernels in the same stream, and how it is used in CUTLASS. # Resources @@ -306,7 +281,7 @@ projects. Client applications should target CUTLASS's `include/` directory in th paths. CUTLASS unit tests, examples, and utilities can be build with CMake. -The minimum version of CMake is given in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html). +The minimum version of CMake is given in the [Quickstart guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html). Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed on your system. @@ -351,7 +326,7 @@ CUTLASS is arranged as a header-only library along with Utilities, Tools, Exampl and template concepts defined in the CUTLASS project. A detailed explanation of the source code organization may be found in the -[CUTLASS documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html), but several main components are summarized below. +[CUTLASS documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/code_organization.html), but several main components are summarized below. ## CUTLASS Template Library @@ -425,7 +400,7 @@ tools/ The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate basic usage of Core API components and complete tests of the CUTLASS GEMM computations. -Instructions for building and running the Unit tests are described in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html). +Instructions for building and running the Unit tests are described in the [Quickstart guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html). # Performance Profiling @@ -641,9 +616,9 @@ reference_device: Passed ## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler - Please follow the links for more CMake examples on selectively compiling CUTLASS kernels: - - [GEMM CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#gemm-cmake-examples) - - [Implicit GEMM convolution CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#convolution-cmake-examples) -- [Further details about the CUTLASS Profiler are described here.](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) + - [GEMM CMake Examples](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html#gemm-cmake-examples) + - [Implicit GEMM convolution CMake Examples](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html#convolution-cmake-examples) +- [Further details about the CUTLASS Profiler are described here.](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html) # About diff --git a/customConfigs.cmake b/customConfigs.cmake index ac86cbe1..563316bd 100644 --- a/customConfigs.cmake +++ b/customConfigs.cmake @@ -36,6 +36,7 @@ set(CUTLASS_PROFILER_REGRESSION_TEST_LEVEL ${CUTLASS_TEST_LEVEL} CACHE STRING " find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED) + function(cutlass_generate_kernel_filter_and_testlist_files) set(options) diff --git a/examples/70_blackwell_gemm/CMakeLists.txt b/examples/70_blackwell_gemm/CMakeLists.txt index 5bb294b1..6e87b687 100644 --- a/examples/70_blackwell_gemm/CMakeLists.txt +++ b/examples/70_blackwell_gemm/CMakeLists.txt @@ -33,7 +33,7 @@ set(TEST_SWIZZLE_2 --swizzle=2) set(TEST_SWIZZLE_5 --swizzle=5) set(TEST_SWIZZLE_5_UNEVEN --swizzle=5 --m=4096 --n=16384) -if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f") +if(CUTLASS_NVCC_ARCHS MATCHES "100a|100f|101a|101f|103a|103f") cutlass_example_add_executable( 70_blackwell_fp16_gemm 70_blackwell_fp16_gemm.cu diff --git a/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt b/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt index 87bcd514..bc32fbe6 100644 --- a/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt +++ b/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt @@ -27,7 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # Both filenames are shorter to avoid MAX_PATH issues on Windows. -if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f") +if(CUTLASS_NVCC_ARCHS MATCHES "100a|100f|101a|101f|103a|103f") cutlass_example_add_executable( 71_blackwell_gemm_with_collective_builder 71_blackwell_gemm_with_collective_builder.cu diff --git a/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt b/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt index 45ccdcca..6c13ae4e 100644 --- a/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt +++ b/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt @@ -28,7 +28,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f") +if(CUTLASS_NVCC_ARCHS MATCHES "100a|100f|101a|101f|103a|103f") cutlass_example_add_executable( 72a_blackwell_nvfp4_bf16_gemm 72a_blackwell_nvfp4_bf16_gemm.cu diff --git a/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt b/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt index 4ac31f62..4572cfcb 100644 --- a/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt +++ b/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt @@ -28,7 +28,7 @@ -if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f") +if(CUTLASS_NVCC_ARCHS MATCHES "100a|100f|101a|101f|103a|103f") cutlass_example_add_executable( 73_blackwell_gemm_preferred_cluster blackwell_gemm_preferred_cluster.cu diff --git a/examples/74_blackwell_gemm_streamk/CMakeLists.txt b/examples/74_blackwell_gemm_streamk/CMakeLists.txt index 4f808e85..549b51ce 100644 --- a/examples/74_blackwell_gemm_streamk/CMakeLists.txt +++ b/examples/74_blackwell_gemm_streamk/CMakeLists.txt @@ -29,7 +29,7 @@ -if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f") +if(CUTLASS_NVCC_ARCHS MATCHES "100a|100f|101a|101f|103a|103f") cutlass_example_add_executable( 74_blackwell_gemm_streamk blackwell_gemm_streamk.cu diff --git a/examples/75_blackwell_grouped_gemm/CMakeLists.txt b/examples/75_blackwell_grouped_gemm/CMakeLists.txt index 304a49f8..8669eb35 100644 --- a/examples/75_blackwell_grouped_gemm/CMakeLists.txt +++ b/examples/75_blackwell_grouped_gemm/CMakeLists.txt @@ -49,7 +49,7 @@ set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0) set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes -if(CUTLASS_NVCC_ARCHS STREQUAL "100a") +if("100a" IN_LIST CUTLASS_NVCC_ARCHS) cutlass_example_add_executable( 75_blackwell_grouped_gemm 75_blackwell_grouped_gemm.cu diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index 65034d3d..9a5326ca 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -110,6 +110,8 @@ set(TEST_BWD_MLA_VARLEN --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify set(TEST_MLA_SEP_REDUCTION --b=1 --k=4096 --split_kv=8 --page=128 --verify) set(TEST_MLA_FUSE_REDUCTION --b=1 --k=4096 --split_kv=8 --page=128 --fuse_reduction --verify) +set(TEST_MLA_LARGE_SPLIT_KV --verify --split_kv=20 --page=128) + if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a)) foreach(PREC fp8 fp16) @@ -171,6 +173,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC 77_blackwell_mla.cu TEST_COMMAND_OPTIONS TEST_MLA_BASIC + TEST_MLA_LARGE_SPLIT_KV TEST_MLA_SEP_REDUCTION TEST_MLA_FUSE_REDUCTION ) @@ -183,6 +186,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC 77_blackwell_mla.cu TEST_COMMAND_OPTIONS TEST_MLA_BASIC + TEST_MLA_LARGE_SPLIT_KV TEST_MLA_SEP_REDUCTION TEST_MLA_FUSE_REDUCTION ) diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp index 1e094bf4..2063a379 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -78,7 +78,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { using StagesQ = cutlass::gemm::collective::StageCount; using StagesKV = cutlass::gemm::collective::StageCount; - + using ClusterShape = Shape<_1, _1, _1>; static const int Alignment = 128 / sizeof_bits_v; @@ -668,7 +668,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { pipeline_c.producer_acquire(pipeline_c_producer_state); - ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + ElementQK acc_scale = (old_row_max == row_max_safe) ? 0.5f : 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); row_sum *= acc_scale; // row_sum = sum(reg_S) float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); @@ -934,13 +934,16 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { } Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size(tTMrO_i); j += 2) { - float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); - float2 out; - cute::mul(out, scale_f32x2, in); - tTMrO_i(j) = out.x; - tTMrO_i(j+1) = out.y; + + if (scale != 1.0f) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j+1) = out.y; + } } copy_out(i); @@ -1009,11 +1012,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); // e^(scale * (old_max - new_max) - float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + float scale = (tTMEM_LOADVrS(kIdxOldRowMax) == tTMEM_LOADVrS(kIdxNewRowMax)) ? 1.0f : ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); pipeline_o.consumer_wait(pipeline_o_consumer_state); - correction_rescale(scale, uint32_t(TmemAllocation::O0)); + bool warp_do_correction = __any_sync(0xFFFFFFFF, scale != 1.0f); + if (warp_do_correction) { + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + } pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); ++pipeline_s1_c_consumer_state; @@ -1027,11 +1033,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); - scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + scale = (tTMEM_LOADVrS(kIdxOldRowMax) == tTMEM_LOADVrS(kIdxNewRowMax)) ? 1.0f : ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); pipeline_o.consumer_wait(pipeline_o_consumer_state); - correction_rescale(scale, uint32_t(TmemAllocation::O1)); + warp_do_correction = __any_sync(0xFFFFFFFF, scale != 1.0f); + if (warp_do_correction) { + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + } pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); ++pipeline_s0_c_consumer_state; @@ -1071,7 +1080,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { // store to smem Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); - + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); if (epilogue.params.ptr_LSE != nullptr) { @@ -1171,10 +1180,10 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { auto tOgO = thr_copy.partition_D(sO); auto tOrO = make_tensor(shape(tOgO(_,_,_,_0{}))); clear(tOrO); - + copy(tiled_copy, tOrO, tOgO(_,_,_,_0{})); #endif - + if (epilogue.params.ptr_LSE != nullptr) { int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord); diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp index afef0224..453bf77d 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -86,10 +86,10 @@ struct Sm100FmhaGenMainloopWarpspecialized { static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2; static constexpr int StageCountKV = 256 * 11 / get<1>(TileShape{}); - + using StagesQ = cutlass::gemm::collective::StageCount; using StagesKV = cutlass::gemm::collective::StageCount; - + using ClusterShape = Shape<_1, _1, _1>; static const int Alignment = 128 / sizeof_bits_v; @@ -187,7 +187,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { SmemLayoutQ, SmemLayoutK, SmemLayoutV, PipelineQ, PipelineKV, TileShape, Mask >; - + struct Arguments { typename Load::Arguments load; @@ -622,7 +622,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { NumericArrayConverter convert; const int kReleasePipeCount = 10; // must be multiple of 2 - + order_s.wait(); CUTLASS_PRAGMA_UNROLL @@ -646,7 +646,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { } tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); - + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { order_s.arrive(); } @@ -672,7 +672,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { pipeline_c.producer_acquire(pipeline_c_producer_state); - ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + ElementQK acc_scale = (old_row_max == row_max_safe) ? 0.5f : 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); row_sum *= acc_scale; // row_sum = sum(reg_S) float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); @@ -700,7 +700,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; - + row_sum = local_row_sum; if (final_call) { @@ -781,7 +781,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { template CUTLASS_DEVICE auto correction_epilogue( - float scale_softmax_log2, float scale_out, Vector const& v0, Vector const& v1, + float scale_softmax_log2, float scale_out, Vector const& v0, Vector const& v1, GTensor& gO, CTensor const& cO, Shape const& g_shape, Epilogue const& epilogue) { @@ -794,13 +794,13 @@ struct Sm100FmhaGenMainloopWarpspecialized { // good values would be either 32 or 64 const int kCorrectionTileSize = 32 / sizeof(ElementOut); - using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem typename CollectiveMmaPV::TiledMma mma; Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); Tensor tOgO = mma.get_slice(0).partition_C(gO); - + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); Tensor tOgO_i = tOgO.compose(make_layout(make_shape(_128{}, Int{}))); @@ -812,7 +812,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); - + Tensor tTMEM_LOADtO0 = thr_tmem_load.partition_S(tOtO0); Tensor tTMEM_LOADtO1 = thr_tmem_load.partition_S(tOtO1); Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); @@ -841,10 +841,10 @@ struct Sm100FmhaGenMainloopWarpspecialized { Tensor tTMrO0 = make_tensor(shape(tTMEM_LOADcO)); Tensor tTMrO1 = make_tensor(shape(tTMEM_LOADcO)); - + copy(tiled_tmem_load, tTMEM_LOADtO0_i, tTMrO0); copy(tiled_tmem_load, tTMEM_LOADtO1_i, tTMrO1); - + CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size(tTMrO0); j += 2) { float2 in0 = make_float2(tTMrO0(j), tTMrO0(j+1)); @@ -891,24 +891,24 @@ struct Sm100FmhaGenMainloopWarpspecialized { // good values would be either 32 or 64 const int kCorrectionTileSize = 32; - using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 64 cols of 32b elem + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 64 cols of 32b elem using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 64 cols of 32b elem typename CollectiveMmaPV::TiledMma mma; Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); - + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); tOtO_i.data() = tOtO_i.data().get() + tmem_O; - + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); - + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); @@ -918,7 +918,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { float2 scale_f32x2 = make_float2(scale, scale); Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int(TileShape{}) / kCorrectionTileSize>{})); - + auto copy_in = [&](int i) { Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); @@ -948,13 +948,16 @@ struct Sm100FmhaGenMainloopWarpspecialized { } Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size(tTMrO_i); j += 2) { - float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); - float2 out; - cute::mul(out, scale_f32x2, in); - tTMrO_i(j) = out.x; - tTMrO_i(j+1) = out.y; + + if (scale != 1.0f) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j+1) = out.y; + } } copy_out(i); @@ -981,7 +984,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); - + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); @@ -1019,11 +1022,14 @@ struct Sm100FmhaGenMainloopWarpspecialized { copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); // e^(scale * (old_max - new_max) - float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + float scale = (tTMEM_LOADVrS(kIdxOldRowMax) == tTMEM_LOADVrS(kIdxNewRowMax)) ? 1.0f : ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); pipeline_o.consumer_wait(pipeline_o_consumer_state); - correction_rescale(scale, uint32_t(TmemAllocation::O0)); + bool warp_do_correction = __any_sync(0xFFFFFFFF, scale != 1.0f); + if (warp_do_correction) { + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + } pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); ++pipeline_s1_c_consumer_state; @@ -1037,11 +1043,14 @@ struct Sm100FmhaGenMainloopWarpspecialized { copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); - scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + scale = (tTMEM_LOADVrS(kIdxOldRowMax) == tTMEM_LOADVrS(kIdxNewRowMax)) ? 1.0f : ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); pipeline_o.consumer_wait(pipeline_o_consumer_state); - correction_rescale(scale, uint32_t(TmemAllocation::O1)); + warp_do_correction = __any_sync(0xFFFFFFFF, scale != 1.0f); + if (warp_do_correction) { + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + } pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); ++pipeline_s0_c_consumer_state; diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp index bf41af9f..52e5c0c1 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp @@ -77,7 +77,7 @@ struct Sm100MlaFwdMainloopTmaWarpspecialized { static constexpr int StageCountK = 1; static constexpr int StageCountV = 1; static constexpr int StageCountKV = StageCountK + StageCountV; - // Support StageCountKV > 2 in the future. + // Support StageCountKV > 2 in the future. static_assert(StageCountK == 1 && StageCountV == 1, "Only support StageCountK = StageCountV = 1!"); static_assert(std::is_same_v>, "Only support ThreadShape = Shape<_2, _1, _1>"); @@ -116,24 +116,24 @@ struct Sm100MlaFwdMainloopTmaWarpspecialized { using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); using SmemStorageOneStageO = decltype(make_layout(replace<2>(TileShapePV{}, _1{}))); - - // Since the shared memory is not sufficient if we use separate Q, K, V, and O shared memory, - // we reuse shared memory for V and O to address this problem, + + // Since the shared memory is not sufficient if we use separate Q, K, V, and O shared memory, + // we reuse shared memory for V and O to address this problem, // and a barrier has been added to coordinate access to shared memory. static constexpr bool IsOrderLoadEpilogue = std::is_same_v; static const int NumWarpsEpilogue = 1; static const int NumWarpsLoad = 1; - + struct TensorStorageQKVO { cute::array_aligned> smem_q; - cute::array_aligned> smem_k; + cute::array_aligned> smem_k; cute::array_aligned> smem_o; // use as O0 cute::array_aligned> smem_v; // use as V0 and O1 }; struct TensorStorageQKV { cute::array_aligned> smem_q; - cute::array_aligned> smem_k; + cute::array_aligned> smem_k; cute::array_aligned> smem_v; }; @@ -689,7 +689,7 @@ struct Sm100MlaFwdMainloopTmaWarpspecialized { pipeline_c.producer_acquire(pipeline_c_producer_state); - ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + ElementQK acc_scale = (old_row_max == row_max_safe) ? 0.5f : 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); row_sum *= acc_scale; // row_sum = sum(reg_S) float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); @@ -755,7 +755,7 @@ struct Sm100MlaFwdMainloopTmaWarpspecialized { Tensor cS = domain_offset(logical_offset, cS_base); pipeline_c.producer_acquire(pipeline_c_producer_state); - + constexpr bool NeedMask = !std::is_same_v; CUTLASS_PRAGMA_NO_UNROLL @@ -841,13 +841,15 @@ struct Sm100MlaFwdMainloopTmaWarpspecialized { copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); #ifndef ONLY_SOFTMAX - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size(tTMrO); j += 2) { - float2 in = make_float2(tTMrO(j), tTMrO(j+1)); - float2 out; - cute::mul(out, scale_f32x2, in); - tTMrO(j) = out.x; - tTMrO(j+1) = out.y; + if (scale != 1.0f) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO); j += 2) { + float2 in = make_float2(tTMrO(j), tTMrO(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO(j) = out.x; + tTMrO(j+1) = out.y; + } } #endif @@ -1017,11 +1019,14 @@ struct Sm100MlaFwdMainloopTmaWarpspecialized { copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); // e^(scale * (old_max - new_max) - float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + float scale = (tTMEM_LOADVrS(kIdxOldRowMax) == tTMEM_LOADVrS(kIdxNewRowMax)) ? 1.0f : ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); pipeline_o.consumer_wait(pipeline_o_consumer_state); - correction_rescale(scale, uint32_t(TmemAllocation::O0)); + bool warp_do_correction = __any_sync(0xFFFFFFFF, scale != 1.0f); + if (warp_do_correction) { + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + } pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); ++pipeline_s1_c_consumer_state; @@ -1035,11 +1040,14 @@ struct Sm100MlaFwdMainloopTmaWarpspecialized { copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); - scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + scale = (tTMEM_LOADVrS(kIdxOldRowMax) == tTMEM_LOADVrS(kIdxNewRowMax)) ? 1.0f : ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); pipeline_o.consumer_wait(pipeline_o_consumer_state); - correction_rescale(scale, uint32_t(TmemAllocation::O1)); + warp_do_correction = __any_sync(0xFFFFFFFF, scale != 1.0f); + if (warp_do_correction) { + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + } pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); ++pipeline_s0_c_consumer_state; @@ -1178,10 +1186,10 @@ struct Sm100MlaFwdMainloopTmaWarpspecialized { auto tOgO = thr_copy.partition_D(sO); auto tOrO = make_tensor(shape(tOgO(_,_,_,_0{}))); clear(tOrO); - + copy(tiled_copy, tOrO, tOgO(_,_,_,_0{})); #endif - + if (epilogue.params.ptr_LSE != nullptr) { int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord); diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp index e9edb90e..7f6a1bb9 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp @@ -1039,10 +1039,6 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { auto pipeline_commit_state = pipeline_acquire_state; int pipeline_offset = 0; - for (int i = 0; i < StagesPV; i++) { - cutlass::arch::cp_async_fence(); - } - auto load_stage = [&](auto fn) { pipeline_load.producer_acquire(pipeline_acquire_state); fn(pipeline_acquire_state.index()); @@ -1132,6 +1128,13 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { pipeline_page_table.consumer_release(pipeline_pt_release_state); ++pipeline_pt_release_state; + // Extra async fence if the pipeline_offset can't meet the StagesPV + int extra_offset = 0; + while (pipeline_offset + extra_offset < StagesPV - 1) { + extra_offset++; + cutlass::arch::cp_async_fence(); + } + while (pipeline_offset > 0) { cutlass::arch::cp_async_fence(); @@ -2097,7 +2100,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { cutlass::arch::NamedBarrier( (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue - ).arrive(); + ).arrive_and_wait(); return; } diff --git a/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp index 465a9871..3d7efd0f 100644 --- a/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp +++ b/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp @@ -329,6 +329,20 @@ void fmha_bwd_reference_dQ( dim3 grid(size<0>(mDQ), size<2>(mDQ), 1); dim3 block(256); int shared_mem = size<0>(mK) * sizeof(typename TensorDQ::value_type); + cudaError_t result; + if (shared_mem >= (48 << 10)) { + result = cudaFuncSetAttribute( + &fmha_bwd_reference_dQ_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + cudaGetLastError(); // Clear the error state + throw std::runtime_error("Failed to allocate " + + std::to_string(shared_mem >> 10) + " KB dynamic smem for dQ tensor in ref. check - " + + "please try reducing seq_len or skipping ref. check"); + } + } fmha_bwd_reference_dQ_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); } @@ -356,6 +370,20 @@ void fmha_bwd_reference_dK( dim3 grid(K, H_K * B, 1); dim3 block(std::max(D, 256)); int shared_mem = size<0>(mDO) * sizeof(typename TensorDK::value_type); + cudaError_t result; + if (shared_mem >= (48 << 10)) { + result = cudaFuncSetAttribute( + &fmha_bwd_reference_dK_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + cudaGetLastError(); // Clear the error state + throw std::runtime_error("Failed to allocate " + + std::to_string(shared_mem >> 10) + " KB dynamic smem for dO tensor in ref. check - " + + "please try reducing seq_len or skipping ref. check"); + } + } fmha_bwd_reference_dK_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); } @@ -383,6 +411,20 @@ void fmha_bwd_reference_dV( dim3 grid(K, H_K * B, 1); dim3 block(std::max(D_VO, 256)); int shared_mem = size<0>(mDO) * sizeof(typename TensorDV::value_type); + cudaError_t result; + if (shared_mem >= (48 << 10)) { + result = cudaFuncSetAttribute( + &fmha_bwd_reference_dV_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + cudaGetLastError(); // Clear the error state + throw std::runtime_error("Failed to allocate " + + std::to_string(shared_mem >> 10) + " KB dynamic smem for dO tensor in ref. check - " + + "please try reducing seq_len or skipping ref. check"); + } + } fmha_bwd_reference_dV_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); } diff --git a/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp index d674ee95..a93efa30 100644 --- a/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp +++ b/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp @@ -189,6 +189,19 @@ void fmha_reference( dim3 grid(size<0>(mO), size<2>(mO), 1); dim3 block(256); int shared_mem = size<0>(mK) * int(sizeof(typename TensorLSE::value_type)); + cudaError_t result; + if (shared_mem >= (48 << 10)) { + result = cudaFuncSetAttribute( + &fmha_reference_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + cudaGetLastError(); // Clear the error state + throw std::runtime_error("Failed to allocate " + + std::to_string(shared_mem >> 10) + " KB dynamic smem for S/P tensor in ref. check - " + + "please try reducing seq_len or skipping ref. check"); + } + } fmha_reference_kernel<<>>(problem_shape_in, mQ, mK, mV, mO, mLSE, mask); } diff --git a/examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp index c83ebdb7..bc057e0b 100644 --- a/examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp +++ b/examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp @@ -185,9 +185,11 @@ void fmha_mla_reference( cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem); if (cudaSuccess != result) { - result = cudaGetLastError(); // to clear the error bit - throw std::runtime_error("couldn't perform smem optin"); - } + cudaGetLastError(); // Clear the error state + throw std::runtime_error("Failed to allocate " + + std::to_string(shared_mem >> 10) + " KB dynamic smem for S/P tensor in ref. check - " + + "please try reducing seq_len or skipping ref. check"); + } } fmha_mla_reference_kernel<<>>( problem_shape, mSeq, mPT, mQL, mQR, mCL, mKR, mO, mLSE, scale); diff --git a/examples/86_blackwell_mixed_dtype_gemm/86_blackwell_mixed_dtype.cu b/examples/86_blackwell_mixed_dtype_gemm/86_blackwell_mixed_dtype.cu index 0a23488e..a2b5c877 100644 --- a/examples/86_blackwell_mixed_dtype_gemm/86_blackwell_mixed_dtype.cu +++ b/examples/86_blackwell_mixed_dtype_gemm/86_blackwell_mixed_dtype.cu @@ -332,6 +332,7 @@ bool verify(MixedDtypeOptions const& options) { // Compute reference output // + // Reference uses dequantized B matrix as input. Hence we need to change alignment. constexpr int AlignmentBdq = 128 / cutlass::sizeof_bits::value; using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< @@ -351,7 +352,7 @@ bool verify(MixedDtypeOptions const& options) { ElementAccumulator, ElementAccumulator, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - EpilogueSchedule + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< diff --git a/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_rcgrouped.cu b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_rcgrouped.cu new file mode 100644 index 00000000..a6e757d6 --- /dev/null +++ b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_rcgrouped.cu @@ -0,0 +1,854 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Ragged Contiguous Grouped GEMM example using CUTLASS 3 APIs for the NVIDIA Blackwell SM100 architecture. + + This example demonstrates an implementation of Ragged Contiguous Grouped GEMM using a TMA+TMA Blackwell + SM100 TensorOp-based warp-specialized kernel. In Ragged Contiguous Grouped Gemms, weights are of same size for each group, + i.e. matrix A, whereas Activations differ in shape and stride between groups. Therefore, we used a Batched TMA Load + to load Weights matrix into the shared memory for MMA. The Activations are loaded using a PtrArray logic with updates to TMA descriptors. + + To run this example: + + $ ./examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_rcgrouped --m=2048 --n=2048 --k=2048 --groups=10 + + The above example command makes all 10 groups to be sized at the given m, n, k sizes. + Skipping any of the problem dimensions randomizes it across the different groups. + Same applies for alpha and beta values that are randomized across the different groups. + + To run this example for a set of problems using the benchmark option: + + $ ./examples/92_blackwell_grouped_gemm/92_blackwell_moe_gemm_rcgrouped --benchmark=./test_benchmark.txt + + Where the test_benchmark.txt may look as such: + 0 256x512x128 + 1 256x512x512 + 2 512x256x128 + 3 256x256x128 + 4 256x512x1024 + 5 1024x512x128 and so on +*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" +using namespace cute; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using ElementC = cutlass::half_t; // Element type for C and D matrix operands + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + +// Runtime Cluster Shape +using ClusterShape = Shape; + +// Different configs for 1SM and 2SM MMA kernel +struct MMA1SMConfig { + using MmaTileShape = Shape<_128,_128,_64>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch +}; + +struct MMA2SMConfig { + using MmaTileShape = Shape<_256,_256,_64>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch +}; + +template +struct GivenGemmSchedule { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + typename ScheduleConfig::MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementC, LayoutC *, AlignmentC, + typename ScheduleConfig::EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + typename ScheduleConfig::MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename ScheduleConfig::KernelSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +using GemmKernel1SM = GivenGemmSchedule::GemmKernel; +using Gemm1SM = GivenGemmSchedule::Gemm; +using Gemm = Gemm1SM; + +using GemmKernel2SM = GivenGemmSchedule::GemmKernel; +using Gemm2SM = GivenGemmSchedule::Gemm; + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +using StrideA = cutlass::detail::TagToStrideA_t; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; + +StrideA stride_A; + +// Host-side allocations +std::vector offset_B; +std::vector offset_C; +std::vector offset_D; + +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; + +std::vector alpha_host; +std::vector beta_host; + +// Device-side allocations +cutlass::DeviceAllocation problem_sizes; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +cutlass::HostTensor tensor_D; +cutlass::HostTensor tensor_ref_D; + +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_ref_D; + +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; + +// Note, this is an array of pointers to alpha and beta scaling values per group +cutlass::DeviceAllocation alpha_device; +cutlass::DeviceAllocation beta_device; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; +// Command line options parsing +struct Options { + + bool help = false; + bool use_pdl = false; + bool sparse_test = false; + + float alpha = FLT_MAX; + float beta = FLT_MAX; + int iterations = 1000; + int warmup = 1000; + int m = 1024, n = 2048, k = 512, groups = 10; + double sparse_prob = 0.1; + dim3 cluster_shape = dim3(4,2,1); + dim3 cluster_shape_fallback = dim3(2,1,1); + RasterOrderOptions raster_order = RasterOrderOptions::AlongM; + int max_sm_count = INT_MAX; + std::string benchmark_path; + std::vector problem_sizes_host; + int const tma_alignment_bits = 128; + int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + if (cmd.check_cmd_line_flag("use_pdl")) { + use_pdl = true; + } + if (cmd.check_cmd_line_flag("sparse_test")) { + sparse_test = true; + cmd.get_cmd_line_argument("sparse_prob", sparse_prob); + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, FLT_MAX); + cmd.get_cmd_line_argument("beta", beta, FLT_MAX); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("warmup", warmup); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + cmd.get_cmd_line_argument("cluster_m", cluster_shape.x); + cmd.get_cmd_line_argument("cluster_n", cluster_shape.y); + cmd.get_cmd_line_argument("cluster_fallback_m", cluster_shape_fallback.x); + cmd.get_cmd_line_argument("cluster_fallback_n", cluster_shape_fallback.y); + cmd.get_cmd_line_argument("max_sm_count", max_sm_count, INT_MAX); + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + problem_sizes_host.clear(); + return; + } + } + else if(sparse_test){ + std::cout << "Running sparse test" << std::endl; + randomize_sparse_problems(cmd, sparse_prob); + } + else { + randomize_problems(cmd); + } + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster_order = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster_order = RasterOrderOptions::AlongM; + } + } + + void randomize_problems(cutlass::CommandLine &cmd) { + int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1; + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes_host.reserve(groups); + + m = cmd_line_m; + k = cmd_line_k; + if (m < 1) { + m = alignment * ((rand() % 64) + 1); + } + if (k < 1) { + k = alignment * ((rand() % 64) + 1); + } + + for (int i = groups; i > 0; i--) { + int n = cmd_line_n; + if (n < 0) { + n = alignment * ((rand() % 64) + 1); + } + problem_sizes_host.push_back({m, n, k}); + } + } + + void randomize_sparse_problems(cutlass::CommandLine &cmd, double prob) { + int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1; + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + int num_to_set = int(prob * groups); + + std::vector set_to_zero(groups, false); + std::vector indices(groups, false); + + int index=0; + for(auto it = indices.begin(); it != indices.end(); ++it){ + *it = index++; + } + + // Shuffle indices + std::random_device rd; + std::mt19937 gen(rd()); + std::shuffle(indices.begin(), indices.end(), gen); + + // Set first num_to_set entries to true + for (int i = 0; i < num_to_set; ++i) + set_to_zero[indices[i]] = true; + + problem_sizes_host.reserve(groups); + + m = cmd_line_m; + k = cmd_line_k; + if (m < 1) { + m = alignment * ((rand() % 64) + 1); + } + if (k < 1) { + k = alignment * ((rand() % 64) + 1); + } + + for (int i = groups; i > 0; i--) { + int n = cmd_line_n; + if (!set_to_zero[i]){ + if (n < 0) { + n = alignment * ((rand() % 64) + 1); + } + } + else{ + n = 0; + } + + problem_sizes_host.push_back({m, n, k}); + } + } + + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + extent.at(i) = std::atoi(tokens.at(i).c_str()); + } + problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); + } + groups = static_cast(problem_sizes_host.size()); + + return true; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "92_blackwell_moe_gemm_rcgrouped\n\n" + << " Blackwell FP8 Grouped GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --cluster_m= and --cluster_n= Sets the X,Y dims of the preferred cluster shape\n" + << " --cluster_fallback_m= and --cluster_fallback_n= Sets the X,Y dims of the fallback cluster shape\n\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M)\n\n" + << " --iterations= Number of profiling iterations to perform\n\n" + << " --benchmark= Executes a benchmark problem size\n" + << " --max_sm_count= Run kernels using only these number of SMs\n" + << " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "92_blackwell_moe_gemm_rcgrouped" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, std::vector problem_sizes_host) const + { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + + for (auto const & problem : problem_sizes_host) { + fmas += static_cast(get<0>(problem)) * + static_cast(get<1>(problem)) * + static_cast(get<2>(problem)); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = static_cast(2); + scope_min = static_cast(0); + } else if (bits_input <= 8) { + scope_max = static_cast(2); + scope_min = static_cast(-2); + } else { + scope_max = static_cast(8); + scope_min = static_cast(-8); + } + + scope_min = static_cast(0); + scope_max = static_cast(2); + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Allocates device-side data +void allocate(const Options &options) { + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_B = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + + } + + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_ref_D.reset(total_elements_D); + block_alpha.reset(options.groups); + block_beta.reset(options.groups); + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.groups)); + auto a_coord = cutlass::make_Coord(options.m * options.groups, options.k); + block_A.reset(a_coord.product()); + +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + uint64_t seed = 2020; + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + // + // Assign pointers + // + + std::vector ptr_B_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + for (int32_t i = 0; i < options.groups; ++i) { + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +typename Gemm::Arguments args_from_options(Options &options) +{ + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = min(cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id), options.max_sm_count); + + if (!is_static_v) { + if (size<0>(typename Gemm::GemmKernel::CollectiveMainloop::AtomThrShapeMNK{}) == 2 && + (options.cluster_shape.x < 2 || options.cluster_shape_fallback.x < 2)) { + std::cout << "Error: MMA2SMConfig kernel config needs cluster_dim.x >= 2" << std::endl; + } + hw_info.cluster_shape = options.cluster_shape; + hw_info.cluster_shape_fallback = options.cluster_shape_fallback; + } + + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + // If alpha/beta are provided (via cmd line args) and are scalar, then same alpha/beta applies to all batches. + // If pointers to alpha/beta are provided, then alpha/beta can differ between batches/groups. + if (options.alpha != FLT_MAX){ + // Single alpha for all groups + fusion_args.alpha = options.alpha; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.dAlpha = {_0{}, _0{}, 0}; + } + else { + fusion_args.alpha = 0; + fusion_args.alpha_ptr_array = alpha_device.get(); + // Only one alpha per each group + fusion_args.dAlpha = {_0{}, _0{}, 1}; + } + if (options.beta != FLT_MAX) { + // Single beta for all groups + fusion_args.beta = options.beta; + fusion_args.beta_ptr_array = nullptr; + fusion_args.dBeta = {_0{}, _0{}, 0}; + } + else { + fusion_args.beta = 0; + fusion_args.beta_ptr_array = beta_device.get(); + // Only one beta per each group + fusion_args.dBeta = {_0{}, _0{}, 1}; + } + + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = options.raster_order; + + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {block_A.get(), stride_A, ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, scheduler + }; + + return arguments; +} + +bool verify(const Options &options) { + bool passed = true; + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + cutlass::TensorRef ref_A(block_A.get() + size_t(1) * i * M * K, Gemm::LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get() + offset_B.at(i), Gemm::LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), Gemm::LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), Gemm::LayoutD::packed({M, N})); + + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + {M, N, K}, + ElementAccumulator(alpha_host.at(i)), + ref_A, + ref_B, + ElementAccumulator(beta_host.at(i)), + ref_C, + ref_D); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N); + + } + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options, bool host_problem_shapes_available = true) +{ + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl)); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + for (int iter = 0; iter < options.warmup; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl)); + } + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl)); + } + timer.stop(); + + // Compute average setup and runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host); + + std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " TFLOPS : " << result.gflops / 1000.0 << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example + if (__CUDACC_VER_MAJOR__ < 12 || + ((__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8) + ) + ) { + std::cerr << "This example requires CUDA 12.8 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) { + std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + allocate(options); + initialize(options); + + // + // Evaluate CUTLASS kernels + // + + std::cout << "Running kernel with 1SM MMA config:" << std::endl; + run(options); + std::cout << "Running kernel with 2SM MMA config:" << std::endl; + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/92_blackwell_moe_gemm/CMakeLists.txt b/examples/92_blackwell_moe_gemm/CMakeLists.txt index c88a461e..b0ece023 100644 --- a/examples/92_blackwell_moe_gemm/CMakeLists.txt +++ b/examples/92_blackwell_moe_gemm/CMakeLists.txt @@ -36,6 +36,9 @@ set(TEST_DEEPSEEK_A_FP4 --m=1024 --n=1 --k=7168 --l=256) # TP=1 shape is too set(TEST_DEEPSEEK_B_FP4 --m=7168 --n=1 --k=512 --l=256) set(TEST_IRREGULAR_MNK_FP4 --m=4080 --n=9 --k=4160 --l=8) +set(TEST_FIXED --m=2048 --n=5120 --k=8192 --iterations=0) # Fixed problem sizes +set(TEST_SPARSE_GROUPS --sparse_test --sparse_prob=0.3 --iterations=0) + if (CUTLASS_NVCC_ARCHS MATCHES 100a) cutlass_example_add_executable( 92_blackwell_moe_gemm_regular @@ -53,6 +56,14 @@ cutlass_example_add_executable( 92_blackwell_moe_gemm_grouped.cu ) +cutlass_example_add_executable( + 92_blackwell_moe_gemm_rcgrouped + 92_blackwell_moe_gemm_rcgrouped.cu + TEST_COMMAND_OPTIONS + TEST_FIXED + TEST_SPARSE_GROUPS +) + cutlass_example_add_executable( 92_blackwell_moe_gemm_fp4_regular 92_blackwell_moe_gemm_fp4_regular.cu diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 6d97ea56..1682cca9 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -80,6 +80,7 @@ function(cutlass_example_add_executable NAME) endfunction() + foreach(EXAMPLE 00_basic_gemm 01_cutlass_utilities diff --git a/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py b/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py index 79a23079..6f2504f7 100644 --- a/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py +++ b/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py @@ -89,7 +89,7 @@ def tensor_op_gemm_wrapper( k: cutlass.Int32, l: cutlass.Int32, ): - print(f"\n[DSL INFO] Input Parameters:") + print("\n[DSL INFO] Input Parameters:") print(f"[DSL INFO] mnkl: {(m, n, k, l)}") # Assume alignment of shape to call tensorop_gemm example @@ -111,7 +111,7 @@ def tensor_op_gemm_wrapper( tensor_op_gemm = TensorOpGemm( a_ptr.value_type, c_ptr.value_type, cutlass.Float32, (2, 2, 1) ) - print(f"\n[DSL INFO] Created TensorOpGemm instance") + print("\n[DSL INFO] Created TensorOpGemm instance") print(f"[DSL INFO] Input dtype: {a_ptr.value_type}") print(f"[DSL INFO] Output dtype: {c_ptr.value_type}") print(f"[DSL INFO] Accumulation dtype: {cutlass.Float32}") @@ -119,11 +119,11 @@ def tensor_op_gemm_wrapper( # No need to compile inside jit function tensor_op_gemm(mA, mB, mC) - print(f"\n[DSL INFO] Executed TensorOpGemm") + print("\n[DSL INFO] Executed TensorOpGemm") def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]): - print(f"\nRunning TensorOpGemm test with:") + print("\nRunning TensorOpGemm test with:") print(f"Tensor dimensions: {mnkl}") # (M,K,L) @@ -139,7 +139,7 @@ def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]): mnkl[3], mnkl[0], mnkl[1], dtype=torch.float16, device="cuda" ).permute(1, 2, 0) - print(f"Input tensor shapes:") + print("Input tensor shapes:") print(f"a: {a.shape}, dtype: {a.dtype}") print(f"b: {b.shape}, dtype: {b.dtype}") print(f"c: {c.shape}, dtype: {c.dtype}\n") @@ -158,7 +158,7 @@ def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]): ref = torch.einsum("mkl,nkl->mnl", a, b) torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05) - print(f"\n[DSL INFO] Results verified successfully!") + print("\n[DSL INFO] Results verified successfully!") print(f"First few elements of result: \n{c[:3, :3, :3]}") diff --git a/examples/python/CuTeDSL/ampere/call_from_jit.py b/examples/python/CuTeDSL/ampere/call_from_jit.py index ee71e53f..68e1b0ae 100644 --- a/examples/python/CuTeDSL/ampere/call_from_jit.py +++ b/examples/python/CuTeDSL/ampere/call_from_jit.py @@ -169,7 +169,7 @@ def tensor_op_gemm_wrapper( acc_dtype: Type[cutlass.Numeric], atom_layout_mnk: cutlass.Constexpr[tuple[int, int, int]], ): - print(f"\n[DSL INFO] Input Parameters:") + print("\n[DSL INFO] Input Parameters:") print(f"[DSL INFO] mnkl: {mnkl}") print(f"[DSL INFO] buffer_a: {buffer_a}") print(f"[DSL INFO] buffer_b: {buffer_b}") @@ -181,7 +181,7 @@ def tensor_op_gemm_wrapper( mB = buffer_b.to_tensor(cute.select(mnkl, mode=[3, 1, 2])) mC = buffer_c.to_tensor(cute.select(mnkl, mode=[3, 0, 1])) - print(f"\n[DSL INFO] Created Tensors:") + print("\n[DSL INFO] Created Tensors:") print(f"[DSL INFO] mA = {mA}") print(f"[DSL INFO] mB = {mB}") print(f"[DSL INFO] mC = {mC}") @@ -192,7 +192,7 @@ def tensor_op_gemm_wrapper( acc_dtype, atom_layout_mnk, ) - print(f"\n[DSL INFO] Created TensorOpGemm instance") + print("\n[DSL INFO] Created TensorOpGemm instance") print(f"[DSL INFO] Input dtype: {buffer_a.ptr.value_type}") print(f"[DSL INFO] Output dtype: {buffer_c.ptr.value_type}") print(f"[DSL INFO] Accumulation dtype: {acc_dtype}") @@ -200,11 +200,11 @@ def tensor_op_gemm_wrapper( # No need to compile inside jit function tensor_op_gemm(mA, mB, mC) - print(f"\n[DSL INFO] Executed TensorOpGemm") + print("\n[DSL INFO] Executed TensorOpGemm") def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]): - print(f"\nRunning TensorOpGemm test with:") + print("\nRunning TensorOpGemm test with:") print(f"Tensor dimensions: {mnkl}") ab_dtype = cutlass.Float16 @@ -220,7 +220,7 @@ def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]): mnkl[3], mnkl[0], mnkl[1], dtype=torch_dtype(c_dtype), device="cuda" ) - print(f"Input tensor shapes:") + print("Input tensor shapes:") print(f"a: {a.shape}, dtype: {a.dtype}") print(f"b: {b.shape}, dtype: {b.dtype}") print(f"c: {c.shape}, dtype: {c.dtype}\n") @@ -251,7 +251,7 @@ def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]): ref = torch.einsum("lmk,lnk->lmn", a, b) torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05) - print(f"\n[DSL INFO] Results verified successfully!") + print("\n[DSL INFO] Results verified successfully!") print(f"First few elements of result: \n{c[:3, :3, :3]}") diff --git a/examples/python/CuTeDSL/ampere/elementwise_add.py b/examples/python/CuTeDSL/ampere/elementwise_add.py index 596822cc..4d41d25b 100644 --- a/examples/python/CuTeDSL/ampere/elementwise_add.py +++ b/examples/python/CuTeDSL/ampere/elementwise_add.py @@ -28,11 +28,10 @@ import argparse +import torch import time from typing import Type -import cuda.bindings.driver as cuda -import torch import cutlass import cutlass.cute as cute @@ -154,7 +153,7 @@ def elementwise_add_kernel( blkCrd = cC[blk_coord] # (TileM, TileN) # Note: these prints only run at compile/jit time - print(f"[DSL INFO] Sliced Tensors per thread block:") + print("[DSL INFO] Sliced Tensors per thread block:") print(f"[DSL INFO] blkA = {blkA.type}") print(f"[DSL INFO] blkB = {blkB.type}") print(f"[DSL INFO] blkC = {blkC.type}") @@ -182,9 +181,9 @@ def elementwise_add_kernel( frgC = cute.make_fragment_like(thrC) thrCrd = thr_copy_C.partition_S(blkCrd) - frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean) + frgPred = cute.make_rmem_tensor(thrCrd.shape, cutlass.Boolean) - print(f"[DSL INFO] Sliced Tensors per thread:") + print("[DSL INFO] Sliced Tensors per thread:") print(f"[DSL INFO] thrA = {thrA.type}") print(f"[DSL INFO] thrB = {thrB.type}") print(f"[DSL INFO] thrC = {thrC.type}") @@ -233,18 +232,18 @@ def elementwise_add(mA, mB, mC, copy_bits: cutlass.Constexpr = 128): val_layout = cute.make_ordered_layout((4, vector_size), order=(1, 0)) tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout) - print(f"[DSL INFO] Input Tensors:") + print("[DSL INFO] Input Tensors:") print(f"[DSL INFO] mA = {mA.type}") print(f"[DSL INFO] mB = {mB.type}") - print(f"[DSL INFO] Tiling Parameters:") + print("[DSL INFO] Tiling Parameters:") print(f"[DSL INFO] tiler_mn = {tiler_mn} per thread block") print(f"[DSL INFO] tv_layout = {tv_layout}") gA = cute.zipped_divide(mA, tiler_mn) # ((TileM,TileN),(RestM,RestN)) gB = cute.zipped_divide(mB, tiler_mn) # ((TileM,TileN),(RestM,RestN)) gC = cute.zipped_divide(mC, tiler_mn) # ((TileM,TileN),(RestM,RestN)) - print(f"[DSL INFO] Tiled Tensors:") + print("[DSL INFO] Tiled Tensors:") print(f"[DSL INFO] gA = {gA.type}") print(f"[DSL INFO] gB = {gB.type}") print(f"[DSL INFO] gC = {gC.type}") @@ -271,7 +270,7 @@ def run_elementwise_add( warmup_iterations=2, iterations=200, ): - print(f"\nRunning Elementwise Add test with:") + print("\nRunning Elementwise Add test with:") print(f"Tensor dimensions: [{M}, {N}]") print(f"Input and Output Data type: {dtype}") @@ -285,7 +284,7 @@ def run_elementwise_add( c = torch.zeros_like(a) - print(f"Input tensor shapes:") + print("Input tensor shapes:") print(f"a: {a.shape}, dtype: {a.dtype}") print(f"b: {b.shape}, dtype: {b.dtype}") print(f"c: {c.shape}, dtype: {c.dtype}\n") @@ -307,7 +306,9 @@ def run_elementwise_add( print("Compiling kernel with cute.compile ...") start_time = time.time() - compiled_func = cute.compile(elementwise_add, a_tensor, b_tensor, c_tensor) + compiled_func = cute.compile( + elementwise_add, a_tensor, b_tensor, c_tensor, options="--generate-line-info" + ) compilation_time = time.time() - start_time print(f"Compilation time: {compilation_time:.4f} seconds") @@ -386,7 +387,7 @@ if __name__ == "__main__": args = parser.parse_args() if not torch.cuda.is_available(): - raise RuntimeError(f"Ampere GPU is required to run this example!") + raise RuntimeError("Ampere GPU is required to run this example!") run_elementwise_add( args.M, diff --git a/examples/python/CuTeDSL/ampere/elementwise_apply.py b/examples/python/CuTeDSL/ampere/elementwise_apply.py index 43c3b5fc..224e8b26 100644 --- a/examples/python/CuTeDSL/ampere/elementwise_apply.py +++ b/examples/python/CuTeDSL/ampere/elementwise_apply.py @@ -30,17 +30,18 @@ import argparse import operator import time -from typing import Type, List +from functools import partial +from typing import List, Type import cuda.bindings.driver as cuda -import torch - -import cutlass import cutlass.cute as cute import cutlass.cute.testing as testing import cutlass.torch as cutlass_torch +import torch from cutlass.cute.runtime import from_dlpack +import cutlass + """ An Elementwise Apply Example using CuTe DSL. @@ -78,103 +79,83 @@ while maintaining high performance through efficient memory access patterns. @cute.kernel def elementwise_apply_kernel( op: cutlass.Constexpr, - inputs: List[cute.Tensor], - gC: cute.Tensor, + mInputs: List[cute.Tensor], + mC: cute.Tensor, cC: cute.Tensor, # coordinate tensor shape: cute.Shape, tv_layout: cute.Layout, # (tid, vid) -> logic coord ): tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() + bidx, bidy, _ = cute.arch.block_idx() + + ############################################################################### + # Slice to local tile of thread block + ############################################################################### + blk_crd = ((None, None), (bidx, bidy)) - # slice for CTAs - cta_coord = ((None, None), bidx) - # logical coord -> address # Leverage the meta-programming capability of the DSL to slice the tensors for each input # All for loops below on input tensors would be fully unrolled automatically at compile time - ctaInputs = [t[cta_coord] for t in inputs] # (TileM, TileN) - ctaC = gC[cta_coord] # (TileM, TileN) - ctaCrd = cC[cta_coord] # (TileM, TileN) + # logical coord -> memory address + gInputs = [t[blk_crd] for t in mInputs] # (TileM, TileN) + gC = mC[blk_crd] # (TileM, TileN) + gCrd = cC[blk_crd] # (TileM, TileN) - print(f"[DSL INFO] Sliced Tensors per thread block:") - for i in cutlass.range_constexpr(len(ctaInputs)): - print(f"[DSL INFO] ctaInputs{i} = {ctaInputs[i].type}") - print(f"[DSL INFO] ctaC = {ctaC.type}") - print(f"[DSL INFO] ctaCrd = {ctaCrd.type}") + print("[DSL INFO] Sliced Tensors per thread block:") + for i in cutlass.range_constexpr(len(gInputs)): + print(f"[DSL INFO] ctaInputs{i} = {gInputs[i].type}") + print(f"[DSL INFO] gC = {gC.type}") + print(f"[DSL INFO] gCrd = {gCrd.type}") - # compose with CTA TV layout - # (tid, vid) -> address - tidfrgInputs = [cute.composition(t, tv_layout) for t in ctaInputs] - tidfrgC = cute.composition(ctaC, tv_layout) - tidfrgCrd = cute.composition(ctaCrd, tv_layout) - # print(f"{tv_layout = }") - # print(f"{tidfrgAB[0] = }") + ############################################################################### + # Compose with thread block TV layout to map thread & value indices to memory address + ############################################################################### + # (tid, vid) -> memory address + tidfrgInputs = [cute.composition(t, tv_layout) for t in gInputs] + tidfrgC = cute.composition(gC, tv_layout) + tidfrgCrd = cute.composition(gCrd, tv_layout) - thr_coord = (tidx, (None, None)) + # repeat None like vid to remove hierarchy of layout + thr_crd = (tidx, cute.repeat_like(None, tidfrgInputs[0][1])) - # slice for threads + ############################################################################### + # Slice to local tile of thread + ############################################################################### # vid -> address - thrInputs = [t[thr_coord] for t in tidfrgInputs] # (V) - thrC = tidfrgC[thr_coord] # (V) - thrCrd = tidfrgCrd[thr_coord] + thrInputs = [t[thr_crd] for t in tidfrgInputs] # (V) + thrC = tidfrgC[thr_crd] # (V) + thrCrd = tidfrgCrd[thr_crd] - print(f"[DSL INFO] Sliced Tensors per thread:") + print("[DSL INFO] Sliced Tensors per thread:") for i in cutlass.range_constexpr(len(thrInputs)): print(f"[DSL INFO] thrInputs{i} = {thrInputs[i].type}") print(f"[DSL INFO] thrC = {thrC.type}") print(f"[DSL INFO] thrCrd = {thrCrd.type}") - # allocate fragments for gmem->rmem - frgInputs = [cute.make_fragment_like(t, t.element_type) for t in thrInputs] - frgC = cute.make_fragment_like(thrC, gC.element_type) - frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean) + ############################################################################### + # Compute predicate for out of boundary checks + ############################################################################### + frgPred = cute.make_rmem_tensor(thrCrd.shape, cutlass.Boolean) + print(f"[DSL INFO] frgPred = {frgPred.type}") - for i in cutlass.range(cute.size(frgPred), unroll=1): + for i in cutlass.range_constexpr(cute.size(frgPred)): frgPred[i] = cute.elem_less(thrCrd[i], shape) # if tidx == 0 and bidx == 0: # cute.print_tensor(frgPred) ########################################################## - # Move data to reg address space + # Load data and compute result ########################################################## - # declare the atoms which will be used later for memory copy - # Compile time validation: expect same element type for all input tensors so as to reuse the copy atom for load - assert all(t.element_type == inputs[0].element_type for t in inputs) - - copy_atom_load = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - inputs[0].element_type, - num_bits_per_copy=inputs[0].element_type.width, - ) - copy_atom_store = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - gC.element_type, - num_bits_per_copy=gC.element_type.width, - ) - - for thrInput, frgInput in zip(thrInputs, frgInputs): - cute.copy(copy_atom_load, thrInput, frgInput, pred=frgPred) - # Load data before use. The compiler will optimize the copy and load # operations to convert some memory ld/st into register uses. - result = op(*[frgInput.load() for frgInput in frgInputs]) - - # Save the results back to registers. Here we reuse b's registers. - frgC.store(result) - - # Copy the results back to c - cute.copy(copy_atom_store, frgC, thrC, pred=frgPred) + result = op(*[thrInput.load() for thrInput in thrInputs]) + thrC.store(result) @cute.jit def elementwise_apply( - op: cutlass.Constexpr, - a: cute.Tensor, - b: cute.Tensor, - result: cute.Tensor, - stream: cuda.CUstream, + op: cutlass.Constexpr, inputs, result: cute.Tensor, stream: cuda.CUstream ): """CUDA kernel applying binary operator on each element of two n-D input tensors in CuTe Python and store to result tensor. @@ -232,51 +213,71 @@ def elementwise_apply( # Opt-3: SOL with 2D thread tile # * mA layout: (4096, 4096):(4096, 1) - # * TV layout map to (16, 128) logical tile + # * TV layout map to (64, 256) logical tile # * tidx maps to mode-1 and input layout is contiguous on mode-1 for coalesced load-store - thr_layout = cute.make_layout((4, 32), stride=(32, 1)) - val_layout = cute.make_layout((4, 4), stride=(4, 1)) + + # Use 128bit(16B) load as canonicalized form of val_layout then recast to target element-type + coalesced_ldst_bytes = 16 + + # Compile time validation: expect same element type for all input tensors + assert all(t.element_type == inputs[0].element_type for t in inputs) + dtype = inputs[0].element_type + + thr_layout = cute.make_ordered_layout((4, 64), order=(1, 0)) + val_layout = cute.make_ordered_layout((16, coalesced_ldst_bytes), order=(1, 0)) + val_layout = cute.recast_layout(dtype.width, 8, val_layout) tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout) - print(f"[DSL INFO] Input Tensors:") - print(f"[DSL INFO] a = {a.type}") - print(f"[DSL INFO] b = {b.type}") - print(f"[DSL INFO] result = {result.type}") + print("[DSL INFO] Input Tensors:") + for i, t in enumerate(inputs): + print(f"[DSL INFO] inputs{i} = {t}") + print(f"[DSL INFO] result = {result}") - print(f"[DSL INFO] Tiling Parameters:") + print("[DSL INFO] Tiling Parameters:") print(f"[DSL INFO] tiler_mn = {tiler_mn} per thread block") print(f"[DSL INFO] tv_layout = {tv_layout}") - gA = cute.zipped_divide(a, tiler_mn) # ((TileM, TileN), (RestM, RestN)) - gB = cute.zipped_divide(b, tiler_mn) # ((TileM, TileN), (RestM, RestN)) - gC = cute.zipped_divide(result, tiler_mn) # ((TileM, TileN), (RestM, RestN)) + print("[DSL INFO] Tiled Tensors:") + mInputs = [cute.zipped_divide(input, tiler_mn) for input in inputs] + # ((TileM, TileN), (RestM, RestN)) + mC = cute.zipped_divide(result, tiler_mn) - print(f"[DSL INFO] Tiled Tensors:") - print(f"[DSL INFO] gA = {gA.type}") - print(f"[DSL INFO] gB = {gB.type}") - print(f"[DSL INFO] gC = {gC.type}") + # (RestM, RestN) -> (RestN, RestM) + remap_block = cute.make_ordered_layout( + cute.select(mInputs[0].shape[1], mode=[1, 0]), order=(1, 0) + ) + for i, t in enumerate(mInputs): + print(f"[DSL INFO] gInputs{i} = {mInputs[i]}") + mInputs[i] = cute.composition(t, (None, remap_block)) + print(f"[DSL INFO] gInputs{i} (remapped) = {mInputs[i]}") + + mC = cute.composition(mC, (None, remap_block)) + print(f"[DSL INFO] gC = {mC}") idC = cute.make_identity_tensor(result.shape) cC = cute.zipped_divide(idC, tiler=tiler_mn) - print(f"[DSL INFO] coord tensor = {cC.type}") + print(f"[DSL INFO] coord tensor = {cC}") # Launch the kernel asynchronously - # Async token(s) can also be specified as dependencies - elementwise_apply_kernel( - op, - [gA, gB], # Group input tensors into a list as a single argument - gC, - cC, - result.shape, - tv_layout, - ).launch( - grid=[cute.size(gC, mode=[1]), 1, 1], + # Group input tensors into a list as a single argument + elementwise_apply_kernel(op, mInputs, mC, cC, result.shape, tv_layout).launch( + # Compute production at each mode of mC.shape[1] to get multi-dimensional grid size + grid=cute.product_each(mC.shape[1]), block=[cute.size(tv_layout, mode=[0]), 1, 1], stream=stream, ) -def run_elementwise_apply_and_verify( +@cutlass.dsl_user_op +def leaky_relu(x, alpha, *, loc=None, ip=None): + return cute.where(x > 0, x, alpha * x, loc=loc, ip=ip) + + +def leaky_relu_ref(x, alpha): + return torch.where(x > 0, x, alpha * x) + + +def run_and_verify( op, M, N, @@ -287,14 +288,23 @@ def run_elementwise_apply_and_verify( iterations=100, ): if not torch.cuda.is_available(): - raise RuntimeError(f"Ampere GPU is required to run this example!") + raise RuntimeError("NVIDIA GPU is required to run this example!") + + if op == "leaky_relu": + op = partial(leaky_relu, alpha=0.01) + ref_op = partial(leaky_relu_ref, alpha=0.01) + num_inputs = 1 + else: + op = getattr(operator, op) + ref_op = op + num_inputs = 2 # Create non default CUDA stream from PyTorch torch_stream = torch.cuda.Stream() # Get the raw stream pointer as a CUstream current_stream = cuda.CUstream(torch_stream.cuda_stream) - print(f"\nRunning Elementwise Apply test with:") + print("\nRunning Elementwise Apply test with:") print(f"Tensor dimensions: [{M}, {N}]") print(f"Input and Output Data type: {dtype}") print(f"Warmup iterations: {warmup_iterations}") @@ -303,85 +313,78 @@ def run_elementwise_apply_and_verify( torch_dtype = cutlass_torch.dtype(dtype) # Allocate tensors with random values. - a = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) - b = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) - c = torch.zeros_like(a) + inputs = [ + torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) + for _ in range(num_inputs) + ] + c = torch.zeros_like(inputs[0]) - print(f"Input tensor shapes:") - print(f"a: {a.shape}, dtype: {a.dtype}") - print(f"b: {b.shape}, dtype: {b.dtype}") + print("Input tensor shapes:") + for i in range(num_inputs): + print(f"inputs[{i}]: {inputs[i].shape}, dtype: {inputs[i].dtype}") print(f"c: {c.shape}, dtype: {c.dtype}\n") epsilon = 1.2 if op in (operator.truediv, operator.floordiv): - b = torch.where(b == 0, torch.tensor(epsilon), b) + inputs[1] = torch.where(inputs[1] == 0, torch.tensor(epsilon), inputs[1]) - print("Executing elementwise apply kernel...") + inputs_ = [from_dlpack(t, assumed_align=16) for t in inputs] + c_ = from_dlpack(c, assumed_align=16).mark_layout_dynamic() + + print("Compiling kernel with cute.compile ...") + start_time = time.time() + compiled_fn = cute.compile[cute.GenerateLineInfo(True)]( + elementwise_apply, op, inputs_, c_, current_stream + ) + compilation_time = time.time() - start_time + print(f"Compilation time: {compilation_time:.4f} seconds") if not skip_ref_check: - elementwise_apply( - op, - from_dlpack(a), - from_dlpack(b), - from_dlpack(c).mark_layout_dynamic(), - current_stream, - ) + print("Executing elementwise apply kernel...") + compiled_fn(inputs_, c_, current_stream) print("Verifying results...") - torch.testing.assert_close(op(a, b), c) + torch.testing.assert_close(ref_op(*inputs), c) print("Results verified successfully!") + print(f"First few elements of result: \n{c[:3, :3]}") if not benchmark: return - compiled_func = cute.compile( - elementwise_apply, - op, - from_dlpack(a), - from_dlpack(b), - from_dlpack(c).mark_layout_dynamic(), - current_stream, - ) - # When compiled we inlined op in the kernel, so we do not pass it when benchmarking + print("Benchmarking elementwise apply kernel...") avg_time_us = testing.benchmark( - compiled_func, - kernel_arguments=testing.JitArguments( - from_dlpack(a), - from_dlpack(b), - from_dlpack(c).mark_layout_dynamic(), - current_stream, - ), + compiled_fn, + kernel_arguments=testing.JitArguments(inputs_, c_, current_stream), warmup_iterations=warmup_iterations, iterations=iterations, use_cuda_graphs=True, stream=current_stream, ) - avg_time = avg_time_us / 1e3 + num_elements = sum(input.numel() for input in inputs) + c.numel() # Print execution results - print(f"Kernel execution time: {avg_time:.4f} ms") + print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms") print( - f"Achieved memory throughput: {(3 * a.numel() * dtype.width // 8) / (avg_time / 1000) / 1e9:.2f} GB/s" + f"Achieved memory throughput: {(num_elements * dtype.width // 8) / (avg_time_us * 1000):.2f} GB/s" ) - print(f"First few elements of result: \n{c[:3, :3]}") if __name__ == "__main__": parser = argparse.ArgumentParser( - description="example of elementwise apply to demonstrate building elementwise kernels" + description="Demonstration of building customizable elementwise CUDA kernels using the CuTe DSL" ) - parser.add_argument("--M", default=128, type=int) - parser.add_argument("--N", default=128, type=int) + parser.add_argument("--M", default=4096, type=int) + parser.add_argument("--N", default=4096, type=int) parser.add_argument("--op", default="add", type=str) parser.add_argument("--warmup_iterations", default=2, type=int) parser.add_argument("--iterations", default=100, type=int) parser.add_argument("--skip_ref_check", action="store_true") parser.add_argument("--benchmark", action="store_true") args = parser.parse_args() - run_elementwise_apply_and_verify( - getattr(operator, args.op), + run_and_verify( + args.op, args.M, args.N, dtype=cutlass.Float32, diff --git a/examples/python/CuTeDSL/ampere/flash_attention_v2.py b/examples/python/CuTeDSL/ampere/flash_attention_v2.py index c6f8ff4c..aea70d5c 100644 --- a/examples/python/CuTeDSL/ampere/flash_attention_v2.py +++ b/examples/python/CuTeDSL/ampere/flash_attention_v2.py @@ -28,7 +28,7 @@ import argparse from types import SimpleNamespace -from typing import Type, Union, Callable +from typing import Type, Callable import torch import cuda.bindings.driver as cuda @@ -38,6 +38,7 @@ import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync, warp import cutlass.torch as cutlass_torch from cutlass.cute.runtime import from_dlpack +import cutlass.pipeline as pipeline import cutlass.utils as utils """ @@ -126,6 +127,10 @@ class FlashAttentionForwardAmpere: self._num_threads = num_threads self._is_causal = is_causal + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, num_threads=num_threads + ) + @staticmethod def can_implement( dtype, head_dim, m_block_size, n_block_size, num_threads, is_causal @@ -450,7 +455,7 @@ class FlashAttentionForwardAmpere: acc_shape_O = thr_mma.partition_shape_C( (self._m_block_size, self._head_dim_padded) ) - acc_O = cute.make_fragment(acc_shape_O, cutlass.Float32) + acc_O = cute.make_rmem_tensor(acc_shape_O, cutlass.Float32) acc_O.fill(0.0) # /////////////////////////////////////////////////////////////////////////////// @@ -506,7 +511,7 @@ class FlashAttentionForwardAmpere: tKVcKV = gmem_thr_copy_QKV.partition_S(cKV) # Allocate predicate tensors for m and n, here we only allocate the tile of k, and do special process for mn. # This is to reduce register pressure and gets 2-3% performance gain compared with allocating the whole tile. - tQpQ = cute.make_fragment( + tQpQ = cute.make_rmem_tensor( cute.make_layout( ( tQsQ.shape[0][1], @@ -517,7 +522,7 @@ class FlashAttentionForwardAmpere: ), cutlass.Boolean, ) - tKVpKV = cute.make_fragment( + tKVpKV = cute.make_rmem_tensor( cute.make_layout( ( tKsK.shape[0][1], @@ -571,11 +576,11 @@ class FlashAttentionForwardAmpere: # Softmax intermediate result: row_max and row_sum # /////////////////////////////////////////////////////////////////////////////// # shape: (atom_v_m * rest_m) - row_max = cute.make_fragment( + row_max = cute.make_rmem_tensor( (acc_O.shape[0][0] * acc_O.shape[1]), cutlass.Float32 ) # shape: (atom_v_m * rest_m) - row_sum = cute.make_fragment( + row_sum = cute.make_rmem_tensor( (acc_O.shape[0][0] * acc_O.shape[1]), cutlass.Float32 ) row_max.fill(-cutlass.Float32.inf) @@ -710,7 +715,7 @@ class FlashAttentionForwardAmpere: tOgO = gmem_thr_copy_O.partition_D(gO) tOrO = cute.make_fragment_like(tOgO, self._dtype) # sync before all smem stores are done. - cute.arch.barrier() + self.cta_sync_barrier.arrive_and_wait() # load acc O from smem to rmem for wider vectorization cute.copy( gmem_tiled_copy_O, @@ -724,7 +729,7 @@ class FlashAttentionForwardAmpere: (m_block, 0), ) tOcO = gmem_thr_copy_O.partition_D(cO) - tOpO = cute.make_fragment( + tOpO = cute.make_rmem_tensor( cute.make_layout( (tOgO.shape[0][1], tOgO.shape[1], tOgO.shape[2]), stride=(tOgO.shape[2], 0, 1), @@ -778,12 +783,12 @@ class FlashAttentionForwardAmpere: acc_shape_S = mma_params.thr_mma.partition_shape_C( (self._m_block_size, self._n_block_size) ) - acc_S = cute.make_fragment(acc_shape_S, cutlass.Float32) + acc_S = cute.make_rmem_tensor(acc_shape_S, cutlass.Float32) acc_S.fill(0.0) # wait for smem tile QK before mma calculation for S cute.arch.cp_async_wait_group(0) - cute.arch.barrier() + self.cta_sync_barrier.arrive_and_wait() # load smem tile V for O, special process for the first tile to avoid loading nan. # The `if` here is a constexpr, won't be generated in the IR. if is_first_n_block: @@ -847,7 +852,7 @@ class FlashAttentionForwardAmpere: # wait for smem tile V for O cute.arch.cp_async_wait_group(0) - cute.arch.barrier() + self.cta_sync_barrier.arrive_and_wait() if basic_params.n_block > 0: cute.copy( @@ -1170,7 +1175,7 @@ def run( f"Unsupported testcase {dtype}, {head_dim}, {m_block_size}, {n_block_size}, {num_threads}, {is_causal}" ) - print(f"Running Ampere SM80 FlashAttentionForward test with:") + print("Running Ampere SM80 FlashAttentionForward test with:") print(f" dtype: {dtype}") print(f" batch_size: {batch_size}") print(f" seqlen_q: {seqlen_q}") @@ -1285,6 +1290,7 @@ def run( return avg_time_us # Return execution time in microseconds + if __name__ == "__main__": parser = argparse.ArgumentParser( description="example of flash attention v2 with CuTe on GPU" diff --git a/examples/python/CuTeDSL/ampere/hstu_attention.py b/examples/python/CuTeDSL/ampere/hstu_attention.py new file mode 100644 index 00000000..1c5aeb5f --- /dev/null +++ b/examples/python/CuTeDSL/ampere/hstu_attention.py @@ -0,0 +1,1149 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import Type +import argparse + +import torch +import cuda.bindings.driver as cuda +import cutlass +import cutlass.torch as cutlass_torch +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack +from cutlass._mlir.dialects import llvm +import cutlass.pipeline as pipeline +import cutlass.utils as utils + +""" +A HSTU attention forward pass example for NVIDIA Ampere SM80 architecture using Cute DSL, based on the example of flash_attention_v2 for Ampere. + +The example showcases an implementation of HSTU attention(https://arxiv.org/abs/2402.17152) within generative recommender system. It utilize the formula: `mask(silu(q@k+rab))@v`. The implementation includes the following features: +- efficient fast sigmoid implementation +- block rasterization to improve L2 cache hit rate. +- The correct approach to verify the results of the HSTU attention with a Pytorch implementation. + +To run this example: + +.. code-block:: bash + + python examples/ampere/hstu_attention.py --batch_size 4 --seqlen_q 8192 --seqlen_kv 8192 --num_head 4 --head_dim 128 --m_block_size 128 --n_block_size 64 --is_causal --perf_test + +The above example tests the performance of HSTU attention with batch size 4, sequence length 8192, 4 attention heads, and head dimension 128. The m_block_size is 128, and n_block_size is 64. The causal masking is enabled. + +There are some constraints for this example: +* Only Float16 and BFloat16 are supported. +* The contiguous dimension of each tensor must be at least 16 bytes aligned. +* The values of `m_block_size`, `n_block_size`, and `head_dim` must be selected to stay within shared memory capacity limits. +* `m_block_size * 2` must be divisible by `num_threads`, otherwise the kernel will not be able to get the correct result. +* "seqlen_kv should be greater or equal to seqlen_q. +""" + + +class HSTUAttentionForwardAmpere(object): + def __init__( + self, + dtype, + batch_size, + seqlen_q, + seqlen_kv, + num_head, + head_dim, + m_block_size=128, + n_block_size=128, + num_threads=128, + enable_fast_sigmoid=False, + enable_block_rasterization=False, + is_causal=False, + ): + self._dtype = dtype + self._batch_size = batch_size + self._seqlen_q = seqlen_q + self._seqlen_kv = seqlen_kv + self._num_head = num_head + self._head_dim = head_dim + self._m_block_size = m_block_size + self._n_block_size = n_block_size + # padded head_dim to 32 for cta tile. + self._head_dim_padded = (head_dim + 31) // 32 * 32 + self._num_threads = num_threads + self._enable_fast_sigmoid = enable_fast_sigmoid + self._enable_block_rasterization = enable_block_rasterization + self._is_causal = is_causal + assert self._dtype == cutlass.Float16 or self._dtype == cutlass.BFloat16, ( + "Only Float16 or BFloat16 is supported" + ) + assert self._head_dim % 8 == 0, "head dim should be multiply of 8" + assert self._num_threads % 32 == 0, "num_threads should be multiply of 32" + assert self._m_block_size * self._head_dim_padded // self._num_threads >= 8, ( + "Small m_block_size and too many threads" + ) + assert self._n_block_size * self._head_dim_padded // self._num_threads >= 8, ( + "Small n_block_size and too many threads" + ) + assert seqlen_kv >= seqlen_q, "seqlen_kv should be greater or equal to seqlen_q" + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, num_threads=num_threads + ) + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mRAB: cute.Tensor, + stream: cuda.CUstream, + ): + """Configures and launches the HSTU attention kernel. + mQ/mK/mV/mO/mRAB has same data types(supports fp16 and bf16). + mQ has layout: (batch_size, seqlen_q, num_head, head_dim):(seqlen_q * num_head * head_dim, num_head * head_dim, head_dim, 1) + mK/mV/mO has same layout: (batch_size, seqlen_kv, num_head, head_dim):(seqlen_kv * num_head * head_dim, num_head * head_dim, head_dim, 1) + mRAB has layout: (batch_size, num_head, seqlen_q, seqlen_kv):(seqlen_q*seqlen_kv*num_head, seqlen_q*seqlen_kv, seqlen_kv, 1) + + Prepares the shared memory layout, tiled copy atoms, tiled mma and shared memory storage. + Then launches the kernel function with the prepared parameters. + + :param mQ: query tensor + :type mQ: cute.Tensor + :param mK: key tensor + :type mK: cute.Tensor + :param mV: value tensor + :type mV: cute.Tensor + :param mO: output tensor + :type mO: cute.Tensor + :param mRAB: RAB tensor + :type mRAB: cute.Tensor + """ + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory layout: Q/K/V/RAB + # /////////////////////////////////////////////////////////////////////////////// + smem_k_block_size = 64 if self._head_dim_padded % 64 == 0 else 32 + swizzle_bits = 3 if smem_k_block_size == 64 else 2 + sQ_layout_atom = cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, 4, 3), + 0, + cute.make_layout((8, smem_k_block_size), stride=(smem_k_block_size, 1)), + ) + sQ_layout = cute.tile_to_shape( + sQ_layout_atom, + (self._m_block_size, self._head_dim_padded), + (0, 1), + ) + sKV_layout_atom = sQ_layout_atom + sKV_layout = cute.tile_to_shape( + sKV_layout_atom, + (self._n_block_size, self._head_dim_padded), + (0, 1), + ) + sRAB_layout_atom = sQ_layout_atom + sRAB_layout = cute.tile_to_shape( + sRAB_layout_atom, (self._m_block_size, self._n_block_size), (0, 1) + ) + sO_layout = sQ_layout + + @cute.struct + class SharedStorage: + sQ: cute.struct.Align[ + cute.struct.MemRange[self._dtype, cute.cosize(sQ_layout)], 1024 + ] + sK: cute.struct.Align[ + cute.struct.MemRange[self._dtype, cute.cosize(sKV_layout)], 1024 + ] + sV: cute.struct.Align[ + cute.struct.MemRange[self._dtype, cute.cosize(sKV_layout)], 1024 + ] + sRAB: cute.struct.Align[ + cute.struct.MemRange[self._dtype, cute.cosize(sRAB_layout)], 1024 + ] + + assert SharedStorage.size_in_bytes() < utils.get_smem_capacity_in_bytes( + "sm_80" + ), "insufficient shared memory" + + # /////////////////////////////////////////////////////////////////////////////// + # GMEM Tiled copy: + # /////////////////////////////////////////////////////////////////////////////// + # Thread layouts for copies + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self._dtype.width + # atom_async_copy: async copy atom for QKV load + atom_async_copy = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + self._dtype, + num_bits_per_copy=universal_copy_bits, + ) + # atom_universal_copy: universal copy atom for O store + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self._dtype, + num_bits_per_copy=universal_copy_bits, + ) + # tQKV_layout: thread layout for QKV load + tQKV_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems + tQKV_layout = cute.make_layout( + (self._num_threads // tQKV_shape_dim_1, tQKV_shape_dim_1), + stride=(tQKV_shape_dim_1, 1), + ) + # tO_layout: thread layout for O store + tO_layout = tQKV_layout + + # Value layouts for copies + vQKV_layout = cute.make_layout((1, async_copy_elems)) + vO_layout = vQKV_layout + + # gmem_tiled_copy_QKV: tiled copy for QKV load + gmem_tiled_copy_QKV = cute.make_tiled_copy_tv( + atom_async_copy, tQKV_layout, vQKV_layout + ) + + # gmem_tiled_copy_O: tiled copy for O store + gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, tO_layout, vO_layout + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Tiled mma + # /////////////////////////////////////////////////////////////////////////////// + tiled_mma = cute.make_tiled_mma( + cute.nvgpu.warp.MmaF16BF16Op(self._dtype, cutlass.Float32, (16, 8, 16)), + (self._num_threads // 32, 1, 1), + permutation_mnk=(self._num_threads // 32 * 16, 16, 16), + ) + + # block rasterization + if cutlass.const_expr(self._enable_block_rasterization): + grid_dim = ( + self._batch_size, + self._num_head, + cute.ceil_div(mQ.shape[1], self._m_block_size), + ) + else: + grid_dim = ( + cute.ceil_div(mQ.shape[1], self._m_block_size), + self._batch_size, + self._num_head, + ) + + self.kernel( + mQ, + mK, + mV, + mO, + mRAB, + sQ_layout, + sKV_layout, + sRAB_layout, + sO_layout, + gmem_tiled_copy_QKV, + gmem_tiled_copy_O, + tiled_mma, + SharedStorage, + ).launch( + grid=grid_dim, + block=[self._num_threads, 1, 1], + smem=SharedStorage.size_in_bytes(), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mRAB: cute.Tensor, + sQ_layout: cute.ComposedLayout, + sKV_layout: cute.ComposedLayout, + sRAB_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + gmem_tiled_copy_QKV: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + tiled_mma: cute.TiledMma, + SharedStorage: cutlass.Constexpr, + ): + """Kernel function for HSTU attention. + + :param mQ: query tensor + :type mQ: cute.Tensor + :param mK: key tensor + :type mK: cute.Tensor + :param mV: value tensor + :type mV: cute.Tensor + :param mO: output tensor + :type mO: cute.Tensor + :param mRAB: RAB(Relative Attention Bias) tensor + :type mRAB: cute.Tensor + :param sQ_layout: shared memory layout for Q + :type sQ_layout: cute.ComposedLayout + :param sKV_layout: shared memory layout for K/V + :type sKV_layout: cute.ComposedLayout + :param sRAB_layout: shared memory layout for RAB + :type sRAB_layout: cute.ComposedLayout + :param sO_layout: shared memory layout for O + :type sO_layout: cute.ComposedLayout + :param gmem_tiled_copy_QKV: tiled copy for QKV load + :type gmem_tiled_copy_QKV: cute.TiledCopy + :param gmem_tiled_copy_O: tiled copy for O store + :type gmem_tiled_copy_O: cute.TiledCopy + :param tiled_mma: tiled mma + :type tiled_mma: cute.TiledMma + :param SharedStorage: shared storage + :type SharedStorage: cutlass.Constexpr + """ + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + + if cutlass.const_expr(self._enable_block_rasterization): + batch_size, num_head, m_block = cute.arch.block_idx() + else: + m_block, batch_size, num_head = cute.arch.block_idx() + + # reverse the m_block index + m_block = cute.ceil_div(mQ.shape[1], self._m_block_size) - m_block - 1 + + if cutlass.const_expr(self._is_causal): + n_block = ( + cute.ceil_div((m_block + 1) * self._m_block_size, self._n_block_size) + - 1 + ) # for causal case, only process the first n_block tiles + else: + n_block = cute.ceil_div(mK.shape[1], self._n_block_size) - 1 + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + # (m_block_size, head_dim) + gQ = cute.local_tile( + mQ[batch_size, None, num_head, None], + (self._m_block_size, self._head_dim_padded), + (m_block, 0), + ) + # (n_block_size, head_dim, n_block) + gK = cute.local_tile( + mK[batch_size, None, num_head, None], + (self._n_block_size, self._head_dim_padded), + (None, 0), + ) + # (n_block_size, head_dim, n_block) + gV = cute.local_tile( + mV[batch_size, None, num_head, None], + (self._n_block_size, self._head_dim_padded), + (None, 0), + ) + # (m_block_size, n_block_size) + gRAB = cute.local_tile( + mRAB[batch_size, num_head, None, None], + (self._m_block_size, self._n_block_size), + (m_block, None), + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sQ = storage.sQ.get_tensor(sQ_layout) + sK = storage.sK.get_tensor(sKV_layout) + sV = storage.sV.get_tensor(sKV_layout) + sRAB = storage.sRAB.get_tensor(sRAB_layout) + + # Transpose view of V to tensor with layout (head_dim, n_block_size) for tiled mma + sVt = cute.composition( + sV, + cute.make_layout( + (self._head_dim_padded, self._n_block_size), + stride=(self._n_block_size, 1), + ), + ) + + gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_slice(tidx) + # (CPY_Atom, CPY_M, CPY_K) + tQgQ = gmem_thr_copy_QKV.partition_S(gQ) + tQsQ = gmem_thr_copy_QKV.partition_D(sQ) + # (CPY_Atom, CPY_N, CPY_K, n_block) + tKgK = gmem_thr_copy_QKV.partition_S(gK) + tKsK = gmem_thr_copy_QKV.partition_D(sK) + # (CPY_Atom, CPY_N, CPY_K, n_block) + tVgV = gmem_thr_copy_QKV.partition_S(gV) + tVsV = gmem_thr_copy_QKV.partition_D(sV) + # (CPY_Atom, CPY_M, CPY_N, n_block) + tRABgRAB = gmem_tiled_copy_QKV.get_slice(tidx).partition_S(gRAB) + tRabsRAB = gmem_tiled_copy_QKV.get_slice(tidx).partition_D(sRAB) + + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + thr_mma = tiled_mma.get_slice(tidx) + tSrQ = thr_mma.make_fragment_A(thr_mma.partition_A(sQ)) + tSrK = thr_mma.make_fragment_B(thr_mma.partition_B(sK)) + tOrVt = thr_mma.make_fragment_B(thr_mma.partition_B(sVt)) + acc_shape_O = thr_mma.partition_shape_C( + (self._m_block_size, self._head_dim_padded) + ) + acc_O = cute.make_rmem_tensor(acc_shape_O, cutlass.Float32) + acc_O.fill(0.0) + + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom_Q = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + self._dtype, + ) + smem_copy_atom_K = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + self._dtype, + ) + smem_copy_atom_V = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), + self._dtype, + ) + smem_copy_atom_RAB = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + self._dtype, + ) + smem_tiled_copy_Q = cute.make_tiled_copy_A(smem_copy_atom_Q, tiled_mma) + smem_tiled_copy_K = cute.make_tiled_copy_B(smem_copy_atom_K, tiled_mma) + smem_tiled_copy_V = cute.make_tiled_copy_B(smem_copy_atom_V, tiled_mma) + smem_tiled_copy_RAB = cute.make_tiled_copy_C(smem_copy_atom_RAB, tiled_mma) + + smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx) + smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx) + smem_thr_copy_V = smem_tiled_copy_V.get_slice(tidx) + smem_thr_copy_RAB = smem_tiled_copy_RAB.get_slice(tidx) + + tSsQ = smem_thr_copy_Q.partition_S(sQ) + tSrQ_copy_view = smem_thr_copy_Q.retile(tSrQ) + tSsK = smem_thr_copy_K.partition_S(sK) + tSrK_copy_view = smem_thr_copy_K.retile(tSrK) + tOsVt = smem_thr_copy_V.partition_S(sVt) + tOrVt_copy_view = smem_thr_copy_V.retile(tOrVt) + tSsRAB = smem_thr_copy_RAB.partition_S(sRAB) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + # Construct identity layout for Q, KV and RAB + mcQ = cute.make_identity_tensor(mQ.layout.shape) + mcKV = cute.make_identity_tensor(mK.layout.shape) + mcRAB = cute.make_identity_tensor(mRAB.layout.shape) + + cQ = cute.local_tile( + mcQ[batch_size, None, num_head, None], + (self._m_block_size, self._head_dim_padded), + (m_block, 0), + ) + cKV = cute.local_tile( + mcKV[batch_size, None, num_head, None], + (self._n_block_size, self._head_dim_padded), + (n_block, 0), + ) + cRAB = cute.local_tile( + mcRAB[batch_size, num_head, None, None], + (self._m_block_size, self._n_block_size), + (m_block, None), + ) + + # Repeat the partitioning with identity layouts + tQcQ = gmem_thr_copy_QKV.partition_S(cQ) + tKVcKV = gmem_thr_copy_QKV.partition_S(cKV) + tRABcRAB = gmem_thr_copy_QKV.partition_S(cRAB) + + tQpQ = cute.make_rmem_tensor( + cute.make_layout( + ( + tQsQ.shape[0][1], + cute.size(tQsQ, mode=[1]), + cute.size(tQsQ, mode=[2]), + ), + stride=(cute.size(tQsQ, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + tKVpKV = cute.make_rmem_tensor( + cute.make_layout( + ( + tKsK.shape[0][1], + cute.size(tKsK, mode=[1]), + cute.size(tKsK, mode=[2]), + ), + stride=(cute.size(tKsK, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + + # Set predicates for head_dim bounds, seqlen_q/k/v bounds is processed at the first tile. + for rest_v in cutlass.range_constexpr(tQpQ.shape[0]): + for rest_k in cutlass.range_constexpr(tQpQ.shape[2]): + tQpQ[rest_v, 0, rest_k] = cute.elem_less( + tQcQ[(0, rest_v), 0, rest_k][3], mQ.layout.shape[3] + ) + for rest_v in cutlass.range_constexpr(tKVpKV.shape[0]): + for rest_k in cutlass.range_constexpr(tKVpKV.shape[2]): + tKVpKV[rest_v, 0, rest_k] = cute.elem_less( + tKVcKV[(0, rest_v), 0, rest_k][3], mK.layout.shape[3] + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Prefetch Prologue + # /////////////////////////////////////////////////////////////////////////////// + # Start async loads of the last mn-tile, where we take care of the mn residue + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): + if cute.elem_less(tQcQ[0, m, 0][1], mQ.layout.shape[1]): + cute.copy( + gmem_tiled_copy_QKV, + tQgQ[None, m, None], + tQsQ[None, m, None], + pred=tQpQ[None, m, None], + ) + else: + # Clear the smem tiles to account for predicated off loads + tQsQ[None, m, None].fill(0) + + for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): + if cute.elem_less(tKVcKV[0, n, 0][1], mK.layout.shape[1]): + cute.copy( + gmem_tiled_copy_QKV, + tKgK[None, n, None, n_block], + tKsK[None, n, None], + pred=tKVpKV[None, n, None], + ) + else: + # Clear the smem tiles to account for predicated off loads + tKsK[None, n, None].fill(0) + + for m in cutlass.range_constexpr(cute.size(tRABcRAB.shape[1])): + for n in cutlass.range_constexpr(cute.size(tRABcRAB.shape[2])): + if cute.elem_less( + tRABcRAB[0, m, n, n_block][1], mRAB.layout.shape[2] + ) and cute.elem_less( + tRABcRAB[0, m, n, n_block][2], mRAB.layout.shape[3] + ): + cute.copy( + gmem_tiled_copy_QKV, + tRABgRAB[None, m, n, n_block], + tRabsRAB[None, m, n], + ) + else: + # Clear the smem tiles to account for predicated off loads + tRabsRAB[None, m, n].fill(0) + cute.arch.cp_async_commit_group() + + # /////////////////////////////////////////////////////////////////////////////// + # Mainloop + # /////////////////////////////////////////////////////////////////////////////// + for n_block_idx in range(n_block, -1, -1): + # wait for smem tile QK before mma caculation for S + cute.arch.cp_async_wait_group(0) + self.cta_sync_barrier.arrive_and_wait() + + if n_block_idx == n_block: + for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])): + if cute.elem_less(tKVcKV[0, n, 0][1], mV.layout.shape[1]): + cute.copy( + gmem_tiled_copy_QKV, + tVgV[None, n, None, n_block_idx], + tVsV[None, n, None], + pred=tKVpKV[None, n, None], + ) + else: + tVsV[None, n, None].fill(0) + else: + cute.copy( + gmem_tiled_copy_QKV, + tVgV[None, None, None, n_block_idx], + tVsV[None, None, None], + pred=tKVpKV[None, None, None], + ) + cute.arch.cp_async_commit_group() + + acc_shape_S = thr_mma.partition_shape_C( + (self._m_block_size, self._n_block_size) + ) + acc_S = cute.make_rmem_tensor(acc_shape_S, cutlass.Float32) + + rRAB_shape_S = thr_mma.partition_shape_C( + (self._m_block_size, self._n_block_size) + ) + rRAB = cute.make_rmem_tensor(rRAB_shape_S, self._dtype) + tSrRAB_copy_view = smem_thr_copy_RAB.retile(rRAB) + cute.copy( + smem_tiled_copy_RAB, + tSsRAB[None, None, None], + tSrRAB_copy_view[None, None, None], + ) + acc_S.store(rRAB.load().to(cutlass.Float32)) + + # /////////////////////////////////////////////////////////////////////////////// + # S gemm calculation + # /////////////////////////////////////////////////////////////////////////////// + # ldmatrix first QK k-block for mma + cute.copy( + smem_tiled_copy_Q, + tSsQ[None, None, 0], + tSrQ_copy_view[None, None, 0], + ) + cute.copy( + smem_tiled_copy_K, + tSsK[None, None, 0], + tSrK_copy_view[None, None, 0], + ) + for k in cutlass.range_constexpr(0, cute.size(tSsQ.shape[2])): + # ldmatrix next QK k-block for mma + if k < cute.size(tSsQ.shape[2]) - 1: + cute.copy( + smem_tiled_copy_Q, + tSsQ[None, None, k + 1], + tSrQ_copy_view[None, None, k + 1], + ) + cute.copy( + smem_tiled_copy_K, + tSsK[None, None, k + 1], + tSrK_copy_view[None, None, k + 1], + ) + # mma for S=Q@K + cute.gemm( + tiled_mma, + acc_S, + tSrQ[None, None, k], + tSrK[None, None, k], + acc_S, + ) + + # wait for smem tile V for O + cute.arch.cp_async_wait_group(0) + self.cta_sync_barrier.arrive_and_wait() + + if n_block_idx > 0: + cute.copy( + gmem_tiled_copy_QKV, + tKgK[None, None, None, n_block_idx - 1], + tKsK[None, None, None], + pred=tKVpKV[None, None, None], + ) + # m residue handling for RAB + for m in cutlass.range_constexpr(cute.size(tRABcRAB.shape[1])): + if cute.elem_less( + tRABcRAB[0, m, 0, n_block][1], mRAB.layout.shape[2] + ): + cute.copy( + gmem_tiled_copy_QKV, + tRABgRAB[None, m, None, n_block_idx - 1], + tRabsRAB[None, m, None], + ) + else: + tRabsRAB[None, m, None].fill(0) + + cute.arch.cp_async_commit_group() + + # /////////////////////////////////////////////////////////////////////////////// + # silu activation + # /////////////////////////////////////////////////////////////////////////////// + if self._enable_fast_sigmoid: + t1 = acc_S.load() + t2 = t1 * 0.5 + acc_S.store(t2) + for i in cutlass.range_constexpr(cute.size(acc_S.shape[0])): + for j in cutlass.range_constexpr(cute.size(acc_S.shape[1])): + for k in cutlass.range_constexpr(cute.size(acc_S.shape[2])): + ret = llvm.inline_asm( + cutlass.Float32.mlir_type, + [acc_S[i, j, k].ir_value()], + "tanh.approx.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + acc_S[i, j, k] = ret + t3 = acc_S.load() + t4 = t2 * t3 + t2 + acc_S.store(t4) + else: + LOG2_E = 1.4426950408889634074 + t1 = acc_S.load() + t2 = t1 * -LOG2_E + t3 = cute.math.exp2(t2, fastmath=True) + 1.0 + t4 = t1 / t3 + acc_S.store(t4) + + mACC = cute.make_identity_tensor( + (mRAB.layout.shape[2], mRAB.layout.shape[3]) + ) # (seqlen_q, seqlen_kv) + cACC = cute.local_tile( + mACC[None, None], + (self._m_block_size, self._n_block_size), + (m_block, n_block_idx), + ) + + if self._is_causal and (n_block - n_block_idx) < cute.ceil_div( + self._m_block_size, self._n_block_size + ): + tACCcACC = thr_mma.partition_C(cACC) + for i in cutlass.range_constexpr(cute.size(tACCcACC.shape[0])): + for j in cutlass.range_constexpr(cute.size(tACCcACC.shape[1])): + for k in cutlass.range_constexpr(cute.size(tACCcACC.shape[2])): + if cute.elem_less( + tACCcACC[i, j, k][0], tACCcACC[i, j, k][1] + ): + acc_S[i, j, k] = 0.0 + + rP = cute.make_rmem_tensor_like(acc_S, self._dtype) + rP.store(acc_S.load().to(self._dtype)) + + # /////////////////////////////////////////////////////////////////////////////// + # O gemm calculation + # /////////////////////////////////////////////////////////////////////////////// + # Convert layout of acc_S to gemm O accept layout. + # Due to the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2)) + rP_layout_divided = cute.logical_divide(rP.layout, (None, None, 2)) + rP_mma_view = cute.make_layout( + ( + (rP_layout_divided.shape[0], rP_layout_divided.shape[2][0]), + rP_layout_divided.shape[1], + rP_layout_divided.shape[2][1], + ), + stride=( + (rP_layout_divided.stride[0], rP_layout_divided.stride[2][0]), + rP_layout_divided.stride[1], + rP_layout_divided.stride[2][1], + ), + ) + tOrP = cute.make_tensor(rP.iterator, rP_mma_view) + + # ldmatrix first V k-block for mma + cute.copy( + smem_tiled_copy_V, + tOsVt[None, None, 0], + tOrVt_copy_view[None, None, 0], + ) + for k in cutlass.range_constexpr(0, cute.size(tOrP.shape[2])): + # ldmatrix next V k-block for mma + if k < cute.size(tOrP.shape[2]) - 1: + cute.copy( + smem_tiled_copy_V, + tOsVt[None, None, k + 1], + tOrVt_copy_view[None, None, k + 1], + ) + # mma for O=P@V + cute.gemm( + tiled_mma, + acc_O, + tOrP[None, None, k], + tOrVt[None, None, k], + acc_O, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + # store acc_O + rO = cute.make_rmem_tensor(acc_O.layout, self._dtype) + rO.store(acc_O.load().to(self._dtype)) + # reuse sQ's data iterator + sO_iter = cute.recast_ptr(sQ.iterator, sO_layout.inner) + sO = cute.make_tensor(sO_iter, sO_layout.outer) + smem_copy_atom_O = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self._dtype + ) + smem_tiled_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma) + smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx) + taccOrO = smem_thr_copy_O.retile(rO) + taccOsO = smem_thr_copy_O.partition_D(sO) + # copy acc O from rmem to smem with sts.32(auto vectorization) + cute.copy( + smem_copy_atom_O, + taccOrO, + taccOsO, + ) + gO = cute.local_tile( + mO[batch_size, None, num_head, None], + (self._m_block_size, self._head_dim_padded), + (m_block, 0), + ) + + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOrO = cute.make_fragment_like(tOgO, self._dtype) + # sync before all sts are done. + self.cta_sync_barrier.arrive_and_wait() + # load acc O from smem to rmem for wider vectorization + cute.copy( + gmem_tiled_copy_O, + tOsO, + tOrO, + ) + # predicate for O + mcO = cute.make_identity_tensor(mO.layout.shape) + cO = cute.local_tile( + mcO[batch_size, None, num_head, None], + (self._m_block_size, self._head_dim_padded), + (m_block, 0), + ) + tOcO = gmem_thr_copy_O.partition_D(cO) + tOpO = cute.make_rmem_tensor( + cute.make_layout( + (tOgO.shape[0][1], tOgO.shape[1], tOgO.shape[2]), + stride=(tOgO.shape[2], 0, 1), + ), + cutlass.Boolean, + ) + for rest_v in cutlass.range_constexpr(tOpO.shape[0]): + for rest_n in cutlass.range_constexpr(cute.size(tOpO.shape[2])): + tOpO[rest_v, 0, rest_n] = cute.elem_less( + tOcO[(0, rest_v), 0, rest_n][3], mO.layout.shape[3] + ) + # copy acc O from rmem to gmem + for rest_m in cutlass.range_constexpr(cute.size(tOpO.shape[1])): + if cute.elem_less(tOcO[0, rest_m, 0][1], mO.layout.shape[1]): + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None], + pred=tOpO[None, rest_m, None], + ) + + +def run_pytorch_hstu_test( + dtype: torch.dtype, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + rab: torch.Tensor, + is_causal: bool, +): + """Generate the reference output of the HSTU attention with Pytorch. + + :param dtype: data type of the input tensors + :type dtype: torch.dtype + :param q: query tensor + :type q: torch.Tensor + :param k: key tensor + :type k: torch.Tensor + :param v: value tensor + :type v: torch.Tensor + :param rab: RAB tensor + :type rab: torch.Tensor + :param is_causal: whether to use causal masking + :type is_causal: bool + """ + q = q.to(dtype) + k = k.to(dtype) + v = v.to(dtype) + rab = rab.to(dtype) + + s_ = torch.matmul(q, k.transpose(-2, -1)) + rab + s_ = torch.nn.functional.silu(s_) + if is_causal: + mask = torch.ones(1, 1, q.shape[2], k.shape[2], dtype=dtype) + mask = torch.tril(mask) + s_ = s_ * mask.cuda() + + o = torch.matmul(s_, v).permute(0, 2, 1, 3).contiguous() + return o + + +def run( + dtype: Type[cutlass.Numeric], + batch_size: int, + seqlen_q: int, + seqlen_kv: int, + num_head: int, + head_dim: int, + m_block_size: int = 128, + n_block_size: int = 128, + num_threads: int = 128, + enable_fast_sigmoid: bool = False, + enable_block_rasterization: bool = False, + is_causal: bool = False, + perf_test: bool = False, + **kwargs, +): + """ + Run the HSTU attention kernel. + + :param dtype: data type of the input tensors + :type dtype: Type[cutlass.Numeric] + :param batch_size: batch size + :type batch_size: int + :param seqlen_q: sequence length of the query + :type seqlen_q: int + :param seqlen_kv: sequence length of the key + :type seqlen_kv: int + :param num_head: number of attention heads + :type num_head: int + :param head_dim: dimension of the head + :type head_dim: int + :param m_block_size: block size for the m dimension of computation + :type m_block_size: int + :param n_block_size: block size for the n dimension of computation + :type n_block_size: int + :param num_threads: number of threads + :type num_threads: int + :param enable_fast_sigmoid: whether to use fast sigmoid + :type enable_fast_sigmoid: bool + :param enable_block_rasterization: whether to use block rasterization + :type enable_block_rasterization: bool + :param is_causal: whether to use causal masking + :type is_causal: bool + """ + assert dtype == cutlass.Float16 or dtype == cutlass.BFloat16 + + torch_stream = torch.cuda.current_stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + + print("Running Ampere SM80 HSTUAttentionForward test with:") + print("batch_size: ", batch_size) + print("seqlen_q: ", seqlen_q) + print("seqlen_kv: ", seqlen_kv) + print("num_head: ", num_head) + print("head_dim: ", head_dim) + print("m_block_size: ", m_block_size) + print("n_block_size: ", n_block_size) + print("num_threads: ", num_threads) + print("is_causal: ", is_causal) + print("enable_fast_sigmoid: ", enable_fast_sigmoid) + print("enable_block_rasterization: ", enable_block_rasterization) + print("dtype: ", dtype) + + # reduced tensor num and iter num for functionality test + TENSOR_NUM = 1 + ITER_NUM = 1 + WARMUP_NUM = 0 + if perf_test: + TENSOR_NUM = 3 + ITER_NUM = 100 + WARMUP_NUM = 10 + + # Create tensor Q/K/V/O + qs = [ + torch.randn( + batch_size, seqlen_q, num_head, head_dim, dtype=cutlass_torch.dtype(dtype) + ).cuda() + for _ in range(TENSOR_NUM) + ] + ks = [ + torch.randn( + batch_size, seqlen_kv, num_head, head_dim, dtype=cutlass_torch.dtype(dtype) + ).cuda() + for _ in range(TENSOR_NUM) + ] + vs = [ + torch.randn( + batch_size, seqlen_kv, num_head, head_dim, dtype=cutlass_torch.dtype(dtype) + ).cuda() + for _ in range(TENSOR_NUM) + ] + os = [ + torch.randn( + batch_size, seqlen_q, num_head, head_dim, dtype=cutlass_torch.dtype(dtype) + ).cuda() + for _ in range(TENSOR_NUM) + ] + + rabs = [ + torch.randn( + batch_size, num_head, seqlen_q, seqlen_kv, dtype=cutlass_torch.dtype(dtype) + ).cuda() + for _ in range(TENSOR_NUM) + ] + + fa2_fwd = HSTUAttentionForwardAmpere( + dtype, + batch_size, + seqlen_q, + seqlen_kv, + num_head, + head_dim, + m_block_size, + n_block_size, + num_threads, + enable_fast_sigmoid=enable_fast_sigmoid, + enable_block_rasterization=enable_block_rasterization, + is_causal=is_causal, + ) + # assume input is 16B align. + mqs = [ + ( + from_dlpack(qs[i], assumed_align=16) + .mark_layout_dynamic(leading_dim=3) + .mark_compact_shape_dynamic( + mode=3, + stride_order=qs[i].dim_order(), + divisibility=(128 // dtype.width), + ) + ) + for i in range(TENSOR_NUM) + ] + mks = [ + ( + from_dlpack(ks[i], assumed_align=16) + .mark_layout_dynamic(leading_dim=3) + .mark_compact_shape_dynamic( + mode=3, + stride_order=ks[i].dim_order(), + divisibility=(128 // dtype.width), + ) + ) + for i in range(TENSOR_NUM) + ] + mvs = [ + ( + from_dlpack(vs[i], assumed_align=16) + .mark_layout_dynamic(leading_dim=3) + .mark_compact_shape_dynamic( + mode=3, + stride_order=vs[i].dim_order(), + divisibility=(128 // dtype.width), + ) + ) + for i in range(TENSOR_NUM) + ] + mos = [ + ( + from_dlpack(os[i], assumed_align=16) + .mark_layout_dynamic(leading_dim=3) + .mark_compact_shape_dynamic( + mode=3, + stride_order=os[i].dim_order(), + divisibility=(128 // dtype.width), + ) + ) + for i in range(TENSOR_NUM) + ] + mrabs = [ + ( + from_dlpack(rabs[i], assumed_align=16) + .mark_layout_dynamic(leading_dim=3) + .mark_compact_shape_dynamic( + mode=3, + stride_order=rabs[i].dim_order(), + divisibility=(128 // dtype.width), + ) + ) + for i in range(TENSOR_NUM) + ] + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + kernel = cute.compile( + fa2_fwd, + mqs[0], + mks[0], + mvs[0], + mos[0], + mrabs[0], + stream, + ) + + for i in range(0, ITER_NUM): + if i == WARMUP_NUM: + start_event.record(torch_stream) + # Run the kernel + kernel( + mqs[i % TENSOR_NUM], + mks[i % TENSOR_NUM], + mvs[i % TENSOR_NUM], + mos[i % TENSOR_NUM], + mrabs[i % TENSOR_NUM], + stream, + ) + + end_event.record(torch_stream) + torch.cuda.synchronize(torch_stream) + + elapsed_time = start_event.elapsed_time(end_event) + elapsed_time_avg = elapsed_time / (ITER_NUM - WARMUP_NUM) + + LAST_USED_TENSOR = (ITER_NUM - 1) % TENSOR_NUM + q = qs[LAST_USED_TENSOR].permute(0, 2, 1, 3).contiguous() + k = ks[LAST_USED_TENSOR].permute(0, 2, 1, 3).contiguous() + v = vs[LAST_USED_TENSOR].permute(0, 2, 1, 3).contiguous() + rab = rabs[LAST_USED_TENSOR] + + kernel_out = os[LAST_USED_TENSOR].cpu() + + with torch.cuda.stream(torch_stream): + ref_bf16 = run_pytorch_hstu_test(torch.bfloat16, q, k, v, rab, is_causal).cpu() + ref_fp32 = run_pytorch_hstu_test(torch.float32, q, k, v, rab, is_causal).cpu() + torch.cuda.synchronize(torch_stream) + + assert (kernel_out - ref_fp32).abs().max().item() <= 4 * ( + ref_bf16 - ref_fp32 + ).abs().max().item() + print("Results verified successfully!") + + if perf_test: + print(f"Elapsed time: {elapsed_time_avg:.3f} ms") + return elapsed_time_avg * 1000 # return in microseconds + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="example of HSTU attention with CuTe") + parser.add_argument("--dtype", type=cutlass.dtype, default=cutlass.BFloat16) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--seqlen_q", type=int, default=2048) + parser.add_argument("--seqlen_kv", type=int, default=2048) + parser.add_argument("--num_head", type=int, default=4) + parser.add_argument("--head_dim", type=int, default=128) + parser.add_argument("--m_block_size", type=int, default=64) + parser.add_argument("--n_block_size", type=int, default=64) + parser.add_argument("--num_threads", type=int, default=128) + parser.add_argument( + "--no_fast_sigmoid", action="store_false", dest="enable_fast_sigmoid" + ) + parser.add_argument( + "--no_block_rasterization", + action="store_false", + dest="enable_block_rasterization", + ) + parser.add_argument("--is_causal", action="store_true", dest="is_causal") + parser.add_argument("--perf_test", action="store_true", dest="perf_test") + args = parser.parse_args() + + run( + args.dtype, + args.batch_size, + args.seqlen_q, + args.seqlen_kv, + args.num_head, + args.head_dim, + args.m_block_size, + args.n_block_size, + args.num_threads, + args.enable_fast_sigmoid, + args.enable_block_rasterization, + args.is_causal, + args.perf_test, + ) + print("PASS") diff --git a/examples/python/CuTeDSL/ampere/inline_ptx.py b/examples/python/CuTeDSL/ampere/inline_ptx.py new file mode 100644 index 00000000..b678c871 --- /dev/null +++ b/examples/python/CuTeDSL/ampere/inline_ptx.py @@ -0,0 +1,245 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from functools import partial +from typing import Union + +import torch + +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack +from cutlass._mlir.dialects import llvm +from cutlass.cute.typing import Boolean, Int32, Int, Constexpr +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass.cute.arch.nvvm_wrappers import FULL_MASK, WARP_SIZE + +""" +A simple example to show how to wrap PTX instructions by using inline_asm op in llvm dialect. + +Situations like: + +1. Instructions that are not already exposed by CuTe DSL via `nvvm` module +2. Sequences of instructions that the compiler otherwise does not generate optimally + +motivate developers to inline PTX themselves. + +In this example, we inline the vote.sync.ballot.b32, vote.sync.any.pred, vote.sync.all.pred, +vote.sync.uni.pred, and use the corresponding ops in nvvm_wrappers.py for the test. + +You can refer to the documentation of `inline_asm op in llvm dialect `_ +and `vote.sync `_ +for more details. + +To run this example: + +.. code-block:: bash + + python examples/ampere/inline_ptx.py + +The example will run the vote kernel with inline PTX and nvvm dialect separately. +The results from inline PTX and nvvm dialect will be verified correspondingly. + +""" + + +@dsl_user_op +def ptx_vote_sync_op( + pred: Boolean, kind: str, mask: Int = FULL_MASK, *, loc=None, ip=None +) -> Union[Int32, Boolean]: + return_type = Boolean + return_type_str = "pred" + return return_type( + llvm.inline_asm( + T.bool(), + [ + Boolean(pred).ir_value(loc=loc, ip=ip), + Int32(mask).ir_value(loc=loc, ip=ip), + ], + f"""{{\n\t + .reg .pred ps;\n\t + .reg .pred pd;\n\t + setp.ne.b32 ps, $1, 0;\n\t + vote.sync.{kind}.{return_type_str} pd, ps, $2;\n\t + selp.b32 $0, 1, 0, pd;\n\t + }}""", + "=r,r,i", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +ptx_vote_any_sync = partial(ptx_vote_sync_op, kind="any") +ptx_vote_all_sync = partial(ptx_vote_sync_op, kind="all") +ptx_vote_uni_sync = partial(ptx_vote_sync_op, kind="uni") + + +@dsl_user_op +def ptx_vote_ballot_sync( + pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None +) -> Union[Int32, Boolean]: + return_type = Int32 + return_type_str = "b32" + return return_type( + llvm.inline_asm( + T.i32(), + [ + Boolean(pred).ir_value(loc=loc, ip=ip), + Int32(mask).ir_value(loc=loc, ip=ip), + ], + f"""{{\n\t + .reg .pred p;\n\t + setp.ne.b32 p, $1, 0;\n\t + vote.sync.ballot.{return_type_str} $0, p, $2;\n\t + }}""", + "=r,r,i", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.kernel +def vote_kernel( + mBallot: cute.Tensor, + mAny: cute.Tensor, + mAll: cute.Tensor, + mUni: cute.Tensor, + use_inline_ptx: Constexpr[bool], +): + tidx, _, _ = cute.arch.thread_idx() + + vote_ballot = ( + ptx_vote_ballot_sync(tidx < 10) + if use_inline_ptx + else cute.arch.vote_ballot_sync(tidx < 10) + ) + vote_any = ( + ptx_vote_any_sync(tidx < 10) + if use_inline_ptx + else cute.arch.vote_any_sync(tidx < 10) + ) + vote_all = ( + ptx_vote_all_sync(tidx < 10) + if use_inline_ptx + else cute.arch.vote_all_sync(tidx < 10) + ) + vote_uni = ( + ptx_vote_uni_sync(tidx < 10) + if use_inline_ptx + else cute.arch.vote_uni_sync(tidx < 10) + ) + + mBallot[tidx] = vote_ballot + mAny[tidx] = vote_any + mAll[tidx] = vote_all + mUni[tidx] = vote_uni + + +@cute.jit +def vote( + mBallot: cute.Tensor, + mAny: cute.Tensor, + mAll: cute.Tensor, + mUni: cute.Tensor, + use_inline_ptx: Constexpr[bool], +): + vote_kernel( + mBallot, + mAny, + mAll, + mUni, + use_inline_ptx, + ).launch( + grid=[1, 1, 1], + block=[cute.size(WARP_SIZE, mode=[0]), 1, 1], + ) + + +def run(): + ballot_ptx = torch.randint( + 0, 100, (WARP_SIZE,), device=torch.device("cuda"), dtype=torch.int32 + ) + any_ptx = torch.randint( + 0, 2, (WARP_SIZE,), device=torch.device("cuda"), dtype=torch.bool + ) + all_ptx = torch.randint( + 0, 2, (WARP_SIZE,), device=torch.device("cuda"), dtype=torch.bool + ) + uni_ptx = torch.randint( + 0, 2, (WARP_SIZE,), device=torch.device("cuda"), dtype=torch.bool + ) + + mBallotPTX = from_dlpack(ballot_ptx).mark_layout_dynamic() + mAnyPTX = from_dlpack(any_ptx).mark_layout_dynamic() + mAllPTX = from_dlpack(all_ptx).mark_layout_dynamic() + mUniPTX = from_dlpack(uni_ptx).mark_layout_dynamic() + + # get the results from ptx + vote(mBallotPTX, mAnyPTX, mAllPTX, mUniPTX, use_inline_ptx=True) + + ballot_nvvm = torch.randint( + 0, 100, (WARP_SIZE,), device=torch.device("cuda"), dtype=torch.int32 + ) + any_nvvm = torch.randint( + 0, 2, (WARP_SIZE,), device=torch.device("cuda"), dtype=torch.bool + ) + all_nvvm = torch.randint( + 0, 2, (WARP_SIZE,), device=torch.device("cuda"), dtype=torch.bool + ) + uni_nvvm = torch.randint( + 0, 2, (WARP_SIZE,), device=torch.device("cuda"), dtype=torch.bool + ) + + mBallotNVVM = from_dlpack(ballot_nvvm).mark_layout_dynamic() + mAnyNVVM = from_dlpack(any_nvvm).mark_layout_dynamic() + mAllNVVM = from_dlpack(all_nvvm).mark_layout_dynamic() + mUniNVVM = from_dlpack(uni_nvvm).mark_layout_dynamic() + + # get the results from nvvm + vote(mBallotNVVM, mAnyNVVM, mAllNVVM, mUniNVVM, use_inline_ptx=False) + + print("Verifying ballot results...") + torch.testing.assert_close(ballot_ptx, ballot_nvvm) + print("Verifying any results...") + torch.testing.assert_close(any_ptx, any_nvvm) + print(torch.all(any_ptx == any(i < 10 for i in range(WARP_SIZE)))) + assert torch.all(any_ptx == any(i < 10 for i in range(WARP_SIZE))) + print("Verifying all results...") + torch.testing.assert_close(all_ptx, all_nvvm) + assert torch.all(all_ptx == all(i < 10 for i in range(WARP_SIZE))) + print("Verifying uni results...") + torch.testing.assert_close(uni_ptx, uni_nvvm) + assert torch.all(uni_ptx == (len(set(i < 10 for i in range(WARP_SIZE))) == 1)) + print("Results verified successfully!") + + +if __name__ == "__main__": + run() diff --git a/examples/python/CuTeDSL/ampere/sgemm.py b/examples/python/CuTeDSL/ampere/sgemm.py index e7722a8d..0b6a8c9a 100644 --- a/examples/python/CuTeDSL/ampere/sgemm.py +++ b/examples/python/CuTeDSL/ampere/sgemm.py @@ -36,7 +36,7 @@ import torch import cutlass import cutlass.cute as cute import cutlass.cute.testing as testing -import cutlass.torch as cutlass_torch +import cutlass.pipeline as pipeline import cutlass.utils as utils from cutlass.cute.runtime import from_dlpack @@ -103,6 +103,9 @@ class SGemm: assert self._bM % 16 == 0, "multiple of 16 required for tile dimension M" assert self._bN % 16 == 0, "multiple of 16 required for tile dimension N" assert self._num_stages >= 3, "num_stages must be greater than or equal to 3" + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, num_threads=num_threads + ) @cute.jit def __call__( @@ -166,9 +169,8 @@ class SGemm: mA.element_type, num_bits_per_copy=mB.element_type.width, ) - if cutlass.const_expr(self.a_major_mode == utils.LayoutEnum.COL_MAJOR): - num_vectorized = 4 if (mA.layout.max_alignment % 16 == 0) else 1 + num_vectorized = 4 if (mA.layout[0].max_alignment % 16 == 0) else 1 atom_async_copy_A = cute.make_copy_atom( cute.nvgpu.cpasync.CopyG2SOp(), mA.element_type, @@ -182,7 +184,7 @@ class SGemm: vA = cute.make_layout((num_vectorized, 1)) if cutlass.const_expr(self.b_major_mode == utils.LayoutEnum.COL_MAJOR): - num_vectorized = 4 if (mB.layout.max_alignment % 16 == 0) else 1 + num_vectorized = 4 if (mB.layout[0].max_alignment % 16 == 0) else 1 atom_async_copy_B = cute.make_copy_atom( cute.nvgpu.cpasync.CopyG2SOp(), mA.element_type, @@ -294,7 +296,7 @@ class SGemm: # tile (instead of the last one) irregular in shape when k is irregular. # We first handle the irregular tile to avoid checking for this # condition within the mainloop. - residue_k = mA.shape[1] - cutlass.Int32(self._bK) * gA.shape[2] + residue_k = mA.shape[1] - self._bK * gA.shape[2] gA = cute.domain_offset((0, residue_k, 0), gA) gB = cute.domain_offset((0, residue_k, 0), gB) @@ -342,7 +344,7 @@ class SGemm: tAcA = thr_copy_A.partition_S(cA) tBcB = thr_copy_B.partition_S(cB) # Allocate predicate tensors for m and n - tApA = cute.make_fragment( + tApA = cute.make_rmem_tensor( cute.make_layout( ( tAsA.shape[0][1], @@ -353,7 +355,7 @@ class SGemm: ), cutlass.Boolean, ) - tBpB = cute.make_fragment( + tBpB = cute.make_rmem_tensor( cute.make_layout( ( tBsB.shape[0][1], @@ -365,7 +367,7 @@ class SGemm: cutlass.Boolean, ) # Allocate predicate tensors for m, n and k for residue k-tile - tApA_residue_k = cute.make_fragment( + tApA_residue_k = cute.make_rmem_tensor( cute.make_layout( ( tAsA.shape[0][1], @@ -380,7 +382,7 @@ class SGemm: ), cutlass.Boolean, ) - tBpB_residue_k = cute.make_fragment( + tBpB_residue_k = cute.make_rmem_tensor( cute.make_layout( ( tBsB.shape[0][1], @@ -508,7 +510,7 @@ class SGemm: if k_block_max > 1: # Wait until our first prefetched tile is loaded in cute.arch.cp_async_wait_group(k_pipe_max - 2) - cute.arch.barrier() + self.cta_sync_barrier.arrive_and_wait() # Prefetch the first rmem from the first k-tile cute.autovec_copy(tCsA_p[None, None, 0], tCrA[None, None, 0]) cute.autovec_copy(tCsB_p[None, None, 0], tCrB[None, None, 0]) @@ -545,7 +547,7 @@ class SGemm: tCsA_p = tCsA[None, None, None, smem_pipe_read] tCsB_p = tCsB[None, None, None, smem_pipe_read] cute.arch.cp_async_wait_group(k_pipe_max - 2) - cute.arch.barrier() + self.cta_sync_barrier.arrive_and_wait() # Load A, B from shared memory to registers for k_block + 1 k_block_next = (k_block + 1) % k_block_max # static @@ -611,13 +613,13 @@ class SGemm: # them without vectorization. # /////////////////////////////////////////////////////////////////////////////// cute.arch.cp_async_wait_group(0) - cute.arch.barrier() + self.cta_sync_barrier.arrive_and_wait() tCrC.store(epilogue_op(tCrC.load())) # predicate cC = cute.make_identity_tensor(gC.shape) tCpC = thr_mma.partition_C(cC) - predC = cute.make_fragment(tCrC.layout, cutlass.Boolean) + predC = cute.make_rmem_tensor(tCrC.layout, cutlass.Boolean) residue_m = mC.shape[0] - cutlass.Int32(self._bM) * bidx residue_n = mC.shape[1] - cutlass.Int32(self._bN) * bidy for i in range(cute.size(tCrC.shape)): @@ -664,7 +666,7 @@ def run( :return: Execution time of the GEMM kernel in microseconds :rtype: float """ - print(f"Running Ampere SIMT GEMM example:") + print("Running Ampere SIMT GEMM example:") print(f"mnk: {mnk}") print(f"A major: {a_major}, B major: {b_major}, C major: {c_major}") print(f"Static shape: {static_shape}") @@ -697,14 +699,17 @@ def run( divisibility_b = b.shape[1] if b_major == "k" else b.shape[0] divisibility_c = c.shape[1] if c_major == "n" else c.shape[0] - a_tensor = ( - from_dlpack(a, assumed_align=16) - .mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0)) - .mark_compact_shape_dynamic( - mode=(1 if a_major == "k" else 0), - divisibility=divisibility_a, + if static_shape: + a_tensor = ( + from_dlpack(a, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0)) + .mark_compact_shape_dynamic( + mode=(1 if a_major == "k" else 0), + divisibility=divisibility_a, + ) ) - ) + else: + a_tensor = from_dlpack(a, assumed_align=16) b_tensor = ( from_dlpack(b, assumed_align=16) @@ -733,12 +738,8 @@ def run( print("Compiling kernel with cute.compile ...") start_time = time.time() - compiled_fn = cute.compile( - sgemm, - a_tensor, - b_tensor, - c_tensor, - stream=current_stream, + compiled_fn = cute.compile[cute.GenerateLineInfo]( + sgemm, a_tensor, b_tensor, c_tensor, stream=current_stream ) compilation_time = time.time() - start_time print(f"Compilation time: {compilation_time:.4f} seconds") @@ -833,7 +834,7 @@ if __name__ == "__main__": parser.add_argument( "--mnk", type=parse_comma_separated_ints, default=(256, 256, 64) ) - parser.add_argument("--a_major", choices=["k", "m"], default="k") + parser.add_argument("--a_major", choices=["k", "m"], default="m") parser.add_argument("--b_major", choices=["k", "n"], default="k") parser.add_argument("--c_major", choices=["n", "m"], default="n") parser.add_argument("--warmup_iterations", default=2, type=int) diff --git a/examples/python/CuTeDSL/ampere/smem_allocator.py b/examples/python/CuTeDSL/ampere/smem_allocator.py index f9f5c1e0..afd9871b 100644 --- a/examples/python/CuTeDSL/ampere/smem_allocator.py +++ b/examples/python/CuTeDSL/ampere/smem_allocator.py @@ -69,7 +69,7 @@ class complex: class SharedStorage: # struct elements with natural alignment a: cute.struct.MemRange[cutlass.Float32, 32] # array - b: cutlass.Int64 # saclar + b: cutlass.Int64 # scalar c: complex # nested struct # struct elements with strict alignment x: cute.struct.Align[ @@ -90,10 +90,17 @@ def kernel( dst_c: cute.Tensor, ): # Note: SMEM_SIZE bytes (specified in kernel().launch(smem=...)) can be reserved for developer to utilize - # Note: alignment of inital allocator base ptr is 1024 + # Note: alignment of initial allocator base ptr is 1024 allocator = cutlass.utils.SmemAllocator() # base ptr of allocator points at: SMEM_ADDR_START (the starting address of available shared memory) + # -- Allocate a scalar + int_ptr = allocator.allocate(cutlass.Int32) + # base ptr of allocator now points at: SMEM_ADDR_AFTER_INT = SMEM_ADDR_START + aligned_size(int) + assert int_ptr.dtype == cutlass.Int32, "Expected Int32, but got {}".format( + int_ptr.dtype + ) + # -- Allocate a struct -- # Note: when specified alignment, max(alignment, alignof(struct)) will be applied # reserves the section of struct in smem, elements in the struct can be accessed by ptr @@ -153,7 +160,7 @@ def kernel( @cute.jit -def run_allocation_kernel( +def host( const_a: cutlass.Constexpr, dst_a: cute.Tensor, const_b: cutlass.Constexpr, @@ -161,22 +168,18 @@ def run_allocation_kernel( const_c: cutlass.Constexpr, dst_c: cute.Tensor, ): - # additional size for the example, 64(section) + 112(array) + 128(tensor) < 384 - addtional_bytes = 384 - # Note: launch shared memory size is: SMEM_SIZE = 512 + 384 = 896 bytes + # Note: Shared Memory size is automatically calculated now kernel(const_a, dst_a, const_b, dst_b, const_c, dst_c).launch( - grid=(1, 1, 1), - block=(1, 1, 1), - smem=SharedStorage.size_in_bytes() + addtional_bytes, + grid=(1, 1, 1), block=(1, 1, 1) ) -def veify_allocation_kernel(const_a, const_b, const_c): +def run_and_verify(const_a, const_b, const_c): dst_a = torch.zeros((8, 4), dtype=torch.float32, device="cuda") dst_b = torch.zeros((8, 2), dtype=torch.float32, device="cuda") dst_c = torch.zeros((16, 2), dtype=torch.float32, device="cuda") - run_allocation_kernel( + host( const_a, from_dlpack(dst_a), const_b, @@ -185,9 +188,15 @@ def veify_allocation_kernel(const_a, const_b, const_c): from_dlpack(dst_c), ) - np.testing.assert_equal(const_a, dst_a.detach().cpu().numpy()[0]) - np.testing.assert_equal(const_b, dst_b.detach().cpu().numpy()[0]) - np.testing.assert_equal(const_c, dst_c.detach().cpu().numpy()[0]) + assert const_a == dst_a.cpu()[0, 0], ( + f"Expected {const_a}, but got {dst_a.cpu()[0, 0]}" + ) + assert const_b == dst_b.cpu()[0, 0], ( + f"Expected {const_b}, but got {dst_b.cpu()[0, 0]}" + ) + assert const_c == dst_c.cpu()[0, 0], ( + f"Expected {const_c}, but got {dst_c.cpu()[0, 0]}" + ) if __name__ == "__main__": @@ -197,4 +206,4 @@ if __name__ == "__main__": const_a = 0.5 const_b = 1.0 const_c = 2.0 - veify_allocation_kernel(const_a, const_b, const_c) + run_and_verify(const_a, const_b, const_c) diff --git a/examples/python/CuTeDSL/ampere/tensorop_gemm.py b/examples/python/CuTeDSL/ampere/tensorop_gemm.py index 413d1fbf..d3cc709e 100644 --- a/examples/python/CuTeDSL/ampere/tensorop_gemm.py +++ b/examples/python/CuTeDSL/ampere/tensorop_gemm.py @@ -28,10 +28,8 @@ import argparse import math -import time from typing import Tuple, Type -import cuda.bindings.driver as cuda import torch import cutlass @@ -121,12 +119,12 @@ class TensorOpGemm: self.mma_inst_shape = (16, 8, 16) mmaM, mmaN, mmaK = self.mma_inst_shape - assert ( - self.bM % (atom_lay_M * mmaM) == 0 - ), "bM must be divisible by MMA instruction" - assert ( - self.bN % (atom_lay_N * mmaN) == 0 - ), "bN must be divisible by MMA instruction" + assert self.bM % (atom_lay_M * mmaM) == 0, ( + "bM must be divisible by MMA instruction" + ) + assert self.bN % (atom_lay_N * mmaN) == 0, ( + "bN must be divisible by MMA instruction" + ) assert atom_lay_K == 1, "this example does not support atom layout K > 1" assert self.bK % mmaK == 0, "bK must be divisible by MMA instruction" assert self.num_stages >= 3, "num_stages must be greater than or equal to 3" @@ -428,7 +426,7 @@ class TensorOpGemm: # at the granularity of a copy atom, so the predicate tensor does not # need separate booleans for individual elements within a copy # atom (for example, the elements of tAgA.shape[0][0].) - tApA = cute.make_fragment( + tApA = cute.make_rmem_tensor( cute.make_layout( ( tAgA.shape[0][1], @@ -439,7 +437,7 @@ class TensorOpGemm: ), cutlass.Boolean, ) - tBpB = cute.make_fragment( + tBpB = cute.make_rmem_tensor( cute.make_layout( ( tBsB.shape[0][1], @@ -707,7 +705,7 @@ class TensorOpGemm: cute.autovec_copy(tCsC_epilogue, tCrC_epilogue) # Create predication tensor for m - tCpC = cute.make_fragment( + tCpC = cute.make_rmem_tensor( cute.make_layout( ( tCgC_epilogue.shape[0][1], @@ -851,7 +849,7 @@ def run( use_cold_l2: bool = False, **kwargs, ): - print(f"Running Ampere tensor core GEMM example:") + print("Running Ampere tensor core GEMM example:") print(f"mnkl: {mnkl}") print( f"A dtype: {ab_dtype}, B dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}" @@ -944,6 +942,7 @@ def run( return avg_time_us # Return execution time in microseconds + if __name__ == "__main__": def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: diff --git a/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py b/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py new file mode 100644 index 00000000..7d76f724 --- /dev/null +++ b/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py @@ -0,0 +1,2927 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Type, Tuple, Union + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils + +import math + + +""" +High-performance persistent blockwise dense GEMM (C = (SFA * A) * (SFB * B)) example for the NVIDIA Blackwell architecture +using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") +- Matrix B is NxKxL, L is batch dimension, B can be column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") +- Each block will apply the scale factor A +- Each row will apply the scale factor B +- For each iteration, the kernel will compute C = A * B and then apply the scale factor C *= SFA * SFB + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. SCALE warp: Load scaleA and scaleB matrices from global memory (GMEM) to shared memory (SMEM) using non-TMA operations. +2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. EPILOGUE warp: + - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Apply the scale factor and update the final accumulator Final = C * SFA * SFB + Final + - Type convert Final matrix to output type. + - Store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations. + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +.. code-block:: bash + + python examples/blackwell/blockwise_gemm/blockwise_gemm.py \ + --ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \ + --scale_dtype Float32 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \ + --mnkl 4096,4096,4096,4 + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/blockwise_gemm/blockwise_gemm.py \ + --ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \ + --scale_dtype Float32 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \ + --mnkl 4096,4096,4096,4 + + +Constraints are same as dense_gemm.py: +* Supported input data types: fp8 (e4m3fn) + see detailed valid dtype combinations in below BlockwiseGemmKernel class documentation +* A/B tensor must have the same data type +* Mma tiler M must be 64/128/256 +* Mma tiler N must be 128, align with the scaleB requirement +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned +""" + + +class BlockwiseGemmKernel: + """This class implements batched matrix multiplication (C = (SFA * A) * (SFB * B)) with support for fp8 (e4m3fn, e5m2) + and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: Supported A/B data types: + - Float8E4M3FN + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float16/BFloat16 + - Other data types are not supported for accuracy issues + + :note: Constraints: + - MMA tiler M must be 64/128/256 + - MMA tiler N must be 128 + - Cluster shape M must be multiple of 2 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = BlockwiseGemmKernel( + ... acc_dtype=cutlass.Float32, + ... use_2cta_instrs=True, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, sfa_tensor, sfb_tensor, max_active_clusters, stream) + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ): + """Initializes the configuration for a Blackwell blockwise dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = ( + tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + # Set specialized warp ids + self.acc_update_warp_id = (0, 1, 2, 3) + self.epilog_warp_id = (4, 5, 6, 7) + self.mma_warp_id = 8 + self.tma_warp_id = 9 + self.scale_warp_id = 10 + self.sched_warp_id = 11 + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( + ( + *self.acc_update_warp_id, + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.scale_warp_id, + self.sched_warp_id, + ) + ) + self.threads_wo_sched = self.threads_per_warp * len( + ( + *self.acc_update_warp_id, + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.scale_warp_id, + ) + ) + self.num_regs_uniform_warps = 64 + self.num_regs_sched_warps = 64 + self.num_regs_epilogue_warps = 216 + self.num_regs_acc_update_warps = 216 + + # Set barrier for cta sync, epilogue sync and tmem ptr sync + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=32 + * len((self.mma_warp_id, *self.epilog_warp_id, *self.acc_update_warp_id)), + ) + self.sched_sync_barrier = pipeline.NamedBarrier( + barrier_id=4, + num_threads=self.threads_per_warp, + ) + self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + # TMEM offset for final accumulator + self.tmem_final_offset = 384 + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + self.scale_granularity_m = 1 + self.scale_granularity_n = 128 + self.scale_granularity_k = 128 + self.scale_m_per_tile = self.cta_tile_shape_mnk[0] // self.scale_granularity_m + self.scale_n_per_tile = self.cta_tile_shape_mnk[1] // self.scale_granularity_n + self.scale_k_per_tile = self.cta_tile_shape_mnk[2] // self.scale_granularity_k + + if self.scale_k_per_tile != 1: + raise ValueError("scale_k_per_tile must be 1") + if self.scale_m_per_tile != self.cta_tile_shape_mnk[0]: + raise ValueError("scale_m_per_tile must be cta_tile_m") + if self.scale_n_per_tile != 1: + raise ValueError("scale_n_per_tile must be 1") + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # Setup A/B/C/Scale stage count in shared memory and ACC stage count in tensor memory + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_c_stage, + self.num_scale_stage, + self.num_tile_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sfa_dtype, + self.sfb_dtype, + self.scale_m_per_tile * self.scale_k_per_tile, + self.scale_n_per_tile * self.scale_k_per_tile, + self.num_smem_capacity, + self.occupancy, + ) + + # Compute A/B/C/Scale shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + self.sfa_smem_layout_staged = cute.make_layout( + ( + (self.scale_granularity_m, self.scale_m_per_tile), + (self.scale_granularity_k, self.scale_k_per_tile), + self.num_scale_stage, + ), + stride=( + (0, self.scale_k_per_tile), + (0, 1), + self.scale_k_per_tile * self.scale_m_per_tile, + ), + ) + self.sfb_smem_layout_staged = cute.make_layout( + ( + (self.scale_granularity_n, self.scale_n_per_tile), + (self.scale_granularity_k, self.scale_k_per_tile), + self.num_scale_stage, + ), + stride=( + (0, self.scale_k_per_tile), + (0, 1), + self.scale_k_per_tile * self.scale_n_per_tile, + ), + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = 512 + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + sfa: cute.Tensor, + sfb: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param sfa: Scale factor tensor A + :type sfa: cute.Tensor + :param sfb: Scale factor tensor B + :type sfb: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.sfa_dtype: Type[cutlass.Numeric] = sfa.element_type + self.sfb_dtype: Type[cutlass.Numeric] = sfb.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if a.element_type is cutlass.Float32 else None + ), + ) + + # Setup TMA load for B + b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup TMA store for C + tma_atom_c = None + tma_tensor_c = None + c_cta_v_layout = cute.composition( + cute.make_identity_layout(c.shape), self.epi_tile + ) + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + c_cta_v_layout, + ) + + tensor_sfa = cute.make_tensor( + sfa.iterator, + cute.make_layout( + ( + (self.scale_granularity_m, sfa.shape[0]), + (self.scale_granularity_k, sfa.shape[1]), + sfa.shape[2], + ), + stride=( + (0, sfa.layout.stride[0]), + (0, sfa.layout.stride[1]), + sfa.layout.stride[2], + ), + ), + ) + tensor_sfb = cute.make_tensor( + sfb.iterator, + cute.make_layout( + ( + (self.scale_granularity_n, sfb.shape[0]), + (self.scale_granularity_k, sfb.shape[1]), + sfb.shape[2], + ), + stride=( + (0, sfb.layout.stride[0]), + (0, sfb.layout.stride[1]), + sfb.layout.stride[2], + ), + ), + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + + c_smem_size = cute.cosize(self.c_smem_layout_staged.outer) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + # (bidx, bidy, bidz, valid) + sInfo: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int32, 4 * self.num_tile_stage], + # 1 byte alignment + 1, + ] + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + scale_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_scale_stage * 2 + ] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tile_info_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_tile_stage * 2 + ] + epi_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + c_smem_size, + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA: cute.struct.Align[ + cute.struct.MemRange[ + self.sfa_dtype, cute.cosize(self.sfa_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB: cute.struct.Align[ + cute.struct.MemRange[ + self.sfb_dtype, cute.cosize(self.sfb_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + tensor_sfa, + tensor_sfb, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + mSFA_mkl: cute.Tensor, + mSFB_nkl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + lane_idx = cute.arch.lane_idx() + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize mainloop scale_pipeline (barrier) and states + scale_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + ) + scale_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_warp_id), + ) + scale_pipeline = pipeline.PipelineCpAsync.create( + barrier_storage=storage.scale_mbar_ptr.data_ptr(), + num_stages=self.num_scale_stage, + producer_group=scale_pipeline_producer_group, + consumer_group=scale_pipeline_consumer_group, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * ( + 2 if use_2cta_instrs else 1 + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize epilogue pipeline (barrier) and states + epi_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.acc_update_warp_id), + ) + epi_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_warp_id), + ) + epi_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.epi_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=epi_pipeline_producer_group, + consumer_group=epi_pipeline_consumer_group, + ) + + # Initialize tile info pipeline (barrier) and states + tile_info_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + ) + tile_info_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_wo_sched, + ) + tile_info_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.tile_info_mbar_ptr.data_ptr(), + num_stages=self.num_tile_stage, + producer_group=tile_info_pipeline_producer_group, + consumer_group=tile_info_pipeline_consumer_group, + ) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/C/Scale + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + # (bidx, bidy, bidz, valid) + info_layout = cute.make_layout((4, self.num_tile_stage), stride=(1, 4)) + sInfo = storage.sInfo.get_tensor(info_layout) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, loopM, loopK, loopL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, loopN, loopK, loopL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, loopM, loopN, loopL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + # (bM, bK, loopM, loopK, loopL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, + cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + # (bN, bK, loopN, loopK, loopL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, + cute.slice_(self.cta_tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + # coordinate + cSFA_mkl = cute.make_identity_tensor(cute.shape(mSFA_mkl)) + cSFB_nkl = cute.make_identity_tensor(cute.shape(mSFB_nkl)) + # (bM, bK, loopM, loopK, loopL) + cSFA = cute.local_tile( + cSFA_mkl, + cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + # (bN, bK, loopN, loopK, loopL) + cSFB = cute.local_tile( + cSFB_nkl, + cute.slice_(self.cta_tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) + tCgC = thr_mma.partition_C(gC_mnl) + + # scale viewed as C tensor + sSFA_view_as_C_layout = cute.make_layout( + ( + (self.scale_granularity_m, self.scale_m_per_tile), + self.cta_tile_shape_mnk[1], + self.num_scale_stage, + ), + stride=((0, 1), 0, self.scale_m_per_tile), + ) + sSFB_view_as_C_layout = cute.make_layout( + ( + self.cta_tile_shape_mnk[0], + (self.scale_granularity_n, self.scale_n_per_tile), + self.num_scale_stage, + ), + stride=(0, (0, 1), self.scale_n_per_tile), + ) + sSFA_view_as_C = cute.make_tensor(sSFA.iterator, sSFA_view_as_C_layout) + sSFB_view_as_C = cute.make_tensor(sSFB.iterator, sSFB_view_as_C_layout) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition global/shared tensor for TMA load A/B + # + # load scaleA/scaleB + atom_copy = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mSFA_mkl.element_type, + num_bits_per_copy=mSFA_mkl.element_type.width, + ) + tiled_copy_sfa = cute.make_tiled_copy_tv( + atom_copy, cute.make_layout((32,)), cute.make_layout((1,)) + ) + tiled_copy_sfb = cute.make_tiled_copy_tv( + atom_copy, cute.make_layout((32,)), cute.make_layout((1,)) + ) + thr_copy_sfa = tiled_copy_sfa.get_slice(lane_idx) + thr_copy_sfb = tiled_copy_sfb.get_slice(lane_idx) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAgSFA_mkl = thr_copy_sfa.partition_S(gSFA_mkl) + tAsSFA = thr_copy_sfa.partition_D(sSFA) + tAcSFA = thr_copy_sfa.partition_S(cSFA) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopN, loopK, loopL) + tBgSFB_nkl = thr_copy_sfb.partition_S(gSFB_nkl) + tBsSFB = thr_copy_sfb.partition_D(sSFB) + tBcSFB = thr_copy_sfb.partition_S(cSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + self.cta_sync_barrier.arrive_and_wait() + + # + # Specialized Schedule warp + # + if warp_idx == self.sched_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_sched_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + tile_info_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_tile_stage + ) + + while work_tile.is_valid_tile: + # query next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # acquire tile info pipeline + tile_info_pipeline.producer_acquire(tile_info_producer_state) + + # store the tile info + cur_tile_coord = work_tile.tile_idx + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0] + sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1] + sInfo[(2, tile_info_producer_state.index)] = cur_tile_coord[2] + sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32( + work_tile.is_valid_tile + ) + + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.sched_sync_barrier.arrive_and_wait() + # commit tile info pipeline + tile_info_pipeline.producer_commit(tile_info_producer_state) + tile_info_producer_state.advance() + + tile_info_pipeline.producer_tail(tile_info_producer_state) + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), loopK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), loopK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + tAgA_k = tAgA_slice[(None, ab_producer_state.count)] + tBgB_k = tBgB_slice[(None, ab_producer_state.count)] + tAsA_pipe = tAsA[(None, ab_producer_state.index)] + tBsB_pipe = tBsB[(None, ab_producer_state.index)] + + tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state) + + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized Scale load warp + # + if warp_idx == self.scale_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + scale_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_scale_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # + # Prepare the mask for scaleA/scaleB + # + tApSFA = cute.make_rmem_tensor( + cute.make_layout( + cute.filter_zeros( + cute.slice_(tAsSFA, (None, None, None, 0)) + ).shape + ), + cutlass.Boolean, + ) + tBpSFB = cute.make_rmem_tensor( + cute.make_layout( + cute.filter_zeros( + cute.slice_(tBsSFB, (None, None, None, 0)) + ).shape + ), + cutlass.Boolean, + ) + + # Peek (try_wait) SCALE buffer empty + scale_producer_state.reset_count() + peek_scale_empty_status = cutlass.Boolean(1) + if scale_producer_state.count < k_tile_cnt: + peek_scale_empty_status = scale_pipeline.producer_try_acquire( + scale_producer_state + ) + + # + # load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # + # Slice to per mma tile index + # + tAsSFA_pipe = cute.filter_zeros( + tAsSFA[(None, None, None, scale_producer_state.index)] + ) + tBsSFB_pipe = cute.filter_zeros( + tBsSFB[(None, None, None, scale_producer_state.index)] + ) + tAgSFA_k = cute.filter_zeros( + tAgSFA_mkl[ + ( + None, + None, + None, + tile_info[0], + scale_producer_state.count, + tile_info[2], + ) + ] + ) + tBgSFB_k = cute.filter_zeros( + tBgSFB_nkl[ + ( + None, + None, + None, + tile_info[1], + scale_producer_state.count, + tile_info[2], + ) + ] + ) + + tAcSFA_compact = cute.filter_zeros( + cute.slice_( + tAcSFA, + ( + None, + None, + None, + tile_info[0], + scale_producer_state.count, + tile_info[2], + ), + ) + ) + tBcSFB_compact = cute.filter_zeros( + cute.slice_( + tBcSFB, + ( + None, + None, + None, + tile_info[1], + scale_producer_state.count, + tile_info[2], + ), + ) + ) + for i in cutlass.range_constexpr(cute.size(tApSFA, mode=[1])): + tApSFA[((0, 0), i, (0, 0))] = cute.elem_less( + tAcSFA_compact[(i)][0], mSFA_mkl.shape[0] + ) + for i in cutlass.range_constexpr(cute.size(tBpSFB, mode=[1])): + tBpSFB[((0, 0), i, (0, 0))] = cute.elem_less( + tBcSFB_compact[(i)][0], mSFB_nkl.shape[0] + ) + + # Conditionally wait for Scale buffer empty + scale_pipeline.producer_acquire( + scale_producer_state, peek_scale_empty_status + ) + + # load scaleA/scaleB + cute.copy(tiled_copy_sfa, tAgSFA_k, tAsSFA_pipe, pred=tApSFA) + cute.copy(tiled_copy_sfb, tBgSFB_k, tBsSFB_pipe, pred=tBpSFB) + + scale_pipeline.producer_commit(scale_producer_state) + + # Peek (try_wait) Scale buffer empty + scale_producer_state.advance() + peek_scale_empty_status = cutlass.Boolean(1) + if scale_producer_state.count < k_tile_cnt: + peek_scale_empty_status = scale_pipeline.producer_try_acquire( + scale_producer_state + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait Scale buffer empty + # + scale_pipeline.producer_tail(scale_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # Peek (try_wait) Acc buffer empty for k_tile = 0 + acc_producer_state.reset_count() + peek_acc_empty_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire( + acc_producer_state + ) + + # + # Mma mainloop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire( + acc_producer_state, peek_acc_empty_status + ) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait( + ab_consumer_state, peek_ab_full_status + ) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Async arrive accumulator buffer full(each kblock) + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + + # Peek (try_wait) Acc buffer empty for k_tile = k_tile + 1 + acc_producer_state.advance() + if acc_producer_state.count < k_tile_cnt: + if is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire( + acc_producer_state + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + + # + # Specialized acc update warps + # + if warp_idx <= self.acc_update_warp_id[-1]: + cute.arch.warpgroup_reg_alloc(self.num_regs_acc_update_warps) + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + tCtAcc_final = cute.make_tensor( + tCtAcc_base.iterator + self.tmem_final_offset, tCtAcc_base.layout + ) + + # + # Partition for epilogue + # + epi_tidx = tidx % 128 + ( + tiled_copy_t2r, + tiled_copy_r2t, + tTR_tAcc_base, + tTR_rAcc, + tTR_rAcc_final, + tTR_sSFA, + tTR_sSFB, + tRT_rAcc, + tRT_tAcc_base, + ) = self.acc_update_tmem_copy_and_partition( + epi_tidx, + tCtAcc_base, + tCtAcc_final, + tCgC, + sSFA_view_as_C, + sSFB_view_as_C, + epi_tile, + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + scale_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_scale_stage + ) + + epi_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, 1 + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # initialize the final accumulator + tTR_rAcc_final.fill(0.0) + + tTR_rSFA = cute.make_rmem_tensor( + cute.slice_(tTR_sSFA, (None, None, None, 0, None, 0)).shape, + self.acc_dtype, + ) + tTR_rSFB = cute.make_rmem_tensor( + cute.slice_(tTR_sSFB, (None, None, None, 0, None, 0)).shape, + self.acc_dtype, + ) + + scale_consumer_state.reset_count() + peek_scale_full_status = cutlass.Boolean(1) + if scale_consumer_state.count < k_tile_cnt: + peek_scale_full_status = scale_pipeline.consumer_try_wait( + scale_consumer_state + ) + + acc_consumer_state.reset_count() + peek_acc_full_status = cutlass.Boolean(1) + if acc_consumer_state.count < k_tile_cnt: + peek_acc_full_status = acc_pipeline.consumer_try_wait( + acc_consumer_state + ) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # + # Wait for scale buffer full + # + scale_pipeline.consumer_wait( + scale_consumer_state, peek_scale_full_status + ) + + tTR_sSFA_slice = cute.slice_( + tTR_sSFA, + (None, None, None, 0, None, scale_consumer_state.index), + ) + tTR_sSFB_slice = cute.slice_( + tTR_sSFB, + (None, None, None, 0, None, scale_consumer_state.index), + ) + + scale_atom_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.acc_dtype, + num_bits_per_copy=self.acc_dtype.width, + ) + + cute.copy(scale_atom_copy, tTR_sSFA_slice, tTR_rSFA) + cute.copy(scale_atom_copy, tTR_sSFB_slice, tTR_rSFB) + + # + # Wait for accumulator buffer full + # + + acc_pipeline.consumer_wait(acc_consumer_state, peek_acc_full_status) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + # + # Update accumulator by scale factor in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Update accumulator by scale factor + # + tTR_rAcc_subtile = tTR_rAcc_final[ + (None, None, None, subtile_idx) + ] + tTR_rSFA_subtile = tTR_rSFA[(None, None, None, subtile_idx)] + tTR_rSFB_subtile = tTR_rSFB[(None, None, None, subtile_idx)] + + acc_vec = tTR_rAcc.load() + final_vec = tTR_rAcc_subtile.load() + scale_a = tTR_rSFA_subtile.load() + scale_b = tTR_rSFB_subtile.load() + scale = scale_a * scale_b + final_vec = acc_vec * scale + final_vec + tTR_rAcc_subtile.store(final_vec.to(self.acc_dtype)) + + # + # Async arrive accumulator buffer empty + # + scale_pipeline.consumer_release(scale_consumer_state) + scale_consumer_state.advance() + + peek_scale_full_status = cutlass.Boolean(1) + if scale_consumer_state.count < k_tile_cnt: + peek_scale_full_status = scale_pipeline.consumer_try_wait( + scale_consumer_state + ) + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + peek_acc_full_status = cutlass.Boolean(1) + if acc_consumer_state.count < k_tile_cnt: + peek_acc_full_status = acc_pipeline.consumer_try_wait( + acc_consumer_state + ) + + tRT_tAcc = tRT_tAcc_base[(None, None, None, None, None, 0)] + tRT_tAcc = cute.group_modes(tRT_tAcc, 3, cute.rank(tRT_tAcc)) + + # + # Wait for epilogue buffer empty + # + epi_pipeline.producer_acquire(epi_producer_state) + + # copy the accumulator to tensor memory buffer + cute.copy(tiled_copy_r2t, tTR_rAcc_final, tRT_tAcc) + cute.arch.fence_view_async_tmem_store() + + # + # Async arrive epilogue buffer full + # + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Specialized epilogue warps + # + if warp_idx <= self.epilog_warp_id[-1] and warp_idx >= self.epilog_warp_id[0]: + cute.arch.warpgroup_reg_alloc(self.num_regs_epilogue_warps) + # + # Alloc tensor memory buffer + # + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base_ = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + tCtAcc_final = cute.make_tensor( + tCtAcc_base_.iterator + self.tmem_final_offset, tCtAcc_base_.layout + ) + + # + # Partition for epilogue + # + epi_tidx = tidx % 128 + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_final, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_r2s = None + simt_atom = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC_partitioned = None + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + epi_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, 1 + ) + + c_pipeline = None + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + num_prev_subtiles = cutlass.Int32(0) + + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Slice to per mma tile index + # + bSG_gC = None + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + mma_tile_coord_mnl[0], + mma_tile_coord_mnl[1], + mma_tile_coord_mnl[2], + ) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, epi_consumer_state.index) + ] + + # + # Wait for accumulator buffer full + # + epi_pipeline.consumer_wait(epi_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + num_prev_subtiles = num_prev_subtiles + 1 + c_buffer = num_prev_subtiles % self.num_c_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.epilog_sync_barrier.arrive_and_wait() + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + + # + # Async arrive accumulator buffer empty + # + epi_pipeline.consumer_release(epi_consumer_state) + epi_consumer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + # + # Wait for C store complete + # + c_pipeline.producer_tail() + + def acc_update_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + tAcc_final: cute.Tensor, + gC_mnl: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + epi_tile: cute.Tile, + ) -> Tuple[ + cute.TiledCopy, + cute.TiledCopy, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + ]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + Make tiledCopy for tensor memory store, then use it to partition register array (source) and tensor memory (destination). + Partition the scale factor tensor for related copy operations. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param tAcc_final: The final accumulator tensor to be copied and partitioned + :type tAcc_final: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param sSFA: The scale factor tensor for A + :type sSFA: cute.Tensor + :param sSFB: The scale factor tensor for B + :type sSFB: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tiled_copy_r2t: The tiled copy operation for register to tmem copy(r2t) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + - tTR_rAcc_final: The accumulated tensor in register used to hold all t2r results + - tTR_sSFA: The partitioned tensor SFA by tiled_copy_t2r + - tTR_sSFB: The partitioned tensor SFB by tiled_copy_t2r + - tRT_rAcc_final: The accumulated tensor in register used to hold all r2t results + - tRT_tAcc_final: The partitioned accumulator tensor by tiled_copy_r2t + :rtype: Tuple[cute.TiledCopy, cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + tmem_load_atom = None + tmem_store_atom = None + if cutlass.const_expr(self.mma_tiler[0] == 64): + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)), + self.acc_dtype, + ) + elif cutlass.const_expr(self.mma_tiler[0] == 128): + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.acc_dtype, + ) + else: + # default: 16dp + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(1)), + self.acc_dtype, + ) + if cutlass.const_expr(self.mma_tiler[0] == 64): + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St16x256bOp(tcgen05.copy.Repetition(8)), + self.acc_dtype, + ) + elif cutlass.const_expr(self.mma_tiler[0] == 128): + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), + self.acc_dtype, + ) + else: + # default: 16dp + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St16x256bOp(tcgen05.copy.Repetition(1)), + self.acc_dtype, + ) + + tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile) + tAcc_final_epi = cute.flat_divide( + tAcc_final[((None, None), 0, 0, None)], epi_tile + ) + + tiled_copy_t2r = tcgen05.make_tmem_copy( + tmem_load_atom, tAcc_epi[(None, None, 0, 0, 0)] + ) + tiled_copy_r2t = tcgen05.make_tmem_copy( + tmem_store_atom, tAcc_final_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + thr_copy_r2t = tiled_copy_r2t.get_slice(tidx) + + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + sSFA_epi = cute.flat_divide(sSFA, epi_tile) + sSFB_epi = cute.flat_divide(sSFB, epi_tile) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + tTR_sSFA = thr_copy_t2r.partition_D(sSFA_epi) + tTR_sSFB = thr_copy_t2r.partition_D(sSFB_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rAcc_final_ = cute.make_rmem_tensor( + tTR_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype + ) + tTR_rAcc_final = cute.group_modes( + tTR_rAcc_final_, 3, cute.rank(tTR_rAcc_final_) + ) + + tRT_gC = thr_copy_r2t.partition_S(gC_mnl_epi) + tRT_tAcc_final = thr_copy_r2t.partition_D(tAcc_final_epi) + # (R2T, R2T_M, R2T_N, EPI_M, EPI_N, loopM, loopN, loopL) + tRT_rAcc_final_ = cute.make_rmem_tensor( + tRT_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype + ) + # (R2T, R2T_M, R2T_N, (EPI_M, EPI_N)) + tRT_rAcc_final = cute.group_modes( + tRT_rAcc_final_, 3, cute.rank(tRT_rAcc_final_) + ) + + return ( + tiled_copy_t2r, + tiled_copy_r2t, + tTR_tAcc, + tTR_rAcc, + tTR_rAcc_final, + tTR_sSFA, + tTR_sSFB, + tRT_rAcc_final, + tRT_tAcc_final, + ) + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing : + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sfa_dtype: Type[cutlass.Numeric], + sfb_dtype: Type[cutlass.Numeric], + sfa_count: int, + sfb_count: int, + num_smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout of operand C. + :type c_layout: utils.LayoutEnum + :param num_smem_capacity: Total available shared memory capacity in bytes. + :type num_smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 3 if mma_tiler_mnk[0] / tiled_mma.thr_id.shape == 128 else 6 + + # Default C stages + num_c_stage = 2 + + # Default ScaleA/B stages + num_scale_stage = 10 + + # Default Tile info stages + num_tile_stage = 2 + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + # 1024B alignment + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + sfa_bytes = sfa_count * (sfa_dtype.width // 8) * num_scale_stage + sfb_bytes = sfb_count * (sfb_dtype.width // 8) * num_scale_stage + scale_bytes = math.ceil((sfa_bytes + sfb_bytes) / 1024) * 1024 + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + num_smem_capacity // occupancy + - (mbar_helpers_bytes + c_bytes + scale_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + num_c_stage += ( + num_smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes + scale_bytes) + ) // (occupancy * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage, num_scale_stage, num_tile_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _get_tma_atom_kind( + atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean + ) -> Union[ + cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp + ]: + """ + Select the appropriate TMA copy atom based on the number of SMs and the multicast flag. + + :param atom_sm_cnt: The number of SMs + :type atom_sm_cnt: cutlass.Int32 + :param mcast: The multicast flag + :type mcast: cutlass.Boolean + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + """ + if atom_sm_cnt == 2 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 2 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + + raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}") + + @staticmethod + def is_valid_dtypes( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if ab_dtype not in { + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if acc_dtype not in {cutlass.Float32}: + is_valid = False + if c_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.BFloat16}: + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + # Skip invalid mma tile n + if mma_tiler_mn[1] not in (128,): + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not BlockwiseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not BlockwiseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not BlockwiseGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip unsupported A/B layout + if not (a_major == "k" and b_major == "k"): + can_implement = False + return can_implement + + +def create_tensors( + l, m, n, k, a_major, b_major, cd_major, ab_dtype, c_dtype, scale_dtype +): + torch.manual_seed(1111) + + a_torch_cpu = cutlass_torch.matrix(l, m, k, a_major == "m", ab_dtype) + b_torch_cpu = cutlass_torch.matrix(l, n, k, b_major == "n", ab_dtype) + c_torch_cpu = cutlass_torch.matrix(l, m, n, cd_major == "m", c_dtype) + sfa_torch_cpu = cutlass_torch.matrix(l, m, math.ceil(k / 128), True, scale_dtype) + sfb_torch_cpu = cutlass_torch.matrix( + l, math.ceil(n / 128), math.ceil(k / 128), False, scale_dtype + ) + + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like( + c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + sfa_tensor, _ = cutlass_torch.cute_tensor_like( + sfa_torch_cpu, scale_dtype, is_dynamic_layout=True, assumed_align=16 + ) + sfb_tensor, _ = cutlass_torch.cute_tensor_like( + sfb_torch_cpu, scale_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + return ( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + sfa_torch_cpu, + sfb_torch_cpu, + c_torch_gpu, + ) + + +def run( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + scale_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_2cta_instrs: bool, + tolerance: float, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + """ + Prepare A/B/C tensors, launch GPU kernel, and reference checking. + """ + print("Running Blackwell Persistent Dense Blockwise GEMM test with:") + print(f"mnkl: {mnkl}") + print( + f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}, Scale dtype: {scale_dtype}" + ) + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}") + print(f"Use TMA Store: {'True'}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + + # Unpack parameters + m, n, k, l = mnkl + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + if not BlockwiseGemmKernel.can_implement( + ab_dtype, + acc_dtype, + c_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + ( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + sfa_torch_cpu, + sfb_torch_cpu, + c_torch_gpu, + ) = create_tensors( + l, m, n, k, a_major, b_major, c_major, ab_dtype, c_dtype, scale_dtype + ) + # Configure gemm kernel + gemm = BlockwiseGemmKernel( + acc_dtype, use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ) + + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + # Compile gemm kernel + compiled_gemm = cute.compile( + gemm, + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + max_active_clusters, + current_stream, + ) + + # Execution + compiled_gemm( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + current_stream, + ) + + torch.cuda.synchronize() + + # Compute reference result + if not skip_ref_check: + # update + def pad_and_multiply(scale, tensor): + cm, ck, _ = scale.shape + m, k, _ = tensor.shape + IsGroupWise = False + IsBlockWise = False + if ck == math.ceil(k / 128): + IsGroupWise = True + if cm == math.ceil(m / 128): + IsBlockWise = True + if not IsBlockWise and not IsGroupWise: + raise ValueError("Only support granularity = 128") + + k_idx = torch.arange(k, device=scale.device) + if IsGroupWise: + k_idx = k_idx // 128 + m_idx = torch.arange(m, device=scale.device) + if IsBlockWise: + m_idx = m_idx // 128 + expanded_scale = scale[m_idx[:, None], k_idx, :] + + result = expanded_scale * tensor + + return result + + updated_a = pad_and_multiply(sfa_torch_cpu, a_torch_cpu) + updated_b = pad_and_multiply(sfb_torch_cpu, b_torch_cpu) + + ref = torch.einsum("mkl,nkl->mnl", updated_a, updated_b).to( + cutlass_torch.dtype(c_dtype) + ) + res = c_torch_gpu.view(cutlass_torch.dtype(c_dtype)) + + torch.testing.assert_close(res.cpu(), ref.cpu(), atol=tolerance, rtol=1e-03) + + def generate_tensors(): + # Reuse existing CPU reference tensors and create new GPU tensors from them + ( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + sfa_torch_cpu, + sfb_torch_cpu, + c_torch_gpu, + ) = create_tensors( + l, m, n, k, a_major, b_major, c_major, ab_dtype, c_dtype, scale_dtype + ) + return testing.JitArguments( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch_cpu.numel() * a_torch_cpu.element_size() + + b_torch_cpu.numel() * b_torch_cpu.element_size() + + c_torch_cpu.numel() * c_torch_cpu.element_size() + + sfa_torch_cpu.numel() * sfa_torch_cpu.element_size() + + sfb_torch_cpu.numel() * sfb_torch_cpu.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of Dense Persistent GEMM on Blackwell." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(256, 256, 512, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float8E4M3FN) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.BFloat16) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument("--scale_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument("--a_major", choices=["k"], type=str, default="k") + parser.add_argument("--b_major", choices=["k"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", action="store_true", default=False, help="Use cold L2" + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run( + args.mnkl, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.scale_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.use_2cta_instrs, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/blockwise_gemm/contiguous_grouped_gemm.py b/examples/python/CuTeDSL/blackwell/blockwise_gemm/contiguous_grouped_gemm.py new file mode 100644 index 00000000..47d89f0d --- /dev/null +++ b/examples/python/CuTeDSL/blackwell/blockwise_gemm/contiguous_grouped_gemm.py @@ -0,0 +1,3056 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Type, Tuple, Union + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack + +import math +import random + + +""" +High-performance persistent blockwise contiguous grouped dense GEMM (C = (SFA * A) * (SFB * B)) example for the NVIDIA Blackwell architecture +using CUTE DSL. +- Matrix A is MxKx1, A can be row-major("K"), ValidM is composed of valid m in different groups +- Matrix B is NxKxL, B can be column-major("K"), L is grouped dimension +- Matrix C is MxNx1, C can be row-major("N"), ValidM is composed of valid m in different groups +- Each block will apply the scale factor SFA +- Each row will apply the scale factor SFB +- For each iteration, the kernel will compute C = A * B and then apply the scale factor C *= SFA * SFB + +Matrix A/C Memory Layout Diagrams: + + ``` + Group 0 Group 1 Group 2 + -+---------+---------+---------+ + | | | | + K| ValidM0 | ValidM1 | ValidM2 | + | | | | + -+---------+---------+---------+ + |<- ValidM ->| + ``` + Note: the Group(L) dimension will be flatted into M dimension, and the rest Group(L) size is 1. + each ValidM will be aligned to 128. + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. SCALE warp: Load scaleA and scaleB matrices from global memory (GMEM) to shared memory (SMEM) using non-TMA operations. +2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. EPILOGUE warp: + - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Apply the scale factor and update the final accumulator Final = C * SFA * SFB + Final + - Type convert Final matrix to output type. + - Store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations. + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +.. code-block:: bash + + python examples/blackwell/blockwise_gemm/contiguous_grouped_gemm.py \ + --ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \ + --scale_dtype Float32 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \ + --mnkl 256,4096,4096,16 + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/blockwise_gemm/contiguous_grouped_gemm.py \ + --ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \ + --scale_dtype Float32 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \ + --mnkl 256,4096,4096,16 + + +Constraints are same as dense_gemm.py: +* Supported input data types: fp8 (e4m3fn) + see detailed valid dtype combinations in below BlockwiseContiguousGroupedGemmKernel class documentation +* A/B tensor must have the same data type +* Mma tiler M must be 64/128/256 +* Mma tiler N must be 128, align with the scaleB requirement +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned +""" + + +class BlockwiseContiguousGroupedGemmKernel: + """This class implements batched matrix multiplication (C = (SFA * A) * (SFB * B)) with support for fp8 (e4m3fn, e5m2) + and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: Supported A/B data types: + - Float8E4M3FN + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float16/BFloat16 + - Other data types are not supported for accuracy issues + + :note: Constraints: + - MMA tiler M must be 64/128/256 + - MMA tiler N must be 128 + - Cluster shape M must be multiple of 2 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = BlockwiseContiguousGroupedGemmKernel( + ... acc_dtype=cutlass.Float32, + ... use_2cta_instrs=True, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, sfa_tensor, sfb_tensor, max_active_clusters, stream) + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ): + """Initializes the configuration for a Blackwell blockwise dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = ( + tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + # Set specialized warp ids + self.acc_update_warp_id = (0, 1, 2, 3) + self.epilog_warp_id = (4, 5, 6, 7) + self.mma_warp_id = 8 + self.tma_warp_id = 9 + self.scale_warp_id = 10 + self.sched_warp_id = 11 + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( + ( + *self.acc_update_warp_id, + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.scale_warp_id, + self.sched_warp_id, + ) + ) + self.threads_wo_sched = self.threads_per_warp * len( + ( + *self.acc_update_warp_id, + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.scale_warp_id, + ) + ) + self.num_regs_uniform_warps = 64 + self.num_regs_sched_warps = 64 + self.num_regs_epilogue_warps = 216 + self.num_regs_acc_update_warps = 216 + + # Set barrier for cta sync, epilogue sync and tmem ptr sync + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=32 + * len((self.mma_warp_id, *self.epilog_warp_id, *self.acc_update_warp_id)), + ) + self.sched_sync_barrier = pipeline.NamedBarrier( + barrier_id=4, + num_threads=self.threads_per_warp, + ) + self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + # TMEM offset for final accumulator + self.tmem_final_offset = 384 + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + self.scale_granularity_m = 1 + self.scale_granularity_n = 128 + self.scale_granularity_k = 128 + self.scale_m_per_tile = self.cta_tile_shape_mnk[0] // self.scale_granularity_m + self.scale_n_per_tile = self.cta_tile_shape_mnk[1] // self.scale_granularity_n + self.scale_k_per_tile = self.cta_tile_shape_mnk[2] // self.scale_granularity_k + + if self.scale_k_per_tile != 1: + raise ValueError("scale_k_per_tile must be 1") + if self.scale_m_per_tile != self.cta_tile_shape_mnk[0]: + raise ValueError("scale_m_per_tile must be cta_tile_m") + if self.scale_n_per_tile != 1: + raise ValueError("scale_n_per_tile must be 1") + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # Setup A/B/C/Scale stage count in shared memory and ACC stage count in tensor memory + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_c_stage, + self.num_scale_stage, + self.num_tile_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sfa_dtype, + self.sfb_dtype, + self.scale_m_per_tile * self.scale_k_per_tile, + self.scale_n_per_tile * self.scale_k_per_tile, + self.num_smem_capacity, + self.occupancy, + ) + + # Compute A/B/C/Scale shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + self.sfa_smem_layout_staged = cute.make_layout( + ( + (self.scale_granularity_m, self.scale_m_per_tile), + (self.scale_granularity_k, self.scale_k_per_tile), + self.num_scale_stage, + ), + stride=( + (0, self.scale_k_per_tile), + (0, 1), + self.scale_k_per_tile * self.scale_m_per_tile, + ), + ) + self.sfb_smem_layout_staged = cute.make_layout( + ( + (self.scale_granularity_n, self.scale_n_per_tile), + (self.scale_granularity_k, self.scale_k_per_tile), + self.num_scale_stage, + ), + stride=( + (0, self.scale_k_per_tile), + (0, 1), + self.scale_k_per_tile * self.scale_n_per_tile, + ), + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = 512 + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + sfa: cute.Tensor, + sfb: cute.Tensor, + gidx_mapping: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param sfa: Scale factor tensor A + :type sfa: cute.Tensor + :param sfb: Scale factor tensor B + :type sfb: cute.Tensor + :param gidx_mapping: Mapping from m index to group index + :type gidx_mapping: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.sfa_dtype: Type[cutlass.Numeric] = sfa.element_type + self.sfb_dtype: Type[cutlass.Numeric] = sfb.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if a.element_type is cutlass.Float32 else None + ), + ) + + # Setup TMA load for B + b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup TMA store for C + tma_atom_c = None + tma_tensor_c = None + c_cta_v_layout = cute.composition( + cute.make_identity_layout(c.shape), self.epi_tile + ) + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + c_cta_v_layout, + ) + + tensor_sfa = cute.make_tensor( + sfa.iterator, + cute.make_layout( + ( + (self.scale_granularity_m, sfa.shape[0]), + (self.scale_granularity_k, sfa.shape[1]), + sfa.shape[2], + ), + stride=( + (0, sfa.layout.stride[0]), + (0, sfa.layout.stride[1]), + sfa.layout.stride[2], + ), + ), + ) + tensor_sfb = cute.make_tensor( + sfb.iterator, + cute.make_layout( + ( + (self.scale_granularity_n, sfb.shape[0]), + (self.scale_granularity_k, sfb.shape[1]), + sfb.shape[2], + ), + stride=( + (0, sfb.layout.stride[0]), + (0, sfb.layout.stride[1]), + sfb.layout.stride[2], + ), + ), + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + + c_smem_size = cute.cosize(self.c_smem_layout_staged.outer) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + # (bidx, bidy, bidz, valid) + sInfo: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int32, 4 * self.num_tile_stage], + # 1 byte alignment + 1, + ] + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + scale_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_scale_stage * 2 + ] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tile_info_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_tile_stage * 2 + ] + epi_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + c_smem_size, + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA: cute.struct.Align[ + cute.struct.MemRange[ + self.sfa_dtype, cute.cosize(self.sfa_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB: cute.struct.Align[ + cute.struct.MemRange[ + self.sfb_dtype, cute.cosize(self.sfb_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + tensor_sfa, + tensor_sfb, + gidx_mapping, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + mSFA_mkl: cute.Tensor, + mSFB_nkl: cute.Tensor, + gidx_mapping: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + lane_idx = cute.arch.lane_idx() + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize mainloop scale_pipeline (barrier) and states + scale_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + ) + scale_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_warp_id), + ) + scale_pipeline = pipeline.PipelineCpAsync.create( + barrier_storage=storage.scale_mbar_ptr.data_ptr(), + num_stages=self.num_scale_stage, + producer_group=scale_pipeline_producer_group, + consumer_group=scale_pipeline_consumer_group, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * ( + 2 if use_2cta_instrs else 1 + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize epilogue pipeline (barrier) and states + epi_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.acc_update_warp_id), + ) + epi_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_warp_id), + ) + epi_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.epi_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=epi_pipeline_producer_group, + consumer_group=epi_pipeline_consumer_group, + ) + + # Initialize tile info pipeline (barrier) and states + tile_info_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + ) + tile_info_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_wo_sched, + ) + tile_info_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.tile_info_mbar_ptr.data_ptr(), + num_stages=self.num_tile_stage, + producer_group=tile_info_pipeline_producer_group, + consumer_group=tile_info_pipeline_consumer_group, + ) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/C/Scale + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + # (bidx, bidy, bidz, valid) + info_layout = cute.make_layout((4, self.num_tile_stage), stride=(1, 4)) + sInfo = storage.sInfo.get_tensor(info_layout) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, loopM, loopK, loopL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, loopN, loopK, loopL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, loopM, loopN, loopL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + # (bM, bK, loopM, loopK, loopL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, + cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + # (bN, bK, loopN, loopK, loopL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, + cute.slice_(self.cta_tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + # coordinate + cSFA_mkl = cute.make_identity_tensor(cute.shape(mSFA_mkl)) + cSFB_nkl = cute.make_identity_tensor(cute.shape(mSFB_nkl)) + # (bM, bK, loopM, loopK, loopL) + cSFA = cute.local_tile( + cSFA_mkl, + cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + # (bN, bK, loopN, loopK, loopL) + cSFB = cute.local_tile( + cSFB_nkl, + cute.slice_(self.cta_tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) + tCgC = thr_mma.partition_C(gC_mnl) + + # scale viewed as C tensor + sSFA_view_as_C_layout = cute.make_layout( + ( + (self.scale_granularity_m, self.scale_m_per_tile), + self.cta_tile_shape_mnk[1], + self.num_scale_stage, + ), + stride=((0, 1), 0, self.scale_m_per_tile), + ) + sSFB_view_as_C_layout = cute.make_layout( + ( + self.cta_tile_shape_mnk[0], + (self.scale_granularity_n, self.scale_n_per_tile), + self.num_scale_stage, + ), + stride=(0, (0, 1), self.scale_n_per_tile), + ) + sSFA_view_as_C = cute.make_tensor(sSFA.iterator, sSFA_view_as_C_layout) + sSFB_view_as_C = cute.make_tensor(sSFB.iterator, sSFB_view_as_C_layout) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition global/shared tensor for TMA load A/B + # + # load scaleA/scaleB + atom_copy = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mSFA_mkl.element_type, + num_bits_per_copy=mSFA_mkl.element_type.width, + ) + tiled_copy_sfa = cute.make_tiled_copy_tv( + atom_copy, cute.make_layout((32,)), cute.make_layout((1,)) + ) + tiled_copy_sfb = cute.make_tiled_copy_tv( + atom_copy, cute.make_layout((32,)), cute.make_layout((1,)) + ) + thr_copy_sfa = tiled_copy_sfa.get_slice(lane_idx) + thr_copy_sfb = tiled_copy_sfb.get_slice(lane_idx) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAgSFA_mkl = thr_copy_sfa.partition_S(gSFA_mkl) + tAsSFA = thr_copy_sfa.partition_D(sSFA) + tAcSFA = thr_copy_sfa.partition_S(cSFA) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopN, loopK, loopL) + tBgSFB_nkl = thr_copy_sfb.partition_S(gSFB_nkl) + tBsSFB = thr_copy_sfb.partition_D(sSFB) + tBcSFB = thr_copy_sfb.partition_S(cSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + self.cta_sync_barrier.arrive_and_wait() + + # + # Specialized Schedule warp + # + if warp_idx == self.sched_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_sched_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + tile_info_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_tile_stage + ) + + while work_tile.is_valid_tile: + # query next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # acquire tile info pipeline + tile_info_pipeline.producer_acquire(tile_info_producer_state) + + # get the group info + cur_tile_coord = work_tile.tile_idx + gidx = 0 + if work_tile.is_valid_tile: + gidx = gidx_mapping[cur_tile_coord[0] * self.cta_tile_shape_mnk[0]] + + # store the tile info + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0] + sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1] + sInfo[(2, tile_info_producer_state.index)] = gidx + sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32( + work_tile.is_valid_tile + ) + + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.sched_sync_barrier.arrive_and_wait() + # commit tile info pipeline + tile_info_pipeline.producer_commit(tile_info_producer_state) + + # advance to next tile + tile_info_producer_state.advance() + + tile_info_pipeline.producer_tail(tile_info_producer_state) + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + gidx = gidx_mapping[cur_tile_coord[0] * self.cta_tile_shape_mnk[0]] + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = gidx + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), loopK) + tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None, 0)] + # ((atom_v, rest_v), loopK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + tAgA_k = tAgA_slice[(None, ab_producer_state.count)] + tBgB_k = tBgB_slice[(None, ab_producer_state.count)] + tAsA_pipe = tAsA[(None, ab_producer_state.index)] + tBsB_pipe = tBsB[(None, ab_producer_state.index)] + + tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state) + + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + # + # Specialized Scale load warp + # + if warp_idx == self.scale_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + scale_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_scale_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + gidx = gidx_mapping[cur_tile_coord[0] * self.cta_tile_shape_mnk[0]] + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = gidx + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # + # Prepare the mask for scaleA/scaleB + # + tApSFA = cute.make_rmem_tensor( + cute.make_layout( + cute.filter_zeros( + cute.slice_(tAsSFA, (None, None, None, 0)) + ).shape + ), + cutlass.Boolean, + ) + tBpSFB = cute.make_rmem_tensor( + cute.make_layout( + cute.filter_zeros( + cute.slice_(tBsSFB, (None, None, None, 0)) + ).shape + ), + cutlass.Boolean, + ) + + # Peek (try_wait) SCALE buffer empty + scale_producer_state.reset_count() + peek_scale_empty_status = cutlass.Boolean(1) + if scale_producer_state.count < k_tile_cnt: + peek_scale_empty_status = scale_pipeline.producer_try_acquire( + scale_producer_state + ) + + # + # load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # + # Slice to per mma tile index + # + tAsSFA_pipe = cute.filter_zeros( + tAsSFA[(None, None, None, scale_producer_state.index)] + ) + tBsSFB_pipe = cute.filter_zeros( + tBsSFB[(None, None, None, scale_producer_state.index)] + ) + tAgSFA_k = cute.filter_zeros( + tAgSFA_mkl[ + ( + None, + None, + None, + tile_info[0], + scale_producer_state.count, + 0, + ) + ] + ) + tBgSFB_k = cute.filter_zeros( + tBgSFB_nkl[ + ( + None, + None, + None, + tile_info[1], + scale_producer_state.count, + tile_info[2], + ) + ] + ) + + tAcSFA_compact = cute.filter_zeros( + cute.slice_( + tAcSFA, + ( + None, + None, + None, + tile_info[0], + scale_producer_state.count, + 0, + ), + ) + ) + tBcSFB_compact = cute.filter_zeros( + cute.slice_( + tBcSFB, + ( + None, + None, + None, + tile_info[1], + scale_producer_state.count, + tile_info[2], + ), + ) + ) + for i in cutlass.range_constexpr(cute.size(tApSFA, mode=[1])): + tApSFA[((0, 0), i, (0, 0))] = cute.elem_less( + tAcSFA_compact[(i)][0], mSFA_mkl.shape[0] + ) + for i in cutlass.range_constexpr(cute.size(tBpSFB, mode=[1])): + tBpSFB[((0, 0), i, (0, 0))] = cute.elem_less( + tBcSFB_compact[(i)][0], mSFB_nkl.shape[0] + ) + + # Conditionally wait for Scale buffer empty + scale_pipeline.producer_acquire( + scale_producer_state, peek_scale_empty_status + ) + + # load scaleA/scaleB + cute.copy(tiled_copy_sfa, tAgSFA_k, tAsSFA_pipe, pred=tApSFA) + cute.copy(tiled_copy_sfb, tBgSFB_k, tBsSFB_pipe, pred=tBpSFB) + + scale_pipeline.producer_commit(scale_producer_state) + + # Peek (try_wait) Scale buffer empty + scale_producer_state.advance() + peek_scale_empty_status = cutlass.Boolean(1) + if scale_producer_state.count < k_tile_cnt: + peek_scale_empty_status = scale_pipeline.producer_try_acquire( + scale_producer_state + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait Scale buffer empty + # + scale_pipeline.producer_tail(scale_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # MMA warp don't care about gidx + gidx = 0 + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = gidx + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # Peek (try_wait) Acc buffer empty for k_tile = 0 + acc_producer_state.reset_count() + peek_acc_empty_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire( + acc_producer_state + ) + + # + # Mma mainloop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire( + acc_producer_state, peek_acc_empty_status + ) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait( + ab_consumer_state, peek_ab_full_status + ) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Async arrive accumulator buffer full(each kblock) + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + + # Peek (try_wait) Acc buffer empty for k_tile = k_tile + 1 + acc_producer_state.advance() + if acc_producer_state.count < k_tile_cnt: + if is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire( + acc_producer_state + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + + # + # Specialized acc update warps + # + if warp_idx <= self.acc_update_warp_id[-1]: + cute.arch.warpgroup_reg_alloc(self.num_regs_acc_update_warps) + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + tCtAcc_final = cute.make_tensor( + tCtAcc_base.iterator + self.tmem_final_offset, tCtAcc_base.layout + ) + + # + # Partition for epilogue + # + epi_tidx = tidx % 128 + ( + tiled_copy_t2r, + tiled_copy_r2t, + tTR_tAcc_base, + tTR_rAcc, + tTR_rAcc_final, + tTR_sSFA, + tTR_sSFB, + tRT_rAcc, + tRT_tAcc_base, + ) = self.acc_update_tmem_copy_and_partition( + epi_tidx, + tCtAcc_base, + tCtAcc_final, + tCgC, + sSFA_view_as_C, + sSFB_view_as_C, + epi_tile, + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + scale_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_scale_stage + ) + + epi_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, 1 + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # Acc update warp don't care about gidx + gidx = 0 + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = gidx + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # initialize the final accumulator + tTR_rAcc_final.fill(0.0) + + tTR_rSFA = cute.make_rmem_tensor( + cute.slice_(tTR_sSFA, (None, None, None, 0, None, 0)).shape, + self.acc_dtype, + ) + tTR_rSFB = cute.make_rmem_tensor( + cute.slice_(tTR_sSFB, (None, None, None, 0, None, 0)).shape, + self.acc_dtype, + ) + + scale_consumer_state.reset_count() + peek_scale_full_status = cutlass.Boolean(1) + if scale_consumer_state.count < k_tile_cnt: + peek_scale_full_status = scale_pipeline.consumer_try_wait( + scale_consumer_state + ) + + acc_consumer_state.reset_count() + peek_acc_full_status = cutlass.Boolean(1) + if acc_consumer_state.count < k_tile_cnt: + peek_acc_full_status = acc_pipeline.consumer_try_wait( + acc_consumer_state + ) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # + # Wait for scale buffer full + # + scale_pipeline.consumer_wait( + scale_consumer_state, peek_scale_full_status + ) + + tTR_sSFA_slice = cute.slice_( + tTR_sSFA, + (None, None, None, 0, None, scale_consumer_state.index), + ) + tTR_sSFB_slice = cute.slice_( + tTR_sSFB, + (None, None, None, 0, None, scale_consumer_state.index), + ) + + scale_atom_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.acc_dtype, + num_bits_per_copy=self.acc_dtype.width, + ) + + cute.copy(scale_atom_copy, tTR_sSFA_slice, tTR_rSFA) + cute.copy(scale_atom_copy, tTR_sSFB_slice, tTR_rSFB) + + # + # Wait for accumulator buffer full + # + + acc_pipeline.consumer_wait(acc_consumer_state, peek_acc_full_status) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + # + # Update accumulator by scale factor in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Update accumulator by scale factor + # + tTR_rAcc_subtile = tTR_rAcc_final[ + (None, None, None, subtile_idx) + ] + tTR_rSFA_subtile = tTR_rSFA[(None, None, None, subtile_idx)] + tTR_rSFB_subtile = tTR_rSFB[(None, None, None, subtile_idx)] + + acc_vec = tTR_rAcc.load() + final_vec = tTR_rAcc_subtile.load() + scale_a = tTR_rSFA_subtile.load() + scale_b = tTR_rSFB_subtile.load() + scale = scale_a * scale_b + final_vec = acc_vec * scale + final_vec + tTR_rAcc_subtile.store(final_vec.to(self.acc_dtype)) + + # + # Async arrive accumulator buffer empty + # + scale_pipeline.consumer_release(scale_consumer_state) + scale_consumer_state.advance() + + peek_scale_full_status = cutlass.Boolean(1) + if scale_consumer_state.count < k_tile_cnt: + peek_scale_full_status = scale_pipeline.consumer_try_wait( + scale_consumer_state + ) + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + peek_acc_full_status = cutlass.Boolean(1) + if acc_consumer_state.count < k_tile_cnt: + peek_acc_full_status = acc_pipeline.consumer_try_wait( + acc_consumer_state + ) + + tRT_tAcc = tRT_tAcc_base[(None, None, None, None, None, 0)] + tRT_tAcc = cute.group_modes(tRT_tAcc, 3, cute.rank(tRT_tAcc)) + + # + # Wait for epilogue buffer empty + # + epi_pipeline.producer_acquire(epi_producer_state) + + # copy the accumulator to tensor memory buffer + cute.copy(tiled_copy_r2t, tTR_rAcc_final, tRT_tAcc) + cute.arch.fence_view_async_tmem_store() + + # + # Async arrive epilogue buffer full + # + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Specialized epilogue warps + # + if warp_idx <= self.epilog_warp_id[-1] and warp_idx >= self.epilog_warp_id[0]: + cute.arch.warpgroup_reg_alloc(self.num_regs_epilogue_warps) + # + # Alloc tensor memory buffer + # + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base_ = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + tCtAcc_final = cute.make_tensor( + tCtAcc_base_.iterator + self.tmem_final_offset, tCtAcc_base_.layout + ) + + # + # Partition for epilogue + # + epi_tidx = tidx % 128 + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_final, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_r2s = None + simt_atom = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC_partitioned = None + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + epi_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, 1 + ) + + c_pipeline = None + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # Epilogue warp don't care about gidx + gidx = 0 + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = gidx + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + num_prev_subtiles = cutlass.Int32(0) + + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Slice to per mma tile index + # + bSG_gC = None + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + mma_tile_coord_mnl[0], + mma_tile_coord_mnl[1], + 0, + ) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, epi_consumer_state.index) + ] + + # + # Wait for accumulator buffer full + # + epi_pipeline.consumer_wait(epi_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + num_prev_subtiles = num_prev_subtiles + 1 + c_buffer = num_prev_subtiles % self.num_c_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.epilog_sync_barrier.arrive_and_wait() + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + + # + # Async arrive accumulator buffer empty + # + epi_pipeline.consumer_release(epi_consumer_state) + epi_consumer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + # + # Wait for C store complete + # + c_pipeline.producer_tail() + + def acc_update_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + tAcc_final: cute.Tensor, + gC_mnl: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + epi_tile: cute.Tile, + ) -> Tuple[ + cute.TiledCopy, + cute.TiledCopy, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + ]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + Make tiledCopy for tensor memory store, then use it to partition register array (source) and tensor memory (destination). + Partition the scale factor tensor for related copy operations. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param tAcc_final: The final accumulator tensor to be copied and partitioned + :type tAcc_final: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param sSFA: The scale factor tensor for A + :type sSFA: cute.Tensor + :param sSFB: The scale factor tensor for B + :type sSFB: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tiled_copy_r2t: The tiled copy operation for register to tmem copy(r2t) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + - tTR_rAcc_final: The accumulated tensor in register used to hold all t2r results + - tTR_sSFA: The partitioned tensor SFA by tiled_copy_t2r + - tTR_sSFB: The partitioned tensor SFB by tiled_copy_t2r + - tRT_rAcc_final: The accumulated tensor in register used to hold all r2t results + - tRT_tAcc_final: The partitioned accumulator tensor by tiled_copy_r2t + :rtype: Tuple[cute.TiledCopy, cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + tmem_load_atom = None + tmem_store_atom = None + if cutlass.const_expr(self.mma_tiler[0] == 64): + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)), + self.acc_dtype, + ) + elif cutlass.const_expr(self.mma_tiler[0] == 128): + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.acc_dtype, + ) + else: + # default: 16dp + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(1)), + self.acc_dtype, + ) + if cutlass.const_expr(self.mma_tiler[0] == 64): + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St16x256bOp(tcgen05.copy.Repetition(8)), + self.acc_dtype, + ) + elif cutlass.const_expr(self.mma_tiler[0] == 128): + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), + self.acc_dtype, + ) + else: + # default: 16dp + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St16x256bOp(tcgen05.copy.Repetition(1)), + self.acc_dtype, + ) + + tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile) + tAcc_final_epi = cute.flat_divide( + tAcc_final[((None, None), 0, 0, None)], epi_tile + ) + + tiled_copy_t2r = tcgen05.make_tmem_copy( + tmem_load_atom, tAcc_epi[(None, None, 0, 0, 0)] + ) + tiled_copy_r2t = tcgen05.make_tmem_copy( + tmem_store_atom, tAcc_final_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + thr_copy_r2t = tiled_copy_r2t.get_slice(tidx) + + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + sSFA_epi = cute.flat_divide(sSFA, epi_tile) + sSFB_epi = cute.flat_divide(sSFB, epi_tile) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + tTR_sSFA = thr_copy_t2r.partition_D(sSFA_epi) + tTR_sSFB = thr_copy_t2r.partition_D(sSFB_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rAcc_final_ = cute.make_rmem_tensor( + tTR_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype + ) + tTR_rAcc_final = cute.group_modes( + tTR_rAcc_final_, 3, cute.rank(tTR_rAcc_final_) + ) + + tRT_gC = thr_copy_r2t.partition_S(gC_mnl_epi) + tRT_tAcc_final = thr_copy_r2t.partition_D(tAcc_final_epi) + # (R2T, R2T_M, R2T_N, EPI_M, EPI_N, loopM, loopN, loopL) + tRT_rAcc_final_ = cute.make_rmem_tensor( + tRT_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype + ) + # (R2T, R2T_M, R2T_N, (EPI_M, EPI_N)) + tRT_rAcc_final = cute.group_modes( + tRT_rAcc_final_, 3, cute.rank(tRT_rAcc_final_) + ) + + return ( + tiled_copy_t2r, + tiled_copy_r2t, + tTR_tAcc, + tTR_rAcc, + tTR_rAcc_final, + tTR_sSFA, + tTR_sSFB, + tRT_rAcc_final, + tRT_tAcc_final, + ) + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing : + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sfa_dtype: Type[cutlass.Numeric], + sfb_dtype: Type[cutlass.Numeric], + sfa_count: int, + sfb_count: int, + num_smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout of operand C. + :type c_layout: utils.LayoutEnum + :param num_smem_capacity: Total available shared memory capacity in bytes. + :type num_smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 3 if mma_tiler_mnk[0] / tiled_mma.thr_id.shape == 128 else 6 + + # Default C stages + num_c_stage = 2 + + # Default ScaleA/B stages + num_scale_stage = 10 + + # Default Tile info stages + num_tile_stage = 2 + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + # 1024B alignment + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + sfa_bytes = sfa_count * (sfa_dtype.width // 8) * num_scale_stage + sfb_bytes = sfb_count * (sfb_dtype.width // 8) * num_scale_stage + scale_bytes = math.ceil((sfa_bytes + sfb_bytes) / 1024) * 1024 + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + num_smem_capacity // occupancy + - (mbar_helpers_bytes + c_bytes + scale_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + num_c_stage += ( + num_smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes + scale_bytes) + ) // (occupancy * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage, num_scale_stage, num_tile_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _get_tma_atom_kind( + atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean + ) -> Union[ + cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp + ]: + """ + Select the appropriate TMA copy atom based on the number of SMs and the multicast flag. + + :param atom_sm_cnt: The number of SMs + :type atom_sm_cnt: cutlass.Int32 + :param mcast: The multicast flag + :type mcast: cutlass.Boolean + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + """ + if atom_sm_cnt == 2 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 2 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + + raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}") + + @staticmethod + def is_valid_dtypes( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if ab_dtype not in { + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if acc_dtype not in {cutlass.Float32}: + is_valid = False + if c_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.BFloat16}: + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + # Skip invalid mma tile n + if mma_tiler_mn[1] not in (128,): + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + cluster_tiler_m = ( + cluster_shape_mn[0] // (2 if use_2cta_instrs else 1) + ) * mma_tiler_mn[0] + # Skip invalid cluster tiler shape since contiguous layout can't handle oob access + # The contiguous layout means the aligned data is stored in a contiguous manner. + # It can't handle runtime oob when alignment is not align with the tile_M, + # since the problem shape of TMA store can't be changed at runtime. + if cluster_tiler_m not in [64, 128]: + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not BlockwiseContiguousGroupedGemmKernel.is_valid_dtypes( + ab_dtype, acc_dtype, c_dtype + ): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not BlockwiseContiguousGroupedGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not BlockwiseContiguousGroupedGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip unsupported A/B layout + if not (a_major == "k" and b_major == "k"): + can_implement = False + return can_implement + + +def create_mask(num_groups, expect_m, fixed_m=False, m_aligned=128): + valid_m = 0 + group_m_list = [] + gidx_mapping = [] + # initialize + for i in range(num_groups): + if fixed_m: + # fixed_m for perf testing + group_m = 128 + else: + group_m = m_aligned * random.randint( + int(expect_m * 0.7) // m_aligned, int(expect_m * 1.3) // m_aligned + ) + valid_m += group_m + # handle the case that valid_m == 0 + if (i == num_groups - 1) and (valid_m == 0): + group_m = m_aligned + valid_m += group_m + group_m_list.append(group_m) + gidx_mapping.extend([i] * group_m) + + gidx_mapping = torch.tensor(gidx_mapping, device="cuda", dtype=torch.int32) + return valid_m, group_m_list, gidx_mapping + + +def create_tensors( + l, + m, + n, + k, + a_major, + b_major, + cd_major, + ab_dtype, + c_dtype, + scale_dtype, + fixed_m=False, +): + torch.manual_seed(1111) + + valid_m, group_m_list, _gidx_mapping = create_mask(l, m, fixed_m) + + a_torch_cpu = cutlass_torch.matrix(1, valid_m, k, a_major == "m", ab_dtype) + b_torch_cpu = cutlass_torch.matrix(l, n, k, b_major == "n", ab_dtype) + c_torch_cpu = cutlass_torch.matrix(1, valid_m, n, cd_major == "m", c_dtype) + sfa_torch_cpu = cutlass_torch.matrix( + 1, valid_m, math.ceil(k / 128), True, scale_dtype + ) + sfb_torch_cpu = cutlass_torch.matrix( + l, math.ceil(n / 128), math.ceil(k / 128), False, scale_dtype + ) + + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like( + c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + sfa_tensor, _ = cutlass_torch.cute_tensor_like( + sfa_torch_cpu, scale_dtype, is_dynamic_layout=True, assumed_align=16 + ) + sfb_tensor, _ = cutlass_torch.cute_tensor_like( + sfb_torch_cpu, scale_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + gidx_mapping = from_dlpack(_gidx_mapping).mark_layout_dynamic() + + return ( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + gidx_mapping, + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + sfa_torch_cpu, + sfb_torch_cpu, + c_torch_gpu, + group_m_list, + valid_m, + ) + + +def run( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + scale_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_2cta_instrs: bool, + tolerance: float, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + fixed_m: bool = False, + **kwargs, +): + """ + Prepare A/B/C tensors, launch GPU kernel, and reference checking. + """ + print("Running Blackwell Persistent Dense Contiguous Grouped GEMM test with:") + print(f"mnkl: {mnkl}") + print( + f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}, Scale dtype: {scale_dtype}" + ) + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}") + print(f"Use TMA Store: {'True'}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + + # Unpack parameters + m, n, k, l = mnkl + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + # Skip unsupported testcase + if not BlockwiseContiguousGroupedGemmKernel.can_implement( + ab_dtype, + acc_dtype, + c_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + ( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + gidx_mapping, + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + sfa_torch_cpu, + sfb_torch_cpu, + c_torch_gpu, + group_m_list, + valid_m, + ) = create_tensors( + l, + m, + n, + k, + a_major, + b_major, + c_major, + ab_dtype, + c_dtype, + scale_dtype, + fixed_m=fixed_m, + ) + # Configure gemm kernel + gemm = BlockwiseContiguousGroupedGemmKernel( + acc_dtype, use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ) + + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + # Compile gemm kernel + compiled_gemm = cute.compile( + gemm, + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + gidx_mapping, + max_active_clusters, + current_stream, + ) + + # Execution + compiled_gemm( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + gidx_mapping, + current_stream, + ) + + torch.cuda.synchronize() + + # Compute reference result + if not skip_ref_check: + # update + def pad_and_multiply(scale, tensor): + cm, ck, _ = scale.shape + m, k, _ = tensor.shape + IsGroupWise = False + IsBlockWise = False + if ck == math.ceil(k / 128): + IsGroupWise = True + if cm == math.ceil(m / 128): + IsBlockWise = True + if not IsBlockWise and not IsGroupWise: + raise ValueError("Only support granularity = 128") + + k_idx = torch.arange(k, device=scale.device) + if IsGroupWise: + k_idx = k_idx // 128 + m_idx = torch.arange(m, device=scale.device) + if IsBlockWise: + m_idx = m_idx // 128 + expanded_scale = scale[m_idx[:, None], k_idx, :] + + result = expanded_scale * tensor + + return result + + updated_a = pad_and_multiply(sfa_torch_cpu, a_torch_cpu) + updated_b = pad_and_multiply(sfb_torch_cpu, b_torch_cpu) + + ref = torch.empty((1, valid_m, n), dtype=torch.float32) + start = 0 + for i, group_m in enumerate(group_m_list): + end = start + group_m + ref[0, start:end, :] = torch.einsum( + "mk,nk->mn", updated_a[start:end, :, 0], updated_b[:, :, i] + ) + start = end + + ref = ref.permute((1, 2, 0)).to(cutlass_torch.dtype(c_dtype)) + res = c_torch_gpu.view(cutlass_torch.dtype(c_dtype)) + + torch.testing.assert_close(res.cpu(), ref.cpu(), atol=tolerance, rtol=1e-03) + + def generate_tensors(): + # Reuse existing CPU reference tensors and create new GPU tensors from them + ( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + gidx_mapping, + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + sfa_torch_cpu, + sfb_torch_cpu, + c_torch_gpu, + group_m_list, + valid_m, + ) = create_tensors( + l, + m, + n, + k, + a_major, + b_major, + c_major, + ab_dtype, + c_dtype, + scale_dtype, + fixed_m=fixed_m, + ) + return testing.JitArguments( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + gidx_mapping, + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch_cpu.numel() * a_torch_cpu.element_size() + + b_torch_cpu.numel() * b_torch_cpu.element_size() + + c_torch_cpu.numel() * c_torch_cpu.element_size() + + sfa_torch_cpu.numel() * sfa_torch_cpu.element_size() + + sfb_torch_cpu.numel() * sfb_torch_cpu.element_size() + + valid_m * 4 # gidx_mapping length (rows) * sizeof(int32) + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of Dense Persistent GEMM on Blackwell." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(256, 256, 512, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float8E4M3FN) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.BFloat16) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument("--scale_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument("--a_major", choices=["k"], type=str, default="k") + parser.add_argument("--b_major", choices=["k"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument("--fixed_m", action="store_true", default=False, help="Fixed M") + parser.add_argument( + "--use_cold_l2", action="store_true", default=False, help="Use cold L2" + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run( + args.mnkl, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.scale_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.use_2cta_instrs, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + args.fixed_m, + ) + print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/blockwise_gemm/masked_grouped_gemm.py b/examples/python/CuTeDSL/blackwell/blockwise_gemm/masked_grouped_gemm.py new file mode 100644 index 00000000..a08285d9 --- /dev/null +++ b/examples/python/CuTeDSL/blackwell/blockwise_gemm/masked_grouped_gemm.py @@ -0,0 +1,3033 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Type, Tuple, Union + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack + +import math +import random + + +""" +High-performance persistent blockwise masked grouped dense GEMM (C = (SFA * A) * (SFB * B)) example for the NVIDIA Blackwell architecture +using CUTE DSL. +- Matrix A is MxKxL, L is group dimension, A can be row-major("K") +- Matrix B is NxKxL, L is group dimension, B can be column-major("K") +- Matrix C is MxNxL, L is group dimension, C can be row-major("N") or column-major("M") +- Each block will apply the scale factor SFA +- Each row will apply the scale factor SFB +- For each iteration, the kernel will compute C = A * B and then apply the scale factor C *= SFA * SFB + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +Matrix A/C Memory Layout Diagrams: + + ``` + Group 0 Group 1 Group 2 + -+---------+---------+---------+ + | xx| x| xxxx| + K| xx| x| xxxx| + | xx| x| xxxx| + -+---------+---------+---------+ + |<- M ->|<- M ->|<- M ->| + ``` + x = masked elements + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. SCALE warp: Load scaleA and scaleB matrices from global memory (GMEM) to shared memory (SMEM) using non-TMA operations. +2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. EPILOGUE warp: + - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Apply the scale factor and update the final accumulator Final = C * SFA * SFB + Final + - Type convert Final matrix to output type. + - Store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations. + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +.. code-block:: bash + + python examples/blackwell/blockwise_gemm/masked_grouped_gemm.py \ + --ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \ + --scale_dtype Float32 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \ + --mnkl 256,4096,4096,16 + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/blockwise_gemm/masked_grouped_gemm.py \ + --ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \ + --scale_dtype Float32 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \ + --mnkl 256,4096,4096,16 + + +Constraints are same as dense_gemm.py: +* Supported input data types: fp8 (e4m3fn) + see detailed valid dtype combinations in below BlockwiseMaskedGroupedGemmKernel class documentation +* A/B tensor must have the same data type +* Mma tiler M must be 64/128/256 +* Mma tiler N must be 128, align with the scaleB requirement +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned +""" + + +class BlockwiseMaskedGroupedGemmKernel: + """This class implements batched matrix multiplication (C = (SFA * A) * (SFB * B)) with support for fp8 (e4m3fn, e5m2) + and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: Supported A/B data types: + - Float8E4M3FN + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float16/BFloat16 + - Other data types are not supported for accuracy issues + + :note: Constraints: + - MMA tiler M must be 64/128/256 + - MMA tiler N must be 128 + - Cluster shape M must be multiple of 2 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = BlockwiseMaskedGroupedGemmKernel( + ... acc_dtype=cutlass.Float32, + ... use_2cta_instrs=True, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, sfa_tensor, sfb_tensor, max_active_clusters, stream) + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ): + """Initializes the configuration for a Blackwell blockwise dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = ( + tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + # Set specialized warp ids + self.acc_update_warp_id = (0, 1, 2, 3) + self.epilog_warp_id = (4, 5, 6, 7) + self.mma_warp_id = 8 + self.tma_warp_id = 9 + self.scale_warp_id = 10 + self.sched_warp_id = 11 + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( + ( + *self.acc_update_warp_id, + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.scale_warp_id, + self.sched_warp_id, + ) + ) + self.threads_wo_sched = self.threads_per_warp * len( + ( + *self.acc_update_warp_id, + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.scale_warp_id, + ) + ) + self.num_regs_uniform_warps = 64 + self.num_regs_sched_warps = 64 + self.num_regs_epilogue_warps = 216 + self.num_regs_acc_update_warps = 216 + + # Set barrier id for cta sync, epilogue sync and tmem ptr sync + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=32 + * len((self.mma_warp_id, *self.epilog_warp_id, *self.acc_update_warp_id)), + ) + self.sched_sync_barrier = pipeline.NamedBarrier( + barrier_id=4, + num_threads=self.threads_per_warp, + ) + self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + # TMEM offset for final accumulator + self.tmem_final_offset = 384 + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + self.scale_granularity_m = 1 + self.scale_granularity_n = 128 + self.scale_granularity_k = 128 + self.scale_m_per_tile = self.cta_tile_shape_mnk[0] // self.scale_granularity_m + self.scale_n_per_tile = self.cta_tile_shape_mnk[1] // self.scale_granularity_n + self.scale_k_per_tile = self.cta_tile_shape_mnk[2] // self.scale_granularity_k + + if self.scale_k_per_tile != 1: + raise ValueError("scale_k_per_tile must be 1") + if self.scale_m_per_tile != self.cta_tile_shape_mnk[0]: + raise ValueError("scale_m_per_tile must be cta_tile_m") + if self.scale_n_per_tile != 1: + raise ValueError("scale_n_per_tile must be 1") + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # Setup A/B/C/Scale stage count in shared memory and ACC stage count in tensor memory + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_c_stage, + self.num_scale_stage, + self.num_tile_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sfa_dtype, + self.sfb_dtype, + self.scale_m_per_tile * self.scale_k_per_tile, + self.scale_n_per_tile * self.scale_k_per_tile, + self.num_smem_capacity, + self.occupancy, + ) + + # Compute A/B/C/Scale shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + self.sfa_smem_layout_staged = cute.make_layout( + ( + (self.scale_granularity_m, self.scale_m_per_tile), + (self.scale_granularity_k, self.scale_k_per_tile), + self.num_scale_stage, + ), + stride=( + (0, self.scale_k_per_tile), + (0, 1), + self.scale_k_per_tile * self.scale_m_per_tile, + ), + ) + self.sfb_smem_layout_staged = cute.make_layout( + ( + (self.scale_granularity_n, self.scale_n_per_tile), + (self.scale_granularity_k, self.scale_k_per_tile), + self.num_scale_stage, + ), + stride=( + (0, self.scale_k_per_tile), + (0, 1), + self.scale_k_per_tile * self.scale_n_per_tile, + ), + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = 512 + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + sfa: cute.Tensor, + sfb: cute.Tensor, + gidx_mapping: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param sfa: Scale factor tensor A + :type sfa: cute.Tensor + :param sfb: Scale factor tensor B + :type sfb: cute.Tensor + :param gidx_mapping: Mapping from group index to m + :type gidx_mapping: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.sfa_dtype: Type[cutlass.Numeric] = sfa.element_type + self.sfb_dtype: Type[cutlass.Numeric] = sfb.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if a.element_type is cutlass.Float32 else None + ), + ) + + # Setup TMA load for B + b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup TMA store for C + tma_atom_c = None + tma_tensor_c = None + c_cta_v_layout = cute.composition( + cute.make_identity_layout(c.shape), self.epi_tile + ) + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + c_cta_v_layout, + ) + + tensor_sfa = cute.make_tensor( + sfa.iterator, + cute.make_layout( + ( + (self.scale_granularity_m, sfa.shape[0]), + (self.scale_granularity_k, sfa.shape[1]), + sfa.shape[2], + ), + stride=( + (0, sfa.layout.stride[0]), + (0, sfa.layout.stride[1]), + sfa.layout.stride[2], + ), + ), + ) + tensor_sfb = cute.make_tensor( + sfb.iterator, + cute.make_layout( + ( + (self.scale_granularity_n, sfb.shape[0]), + (self.scale_granularity_k, sfb.shape[1]), + sfb.shape[2], + ), + stride=( + (0, sfb.layout.stride[0]), + (0, sfb.layout.stride[1]), + sfb.layout.stride[2], + ), + ), + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + + c_smem_size = cute.cosize(self.c_smem_layout_staged.outer) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + # (bidx, bidy, bidz, valid) + sInfo: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int32, 4 * self.num_tile_stage], + # 1 byte alignment + 1, + ] + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + scale_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_scale_stage * 2 + ] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tile_info_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_tile_stage * 2 + ] + epi_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + c_smem_size, + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA: cute.struct.Align[ + cute.struct.MemRange[ + self.sfa_dtype, cute.cosize(self.sfa_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB: cute.struct.Align[ + cute.struct.MemRange[ + self.sfb_dtype, cute.cosize(self.sfb_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + tensor_sfa, + tensor_sfb, + gidx_mapping, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + mSFA_mkl: cute.Tensor, + mSFB_nkl: cute.Tensor, + gidx_mapping: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + lane_idx = cute.arch.lane_idx() + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize mainloop scale_pipeline (barrier) and states + scale_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + ) + scale_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_warp_id), + ) + scale_pipeline = pipeline.PipelineCpAsync.create( + barrier_storage=storage.scale_mbar_ptr.data_ptr(), + num_stages=self.num_scale_stage, + producer_group=scale_pipeline_producer_group, + consumer_group=scale_pipeline_consumer_group, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * ( + 2 if use_2cta_instrs else 1 + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize epilogue pipeline (barrier) and states + epi_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.acc_update_warp_id), + ) + epi_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_warp_id), + ) + epi_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.epi_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=epi_pipeline_producer_group, + consumer_group=epi_pipeline_consumer_group, + ) + + # Initialize tile info pipeline (barrier) and states + tile_info_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + ) + tile_info_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_wo_sched, + ) + tile_info_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.tile_info_mbar_ptr.data_ptr(), + num_stages=self.num_tile_stage, + producer_group=tile_info_pipeline_producer_group, + consumer_group=tile_info_pipeline_consumer_group, + ) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/C/Scale + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + # (bidx, bidy, bidz, valid) + info_layout = cute.make_layout((4, self.num_tile_stage), stride=(1, 4)) + sInfo = storage.sInfo.get_tensor(info_layout) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, loopM, loopK, loopL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, loopN, loopK, loopL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, loopM, loopN, loopL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + # (bM, bK, loopM, loopK, loopL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, + cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + # (bN, bK, loopN, loopK, loopL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, + cute.slice_(self.cta_tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + # coordinate + cSFA_mkl = cute.make_identity_tensor(cute.shape(mSFA_mkl)) + cSFB_nkl = cute.make_identity_tensor(cute.shape(mSFB_nkl)) + # (bM, bK, loopM, loopK, loopL) + cSFA = cute.local_tile( + cSFA_mkl, + cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + # (bN, bK, loopN, loopK, loopL) + cSFB = cute.local_tile( + cSFB_nkl, + cute.slice_(self.cta_tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) + tCgC = thr_mma.partition_C(gC_mnl) + + # scale viewed as C tensor + sSFA_view_as_C_layout = cute.make_layout( + ( + (self.scale_granularity_m, self.scale_m_per_tile), + self.cta_tile_shape_mnk[1], + self.num_scale_stage, + ), + stride=((0, 1), 0, self.scale_m_per_tile), + ) + sSFB_view_as_C_layout = cute.make_layout( + ( + self.cta_tile_shape_mnk[0], + (self.scale_granularity_n, self.scale_n_per_tile), + self.num_scale_stage, + ), + stride=(0, (0, 1), self.scale_n_per_tile), + ) + sSFA_view_as_C = cute.make_tensor(sSFA.iterator, sSFA_view_as_C_layout) + sSFB_view_as_C = cute.make_tensor(sSFB.iterator, sSFB_view_as_C_layout) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition global/shared tensor for TMA load A/B + # + # load scaleA/scaleB + atom_copy = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mSFA_mkl.element_type, + num_bits_per_copy=mSFA_mkl.element_type.width, + ) + tiled_copy_sfa = cute.make_tiled_copy_tv( + atom_copy, cute.make_layout((32,)), cute.make_layout((1,)) + ) + tiled_copy_sfb = cute.make_tiled_copy_tv( + atom_copy, cute.make_layout((32,)), cute.make_layout((1,)) + ) + thr_copy_sfa = tiled_copy_sfa.get_slice(lane_idx) + thr_copy_sfb = tiled_copy_sfb.get_slice(lane_idx) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAgSFA_mkl = thr_copy_sfa.partition_S(gSFA_mkl) + tAsSFA = thr_copy_sfa.partition_D(sSFA) + tAcSFA = thr_copy_sfa.partition_S(cSFA) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopN, loopK, loopL) + tBgSFB_nkl = thr_copy_sfb.partition_S(gSFB_nkl) + tBsSFB = thr_copy_sfb.partition_D(sSFB) + tBcSFB = thr_copy_sfb.partition_S(cSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + self.cta_sync_barrier.arrive_and_wait() + + # + # Specialized Schedule warp + # + if warp_idx == self.sched_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_sched_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + tile_info_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_tile_stage + ) + + while work_tile.is_valid_tile: + cur_m = cutlass.Int32(0) + cur_boundary = cutlass.Int32(0) + is_valid_m = cutlass.Boolean(False) + + while (not is_valid_m) and work_tile.is_valid_tile: + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # fetch the tile info + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_m = ( + cur_tile_coord[0] // self.cluster_shape_mn[0] + ) * self.cluster_shape_mn[0] + cur_m = mma_tile_coord_m * self.cta_tile_shape_mnk[0] + cur_boundary = gidx_mapping[cur_tile_coord[2]] + is_valid_m = cur_m < cur_boundary + + cur_tile_coord = work_tile.tile_idx + + # acquire tile info pipeline + tile_info_pipeline.producer_acquire(tile_info_producer_state) + + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0] + sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1] + sInfo[(2, tile_info_producer_state.index)] = cur_tile_coord[2] + sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32( + work_tile.is_valid_tile + ) + + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.sched_sync_barrier.arrive_and_wait() + # commit tile info pipeline + tile_info_pipeline.producer_commit(tile_info_producer_state) + tile_info_producer_state.advance() + + tile_info_pipeline.producer_tail(tile_info_producer_state) + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), loopK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), loopK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + tAgA_k = tAgA_slice[(None, ab_producer_state.count)] + tBgB_k = tBgB_slice[(None, ab_producer_state.count)] + tAsA_pipe = tAsA[(None, ab_producer_state.index)] + tBsB_pipe = tBsB[(None, ab_producer_state.index)] + + tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state) + + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized Scale load warp + # + if warp_idx == self.scale_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + scale_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_scale_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # + # Prepare the mask for scaleA/scaleB + # + tApSFA = cute.make_rmem_tensor( + cute.make_layout( + cute.filter_zeros( + cute.slice_(tAsSFA, (None, None, None, 0)) + ).shape + ), + cutlass.Boolean, + ) + tBpSFB = cute.make_rmem_tensor( + cute.make_layout( + cute.filter_zeros( + cute.slice_(tBsSFB, (None, None, None, 0)) + ).shape + ), + cutlass.Boolean, + ) + + # Peek (try_wait) SCALE buffer empty + scale_producer_state.reset_count() + peek_scale_empty_status = cutlass.Boolean(1) + if scale_producer_state.count < k_tile_cnt: + peek_scale_empty_status = scale_pipeline.producer_try_acquire( + scale_producer_state + ) + + # + # load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # + # Slice to per mma tile index + # + tAsSFA_pipe = cute.filter_zeros( + tAsSFA[(None, None, None, scale_producer_state.index)] + ) + tBsSFB_pipe = cute.filter_zeros( + tBsSFB[(None, None, None, scale_producer_state.index)] + ) + tAgSFA_k = cute.filter_zeros( + tAgSFA_mkl[ + ( + None, + None, + None, + tile_info[0], + scale_producer_state.count, + tile_info[2], + ) + ] + ) + tBgSFB_k = cute.filter_zeros( + tBgSFB_nkl[ + ( + None, + None, + None, + tile_info[1], + scale_producer_state.count, + tile_info[2], + ) + ] + ) + + tAcSFA_compact = cute.filter_zeros( + cute.slice_( + tAcSFA, + ( + None, + None, + None, + tile_info[0], + scale_producer_state.count, + tile_info[2], + ), + ) + ) + tBcSFB_compact = cute.filter_zeros( + cute.slice_( + tBcSFB, + ( + None, + None, + None, + tile_info[1], + scale_producer_state.count, + tile_info[2], + ), + ) + ) + for i in cutlass.range_constexpr(cute.size(tApSFA, mode=[1])): + tApSFA[((0, 0), i, (0, 0))] = cute.elem_less( + tAcSFA_compact[(i)][0], mSFA_mkl.shape[0] + ) + for i in cutlass.range_constexpr(cute.size(tBpSFB, mode=[1])): + tBpSFB[((0, 0), i, (0, 0))] = cute.elem_less( + tBcSFB_compact[(i)][0], mSFB_nkl.shape[0] + ) + + # Conditionally wait for Scale buffer empty + scale_pipeline.producer_acquire( + scale_producer_state, peek_scale_empty_status + ) + + # load scaleA/scaleB + cute.copy(tiled_copy_sfa, tAgSFA_k, tAsSFA_pipe, pred=tApSFA) + cute.copy(tiled_copy_sfb, tBgSFB_k, tBsSFB_pipe, pred=tBpSFB) + + scale_pipeline.producer_commit(scale_producer_state) + + # Peek (try_wait) Scale buffer empty + scale_producer_state.advance() + peek_scale_empty_status = cutlass.Boolean(1) + if scale_producer_state.count < k_tile_cnt: + peek_scale_empty_status = scale_pipeline.producer_try_acquire( + scale_producer_state + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait Scale buffer empty + # + scale_pipeline.producer_tail(scale_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # Peek (try_wait) Acc buffer empty for k_tile = 0 + acc_producer_state.reset_count() + peek_acc_empty_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire( + acc_producer_state + ) + + # + # Mma mainloop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire( + acc_producer_state, peek_acc_empty_status + ) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait( + ab_consumer_state, peek_ab_full_status + ) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Async arrive accumulator buffer full(each kblock) + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + + # Peek (try_wait) Acc buffer empty for k_tile = k_tile + 1 + acc_producer_state.advance() + if acc_producer_state.count < k_tile_cnt: + if is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire( + acc_producer_state + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + + # + # Specialized acc update warps + # + if warp_idx <= self.acc_update_warp_id[-1]: + cute.arch.warpgroup_reg_alloc(self.num_regs_acc_update_warps) + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + tCtAcc_final = cute.make_tensor( + tCtAcc_base.iterator + self.tmem_final_offset, tCtAcc_base.layout + ) + + # + # Partition for epilogue + # + epi_tidx = tidx % 128 + ( + tiled_copy_t2r, + tiled_copy_r2t, + tTR_tAcc_base, + tTR_rAcc, + tTR_rAcc_final, + tTR_sSFA, + tTR_sSFB, + tRT_rAcc, + tRT_tAcc_base, + ) = self.acc_update_tmem_copy_and_partition( + epi_tidx, + tCtAcc_base, + tCtAcc_final, + tCgC, + sSFA_view_as_C, + sSFB_view_as_C, + epi_tile, + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + scale_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_scale_stage + ) + + epi_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, 1 + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # initialize the final accumulator + tTR_rAcc_final.fill(0.0) + + tTR_rSFA = cute.make_rmem_tensor( + cute.slice_(tTR_sSFA, (None, None, None, 0, None, 0)).shape, + self.acc_dtype, + ) + tTR_rSFB = cute.make_rmem_tensor( + cute.slice_(tTR_sSFB, (None, None, None, 0, None, 0)).shape, + self.acc_dtype, + ) + + scale_consumer_state.reset_count() + peek_scale_full_status = cutlass.Boolean(1) + if scale_consumer_state.count < k_tile_cnt: + peek_scale_full_status = scale_pipeline.consumer_try_wait( + scale_consumer_state + ) + + acc_consumer_state.reset_count() + peek_acc_full_status = cutlass.Boolean(1) + if acc_consumer_state.count < k_tile_cnt: + peek_acc_full_status = acc_pipeline.consumer_try_wait( + acc_consumer_state + ) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # + # Wait for scale buffer full + # + scale_pipeline.consumer_wait( + scale_consumer_state, peek_scale_full_status + ) + + tTR_sSFA_slice = cute.slice_( + tTR_sSFA, + (None, None, None, 0, None, scale_consumer_state.index), + ) + tTR_sSFB_slice = cute.slice_( + tTR_sSFB, + (None, None, None, 0, None, scale_consumer_state.index), + ) + + scale_atom_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.acc_dtype, + num_bits_per_copy=self.acc_dtype.width, + ) + + cute.copy(scale_atom_copy, tTR_sSFA_slice, tTR_rSFA) + cute.copy(scale_atom_copy, tTR_sSFB_slice, tTR_rSFB) + + # + # Wait for accumulator buffer full + # + + acc_pipeline.consumer_wait(acc_consumer_state, peek_acc_full_status) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + # + # Update accumulator by scale factor in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Update accumulator by scale factor + # + tTR_rAcc_subtile = tTR_rAcc_final[ + (None, None, None, subtile_idx) + ] + tTR_rSFA_subtile = tTR_rSFA[(None, None, None, subtile_idx)] + tTR_rSFB_subtile = tTR_rSFB[(None, None, None, subtile_idx)] + + acc_vec = tTR_rAcc.load() + final_vec = tTR_rAcc_subtile.load() + scale_a = tTR_rSFA_subtile.load() + scale_b = tTR_rSFB_subtile.load() + scale = scale_a * scale_b + final_vec = acc_vec * scale + final_vec + tTR_rAcc_subtile.store(final_vec.to(self.acc_dtype)) + + # + # Async arrive accumulator buffer empty + # + scale_pipeline.consumer_release(scale_consumer_state) + scale_consumer_state.advance() + + peek_scale_full_status = cutlass.Boolean(1) + if scale_consumer_state.count < k_tile_cnt: + peek_scale_full_status = scale_pipeline.consumer_try_wait( + scale_consumer_state + ) + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + peek_acc_full_status = cutlass.Boolean(1) + if acc_consumer_state.count < k_tile_cnt: + peek_acc_full_status = acc_pipeline.consumer_try_wait( + acc_consumer_state + ) + + tRT_tAcc = tRT_tAcc_base[(None, None, None, None, None, 0)] + tRT_tAcc = cute.group_modes(tRT_tAcc, 3, cute.rank(tRT_tAcc)) + + # + # Wait for epilogue buffer empty + # + epi_pipeline.producer_acquire(epi_producer_state) + + # copy the accumulator to tensor memory buffer + cute.copy(tiled_copy_r2t, tTR_rAcc_final, tRT_tAcc) + cute.arch.fence_view_async_tmem_store() + + # + # Async arrive epilogue buffer full + # + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Specialized epilogue warps + # + if warp_idx <= self.epilog_warp_id[-1] and warp_idx >= self.epilog_warp_id[0]: + cute.arch.warpgroup_reg_alloc(self.num_regs_epilogue_warps) + # + # Alloc tensor memory buffer + # + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base_ = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + tCtAcc_final = cute.make_tensor( + tCtAcc_base_.iterator + self.tmem_final_offset, tCtAcc_base_.layout + ) + + # + # Partition for epilogue + # + epi_tidx = tidx % 128 + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_final, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_r2s = None + simt_atom = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC_partitioned = None + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + epi_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, 1 + ) + + c_pipeline = None + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + num_prev_subtiles = cutlass.Int32(0) + + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Slice to per mma tile index + # + bSG_gC = None + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + mma_tile_coord_mnl[0], + mma_tile_coord_mnl[1], + mma_tile_coord_mnl[2], + ) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, epi_consumer_state.index) + ] + + # + # Wait for accumulator buffer full + # + epi_pipeline.consumer_wait(epi_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + num_prev_subtiles = num_prev_subtiles + 1 + c_buffer = num_prev_subtiles % self.num_c_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.epilog_sync_barrier.arrive_and_wait() + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + + # + # Async arrive accumulator buffer empty + # + epi_pipeline.consumer_release(epi_consumer_state) + epi_consumer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + # + # Wait for C store complete + # + c_pipeline.producer_tail() + + def acc_update_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + tAcc_final: cute.Tensor, + gC_mnl: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + epi_tile: cute.Tile, + ) -> Tuple[ + cute.TiledCopy, + cute.TiledCopy, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + ]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + Make tiledCopy for tensor memory store, then use it to partition register array (source) and tensor memory (destination). + Partition the scale factor tensor for related copy operations. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param tAcc_final: The final accumulator tensor to be copied and partitioned + :type tAcc_final: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param sSFA: The scale factor tensor for A + :type sSFA: cute.Tensor + :param sSFB: The scale factor tensor for B + :type sSFB: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tiled_copy_r2t: The tiled copy operation for register to tmem copy(r2t) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + - tTR_rAcc_final: The accumulated tensor in register used to hold all t2r results + - tTR_sSFA: The partitioned tensor SFA by tiled_copy_t2r + - tTR_sSFB: The partitioned tensor SFB by tiled_copy_t2r + - tRT_rAcc_final: The accumulated tensor in register used to hold all r2t results + - tRT_tAcc_final: The partitioned accumulator tensor by tiled_copy_r2t + :rtype: Tuple[cute.TiledCopy, cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + tmem_load_atom = None + tmem_store_atom = None + if cutlass.const_expr(self.mma_tiler[0] == 64): + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)), + self.acc_dtype, + ) + elif cutlass.const_expr(self.mma_tiler[0] == 128): + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.acc_dtype, + ) + else: + # default: 16dp + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(1)), + self.acc_dtype, + ) + if cutlass.const_expr(self.mma_tiler[0] == 64): + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St16x256bOp(tcgen05.copy.Repetition(8)), + self.acc_dtype, + ) + elif cutlass.const_expr(self.mma_tiler[0] == 128): + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), + self.acc_dtype, + ) + else: + # default: 16dp + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St16x256bOp(tcgen05.copy.Repetition(1)), + self.acc_dtype, + ) + + tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile) + tAcc_final_epi = cute.flat_divide( + tAcc_final[((None, None), 0, 0, None)], epi_tile + ) + + tiled_copy_t2r = tcgen05.make_tmem_copy( + tmem_load_atom, tAcc_epi[(None, None, 0, 0, 0)] + ) + tiled_copy_r2t = tcgen05.make_tmem_copy( + tmem_store_atom, tAcc_final_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + thr_copy_r2t = tiled_copy_r2t.get_slice(tidx) + + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + sSFA_epi = cute.flat_divide(sSFA, epi_tile) + sSFB_epi = cute.flat_divide(sSFB, epi_tile) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + tTR_sSFA = thr_copy_t2r.partition_D(sSFA_epi) + tTR_sSFB = thr_copy_t2r.partition_D(sSFB_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rAcc_final_ = cute.make_rmem_tensor( + tTR_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype + ) + tTR_rAcc_final = cute.group_modes( + tTR_rAcc_final_, 3, cute.rank(tTR_rAcc_final_) + ) + + tRT_gC = thr_copy_r2t.partition_S(gC_mnl_epi) + tRT_tAcc_final = thr_copy_r2t.partition_D(tAcc_final_epi) + # (R2T, R2T_M, R2T_N, EPI_M, EPI_N, loopM, loopN, loopL) + tRT_rAcc_final_ = cute.make_rmem_tensor( + tRT_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype + ) + # (R2T, R2T_M, R2T_N, (EPI_M, EPI_N)) + tRT_rAcc_final = cute.group_modes( + tRT_rAcc_final_, 3, cute.rank(tRT_rAcc_final_) + ) + + return ( + tiled_copy_t2r, + tiled_copy_r2t, + tTR_tAcc, + tTR_rAcc, + tTR_rAcc_final, + tTR_sSFA, + tTR_sSFB, + tRT_rAcc_final, + tRT_tAcc_final, + ) + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing : + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sfa_dtype: Type[cutlass.Numeric], + sfb_dtype: Type[cutlass.Numeric], + sfa_count: int, + sfb_count: int, + num_smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout of operand C. + :type c_layout: utils.LayoutEnum + :param num_smem_capacity: Total available shared memory capacity in bytes. + :type num_smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 3 if mma_tiler_mnk[0] / tiled_mma.thr_id.shape == 128 else 6 + + # Default C stages + num_c_stage = 2 + + # Default ScaleA/B stages + num_scale_stage = 10 + + # Default Tile info stages + num_tile_stage = 2 + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + # 1024B alignment + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + sfa_bytes = sfa_count * (sfa_dtype.width // 8) * num_scale_stage + sfb_bytes = sfb_count * (sfb_dtype.width // 8) * num_scale_stage + scale_bytes = math.ceil((sfa_bytes + sfb_bytes) / 1024) * 1024 + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + num_smem_capacity // occupancy + - (mbar_helpers_bytes + c_bytes + scale_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + num_c_stage += ( + num_smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes + scale_bytes) + ) // (occupancy * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage, num_scale_stage, num_tile_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _get_tma_atom_kind( + atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean + ) -> Union[ + cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp + ]: + """ + Select the appropriate TMA copy atom based on the number of SMs and the multicast flag. + + :param atom_sm_cnt: The number of SMs + :type atom_sm_cnt: cutlass.Int32 + :param mcast: The multicast flag + :type mcast: cutlass.Boolean + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + """ + if atom_sm_cnt == 2 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 2 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + + raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}") + + @staticmethod + def is_valid_dtypes( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if ab_dtype not in { + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if acc_dtype not in {cutlass.Float32}: + is_valid = False + if c_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.BFloat16}: + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + # Skip invalid mma tile n + if mma_tiler_mn[1] not in (128,): + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not BlockwiseMaskedGroupedGemmKernel.is_valid_dtypes( + ab_dtype, acc_dtype, c_dtype + ): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not BlockwiseMaskedGroupedGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not BlockwiseMaskedGroupedGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip unsupported A/B layout + if not (a_major == "k" and b_major == "k"): + can_implement = False + return can_implement + + +def create_mask(num_groups: int, m: int, fixed_m=False, tile_m=128): + # align with block_m (or block_n if swapAB) + masked_m_candidates = list( + filter( + lambda candidate: candidate <= m, + (1, 16, tile_m, tile_m * 2, tile_m * 3, tile_m * 4), + ) + ) + masked_m = torch.empty((num_groups,), dtype=torch.int32) + for j in range(num_groups): + if fixed_m: + masked_m[j] = 1 + else: + masked_m[j] = random.choice(masked_m_candidates) + gidx_mapping = masked_m.to(device="cuda", dtype=torch.int32) + + return gidx_mapping, masked_m + + +def create_tensors( + l, + m, + n, + k, + a_major, + b_major, + cd_major, + ab_dtype, + c_dtype, + scale_dtype, + fixed_m=False, +): + torch.manual_seed(1111) + + _gidx_mapping, masked_m = create_mask(l, m, fixed_m) + + a_torch_cpu = cutlass_torch.matrix(l, m, k, a_major == "m", ab_dtype) + b_torch_cpu = cutlass_torch.matrix(l, n, k, b_major == "n", ab_dtype) + c_torch_cpu = cutlass_torch.matrix(l, m, n, cd_major == "m", c_dtype) + sfa_torch_cpu = cutlass_torch.matrix(l, m, math.ceil(k / 128), True, scale_dtype) + sfb_torch_cpu = cutlass_torch.matrix( + l, math.ceil(n / 128), math.ceil(k / 128), False, scale_dtype + ) + + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like( + c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + sfa_tensor, _ = cutlass_torch.cute_tensor_like( + sfa_torch_cpu, scale_dtype, is_dynamic_layout=True, assumed_align=16 + ) + sfb_tensor, _ = cutlass_torch.cute_tensor_like( + sfb_torch_cpu, scale_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + gidx_mapping = from_dlpack(_gidx_mapping).mark_layout_dynamic() + + return ( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + gidx_mapping, + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + sfa_torch_cpu, + sfb_torch_cpu, + c_torch_gpu, + masked_m, + ) + + +def run( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + scale_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_2cta_instrs: bool, + tolerance: float, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + fixed_m: bool = False, + **kwargs, +): + """ + Prepare A/B/C tensors, launch GPU kernel, and reference checking. + """ + print("Running Blackwell Persistent Dense Blockwise GEMM test with:") + print(f"mnkl: {mnkl}") + print( + f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}, Scale dtype: {scale_dtype}" + ) + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}") + print(f"Use TMA Store: {'True'}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + + # Unpack parameters + m, n, k, l = mnkl + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + if not BlockwiseMaskedGroupedGemmKernel.can_implement( + ab_dtype, + acc_dtype, + c_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + ( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + gidx_mapping, + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + sfa_torch_cpu, + sfb_torch_cpu, + c_torch_gpu, + masked_m, + ) = create_tensors( + l, + m, + n, + k, + a_major, + b_major, + c_major, + ab_dtype, + c_dtype, + scale_dtype, + fixed_m=fixed_m, + ) + # Configure gemm kernel + gemm = BlockwiseMaskedGroupedGemmKernel( + acc_dtype, use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ) + + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + # Compile gemm kernel + compiled_gemm = cute.compile( + gemm, + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + gidx_mapping, + max_active_clusters, + current_stream, + ) + + # Execution + compiled_gemm( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + gidx_mapping, + current_stream, + ) + + torch.cuda.synchronize() + + # Compute reference result + if not skip_ref_check: + # update + def pad_and_multiply(scale, tensor): + cm, ck, _ = scale.shape + m, k, _ = tensor.shape + IsGroupWise = False + IsBlockWise = False + if ck == math.ceil(k / 128): + IsGroupWise = True + if cm == math.ceil(m / 128): + IsBlockWise = True + if not IsBlockWise and not IsGroupWise: + raise ValueError("Only support granularity = 128") + + k_idx = torch.arange(k, device=scale.device) + if IsGroupWise: + k_idx = k_idx // 128 + m_idx = torch.arange(m, device=scale.device) + if IsBlockWise: + m_idx = m_idx // 128 + expanded_scale = scale[m_idx[:, None], k_idx, :] + + result = expanded_scale * tensor + + return result + + updated_a = pad_and_multiply(sfa_torch_cpu, a_torch_cpu) + updated_b = pad_and_multiply(sfb_torch_cpu, b_torch_cpu) + + ref = ( + torch.einsum("mkl,nkl->mnl", updated_a, updated_b) + .to(cutlass_torch.dtype(c_dtype)) + .cpu() + ) + res = c_torch_gpu.view(cutlass_torch.dtype(c_dtype)).cpu() + + for j in range(l): + res_j = res[: masked_m[j].item(), :, j] + ref_j = ref[: masked_m[j].item(), :, j] + torch.testing.assert_close(res_j, ref_j, atol=tolerance, rtol=1e-03) + + def generate_tensors(): + # Reuse existing CPU reference tensors and create new GPU tensors from them + ( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + gidx_mapping, + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + sfa_torch_cpu, + sfb_torch_cpu, + c_torch_gpu, + masked_m, + ) = create_tensors( + l, + m, + n, + k, + a_major, + b_major, + c_major, + ab_dtype, + c_dtype, + scale_dtype, + fixed_m=fixed_m, + ) + return testing.JitArguments( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + gidx_mapping, + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch_cpu.numel() * a_torch_cpu.element_size() + + b_torch_cpu.numel() * b_torch_cpu.element_size() + + c_torch_cpu.numel() * c_torch_cpu.element_size() + + sfa_torch_cpu.numel() * sfa_torch_cpu.element_size() + + sfb_torch_cpu.numel() * sfb_torch_cpu.element_size() + + l * 4 # group * sizeof(int32) + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of Dense Persistent GEMM on Blackwell." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(256, 256, 512, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float8E4M3FN) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.BFloat16) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument("--scale_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument("--a_major", choices=["k"], type=str, default="k") + parser.add_argument("--b_major", choices=["k"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument("--fixed_m", action="store_true", default=False, help="Fixed M") + parser.add_argument( + "--use_cold_l2", action="store_true", default=False, help="Use cold L2" + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run( + args.mnkl, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.scale_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.use_2cta_instrs, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + args.fixed_m, + ) + print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py b/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py index c9c24b84..84367506 100644 --- a/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py +++ b/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py @@ -27,7 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import argparse -from typing import Optional, Type, Tuple, Union +from typing import Type, Tuple, Union import cuda.bindings.driver as cuda import torch @@ -85,7 +85,7 @@ Input arguments to this example is shown below: .. code-block:: bash - python examples/blackwell/dense_blockscaled_gemm_persistent.py \ + python examples/blackwell/dense_blockscaled_gemm_persistent.py \ --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ --c_dtype Float16 \ --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ @@ -95,7 +95,7 @@ To collect performance with NCU profiler: .. code-block:: bash - ncu python examples/blackwell/dense_blockscaled_gemm_persistent.py \ + ncu python examples/blackwell/dense_blockscaled_gemm_persistent.py \ --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ --c_dtype Float16 \ --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ @@ -108,7 +108,7 @@ Constraints: see detailed valid dtype combinations in below Sm100BlockScaledPersistentDenseGemmKernel class documentation * A/B tensor must have the same data type, mixed data type is not supported (e.g., mxf8 x mxf4) * Mma tiler M must be 128 or 256(use_2cta_instrs) -* Mma tiler N must be 128 or 256 +* Mma tiler N must be 64/128/192/256 * Cluster shape M/N must be positive and power of 2, total cluster size <= 16 * Cluster shape M must be multiple of 2 if Mma tiler M is 256(use_2cta_instrs) * The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, @@ -144,7 +144,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: - Float8E4M3FN/Float8E5M2 :note: Constraints: - MMA tiler M must be 128 or 256 (use_2cta_instrs) - - MMA tiler N must be 128/256 + - MMA tiler N must be 64/128/192/256 - Cluster shape M must be multiple of 2 if Mma tiler M is 256 - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 - Also, Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors @@ -209,9 +209,18 @@ class Sm100BlockScaledPersistentDenseGemmKernel: (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) ) # Set barrier id for cta sync, epilogue sync and tmem ptr sync - self.cta_sync_bar_id = 0 - self.epilog_sync_bar_id = 1 - self.tmem_ptr_sync_bar_id = 2 + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)), + ) self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") SM100_TMEM_CAPACITY_COLUMNS = 512 self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS @@ -228,21 +237,17 @@ class Sm100BlockScaledPersistentDenseGemmKernel: - Computing epilogue subtile - Setting up A/B/SFA/SFB/C stage counts in shared memory - Computing A/B/SFA/SFB/C shared memory layout - - Computing tensor memory allocation columns """ # Compute mma instruction shapes - mma_inst_bits_k = 256 # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K) - self.mma_inst_shape_mnk = ( + self.mma_inst_shape_mn = ( self.mma_tiler[0], self.mma_tiler[1], - mma_inst_bits_k // self.a_dtype.width, ) # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) - self.mma_inst_shape_mnk_sfb = ( - self.mma_inst_shape_mnk[0] // (2 if self.use_2cta_instrs else 1), - cute.round_up(self.mma_inst_shape_mnk[1], 128), - self.mma_inst_shape_mnk[2], + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), ) tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( @@ -252,7 +257,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: self.sf_dtype, self.sf_vec_size, self.cta_group, - self.mma_inst_shape_mnk[:2], + self.mma_inst_shape_mn, ) tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( @@ -262,20 +267,21 @@ class Sm100BlockScaledPersistentDenseGemmKernel: self.sf_dtype, self.sf_vec_size, cute.nvgpu.tcgen05.CtaGroup.ONE, - self.mma_inst_shape_mnk_sfb[:2], + self.mma_inst_shape_mn_sfb, ) # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) mma_inst_tile_k = 4 self.mma_tiler = ( - self.mma_inst_shape_mnk[0], - self.mma_inst_shape_mnk[1], - self.mma_inst_shape_mnk[2] * mma_inst_tile_k, + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1], + mma_inst_shape_k * mma_inst_tile_k, ) self.mma_tiler_sfb = ( - self.mma_inst_shape_mnk_sfb[0], - self.mma_inst_shape_mnk_sfb[1], - self.mma_inst_shape_mnk_sfb[2] * mma_inst_tile_k, + self.mma_inst_shape_mn_sfb[0], + self.mma_inst_shape_mn_sfb[1], + mma_inst_shape_k * mma_inst_tile_k, ) self.cta_tile_shape_mnk = ( self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), @@ -314,9 +320,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: tiled_mma, self.mma_tiler, self.a_dtype, - self.a_major_mode, self.b_dtype, - self.b_major_mode, self.epi_tile, self.c_dtype, self.c_layout, @@ -431,7 +435,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: self.sf_dtype, self.sf_vec_size, self.cta_group, - self.mma_inst_shape_mnk[:2], + self.mma_inst_shape_mn, ) tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( @@ -441,7 +445,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: self.sf_dtype, self.sf_vec_size, cute.nvgpu.tcgen05.CtaGroup.ONE, - self.mma_inst_shape_mnk_sfb[:2], + self.mma_inst_shape_mn_sfb, ) atom_thr_size = cute.size(tiled_mma.thr_id.shape) @@ -507,6 +511,31 @@ class Sm100BlockScaledPersistentDenseGemmKernel: internal_type=cutlass.Int16, ) + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + x = tma_tensor_sfb.stride[0][1] + y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) + + new_shape = ( + ( + tma_tensor_sfb.shape[0][0], + ((2, 2), y) + ), + tma_tensor_sfb.shape[1], + tma_tensor_sfb.shape[2] + ) + # Use right multiplication for ScaledBasis (3 * x instead of x * 3) + x_times_3 = 3 * x + new_stride = ( + ( + tma_tensor_sfb.stride[0][0], + ((x, x), x_times_3) + ), + tma_tensor_sfb.stride[1], + tma_tensor_sfb.stride[2] + ) + tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) + tma_tensor_sfb = cute.make_tensor(tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout) + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) @@ -628,7 +657,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: mSFA_mkl: cute.Tensor, tma_atom_sfb: cute.CopyAtom, mSFB_nkl: cute.Tensor, - tma_atom_c: Optional[cute.CopyAtom], + tma_atom_c: cute.CopyAtom, mC_mnl: cute.Tensor, cluster_layout_vmnk: cute.Layout, cluster_layout_sfb_vmnk: cute.Layout, @@ -636,7 +665,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: b_smem_layout_staged: cute.ComposedLayout, sfa_smem_layout_staged: cute.Layout, sfb_smem_layout_staged: cute.Layout, - c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], epi_tile: cute.Tile, tile_sched_params: utils.PersistentTileSchedulerParams, epilogue_op: cutlass.Constexpr, @@ -684,9 +713,6 @@ class Sm100BlockScaledPersistentDenseGemmKernel: smem = utils.SmemAllocator() storage = smem.allocate(self.shared_storage) - tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr - tmem_holding_buf = storage.tmem_holding_buf - # Initialize mainloop ab_pipeline (barrier) and states ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 @@ -719,14 +745,13 @@ class Sm100BlockScaledPersistentDenseGemmKernel: ) # Tensor memory dealloc barrier init - if use_2cta_instrs: - if warp_idx == self.tma_warp_id: - num_tmem_dealloc_threads = 32 - with cute.arch.elect_one(): - cute.arch.mbarrier_init( - tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads - ) - cute.arch.mbarrier_init_fence() + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) # Cluster arrive after barrier init if cute.size(self.cluster_shape_mn) > 1: @@ -790,7 +815,9 @@ class Sm100BlockScaledPersistentDenseGemmKernel: ) # (bN, bK, RestN, RestK, RestL) gSFB_nkl = cute.local_tile( - mSFB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + mSFB_nkl, + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), ) # (bM, bN, RestM, RestN, RestL) gC_mnl = cute.local_tile( @@ -894,9 +921,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: if cute.size(self.cluster_shape_mn) > 1: cute.arch.cluster_wait() else: - cute.arch.barrier( - barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta - ) + self.cta_sync_barrier.arrive_and_wait() # # Specialized TMA load warp @@ -915,7 +940,6 @@ class Sm100BlockScaledPersistentDenseGemmKernel: ) while work_tile.is_valid_tile: - # Get tile coord from tile scheduler cur_tile_coord = work_tile.tile_idx mma_tile_coord_mnl = ( @@ -940,9 +964,13 @@ class Sm100BlockScaledPersistentDenseGemmKernel: tAgSFA_slice = tAgSFA[ (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) ] + + slice_n = mma_tile_coord_mnl[1] + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + slice_n = mma_tile_coord_mnl[1] // 2 # ((atom_v, rest_v), RestK) tBgSFB_slice = tBgSFB[ - (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + (None, slice_n, None, mma_tile_coord_mnl[2]) ] # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt @@ -1017,21 +1045,13 @@ class Sm100BlockScaledPersistentDenseGemmKernel: # # Bar sync for retrieve tensor memory ptr from shared mem # - tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) - cute.arch.barrier( - barrier_id=self.tmem_ptr_sync_bar_id, - number_of_threads=tmem_ptr_read_threads, - ) + tmem.wait_for_alloc() # # Retrieving tensor memory ptr and make accumulator/SFA/SFB tensor # + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) # Make accumulator tmem tensor - acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( - self.acc_dtype, - alignment=16, - ptr_to_buffer_holding_addr=tmem_holding_buf, - ) # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) @@ -1067,12 +1087,16 @@ class Sm100BlockScaledPersistentDenseGemmKernel: # # Partition for S2T copy of SFA/SFB # - tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = ( - self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) - ) - tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = ( - self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) - ) + ( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t, + tCtSFA_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t, + tCtSFB_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) # # Persistent tile scheduling loop @@ -1116,6 +1140,30 @@ class Sm100BlockScaledPersistentDenseGemmKernel: if is_leader_cta: acc_pipeline.producer_acquire(acc_producer_state) + tCtSFB_mma = tCtSFB + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + # If this is an ODD tile, shift the TMEM start address for cta_tile_shape_n=192 case by two words (ignores first 64 columns of SFB) + offset = cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA) + + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + elif cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + # Move in increments of 64 columns of SFB + offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA) + + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + # # Reset the ACCUMULATE field for each tile # @@ -1170,7 +1218,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: ) tiled_mma.set( tcgen05.Field.SFB, - tCtSFB[sf_kblock_coord].iterator, + tCtSFB_mma[sf_kblock_coord].iterator, ) cute.gemm( @@ -1220,30 +1268,17 @@ class Sm100BlockScaledPersistentDenseGemmKernel: # # Alloc tensor memory buffer # - if warp_idx == self.epilog_warp_id[0]: - cute.arch.alloc_tmem( - self.num_tmem_alloc_cols, - tmem_holding_buf, - is_two_cta=use_2cta_instrs, - ) + tmem.allocate(self.num_tmem_alloc_cols) # # Bar sync for retrieve tensor memory ptr from shared memory # - tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) - cute.arch.barrier( - barrier_id=self.tmem_ptr_sync_bar_id, - number_of_threads=tmem_ptr_read_threads, - ) + tmem.wait_for_alloc() # # Retrieving tensor memory ptr and make accumulator tensor # - acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( - self.acc_dtype, - alignment=16, - ptr_to_buffer_holding_addr=tmem_holding_buf, - ) + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) @@ -1251,20 +1286,24 @@ class Sm100BlockScaledPersistentDenseGemmKernel: # Partition for epilogue # epi_tidx = tidx - tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = ( - self.epilog_tmem_copy_and_partition( - epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs - ) + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs ) - tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( tiled_copy_t2r, tTR_rC, epi_tidx, sC ) - tma_atom_c, bSG_sC, bSG_gC_partitioned = ( - self.epilog_gmem_copy_and_partition( - epi_tidx, tma_atom_c, tCgC, epi_tile, sC - ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC ) # @@ -1283,7 +1322,6 @@ class Sm100BlockScaledPersistentDenseGemmKernel: c_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), - 32 * len(self.epilog_warp_id), ) c_pipeline = pipeline.PipelineTmaStore.create( num_stages=self.num_c_stage, @@ -1291,7 +1329,6 @@ class Sm100BlockScaledPersistentDenseGemmKernel: ) while work_tile.is_valid_tile: - # Get tile coord from tile scheduler cur_tile_coord = work_tile.tile_idx mma_tile_coord_mnl = ( @@ -1360,11 +1397,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta, ) - epilog_threads = 32 * len(self.epilog_warp_id) - cute.arch.barrier( - barrier_id=self.epilog_sync_bar_id, - number_of_threads=epilog_threads, - ) + self.epilog_sync_barrier.arrive_and_wait() # # TMA store C to global memory @@ -1378,10 +1411,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: # Fence and barrier to make sure shared memory store is visible to TMA store c_pipeline.producer_commit() c_pipeline.producer_acquire() - cute.arch.barrier( - barrier_id=self.epilog_sync_bar_id, - number_of_threads=epilog_threads, - ) + self.epilog_sync_barrier.arrive_and_wait() # # Async arrive accumulator buffer empty @@ -1399,21 +1429,9 @@ class Sm100BlockScaledPersistentDenseGemmKernel: # # Dealloc the tensor memory buffer # - if warp_idx == self.epilog_warp_id[0]: - cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) - epilog_threads = 32 * len(self.epilog_warp_id) - cute.arch.barrier( - barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads - ) - if warp_idx == self.epilog_warp_id[0]: - if use_2cta_instrs: - cute.arch.mbarrier_arrive( - tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 - ) - cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) - cute.arch.dealloc_tmem( - acc_tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs - ) + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(acc_tmem_ptr) # # Wait for C store complete # @@ -1520,7 +1538,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) # (T2R, T2R_M, T2R_N) - tTR_rAcc = cute.make_fragment( + tTR_rAcc = cute.make_rmem_tensor( tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype ) return tiled_copy_t2r, tTR_tAcc, tTR_rAcc @@ -1614,9 +1632,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: tiled_mma: cute.TiledMma, mma_tiler_mnk: Tuple[int, int, int], a_dtype: Type[cutlass.Numeric], - a_major_mode: tcgen05.OperandMajorMode, b_dtype: Type[cutlass.Numeric], - b_major_mode: tcgen05.OperandMajorMode, epi_tile: cute.Tile, c_dtype: Type[cutlass.Numeric], c_layout: utils.LayoutEnum, @@ -1633,12 +1649,8 @@ class Sm100BlockScaledPersistentDenseGemmKernel: :type mma_tiler_mnk: tuple[int, int, int] :param a_dtype: Data type of operand A. :type a_dtype: type[cutlass.Numeric] - :param a_major_mode: Major mode of operand A. - :type a_major_mode: tcgen05.OperandMajorMode :param b_dtype: Data type of operand B. :type b_dtype: type[cutlass.Numeric] - :param b_major_mode: Major mode of operand B. - :type b_major_mode: tcgen05.OperandMajorMode :param epi_tile: The epilogue tile shape. :type epi_tile: cute.Tile :param c_dtype: Data type of operand C (output). @@ -1830,7 +1842,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: c_major: str, ) -> bool: """ - Check if the dtypes and sf_vec_size are valid combinations + Check if layouts and dtypes are valid combinations :param ab_dtype: The data type of the A and B operands :type ab_dtype: Type[cutlass.Numeric] @@ -1870,9 +1882,9 @@ class Sm100BlockScaledPersistentDenseGemmKernel: """ is_valid = True # Skip invalid mma tile shape - if not mma_tiler_mn[0] in [128, 256]: + if mma_tiler_mn[0] not in [128, 256]: is_valid = False - if not mma_tiler_mn[1] in [128, 256]: + if mma_tiler_mn[1] not in [64, 128, 192, 256]: is_valid = False # Skip illegal cluster shape if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0: @@ -2088,7 +2100,7 @@ def run( :return: Execution time of the GEMM kernel :rtype: float """ - print(f"Running Sm100 Persistent Dense BlockScaled GEMM test with:") + print("Running Sm100 Persistent Dense BlockScaled GEMM test with:") print(f"mnkl: {mnkl}") print(f"AB dtype: {ab_dtype}, SF dtype: {sf_dtype}, SF Vec size: {sf_vec_size}") print(f"C dtype: {c_dtype}") @@ -2143,21 +2155,21 @@ def run( c_ref, c_dtype, is_dynamic_layout=True, assumed_align=16 ) - # Mark tensor to be byte aligned + # Mark tensor with element divisibility for 16B alignment a_tensor.mark_compact_shape_dynamic( mode=1 if a_major == "k" else 0, stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0), - divisibility=2 if ab_dtype == cutlass.Float4E2M1FN else 1, + divisibility=32 if ab_dtype == cutlass.Float4E2M1FN else 16, ) b_tensor.mark_compact_shape_dynamic( mode=1 if b_major == "k" else 0, stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0), - divisibility=2 if ab_dtype == cutlass.Float4E2M1FN else 1, + divisibility=32 if ab_dtype == cutlass.Float4E2M1FN else 16, ) c_tensor.mark_compact_shape_dynamic( mode=1 if c_major == "n" else 0, stride_order=(2, 0, 1) if c_major == "n" else (2, 1, 0), - divisibility=2 if c_dtype == cutlass.Float4E2M1FN else 1, + divisibility=32 if ab_dtype == cutlass.Float4E2M1FN else 16, ) # Create scale factor tensor SFA/SFB @@ -2374,6 +2386,7 @@ def run( return exec_time # Return execution time in microseconds + if __name__ == "__main__": def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm.py b/examples/python/CuTeDSL/blackwell/dense_gemm.py index f5a83729..304c86bb 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm.py @@ -39,7 +39,8 @@ import cutlass.pipeline as pipeline from cutlass.cute.nvgpu import cpasync, tcgen05 import cutlass.torch as cutlass_torch import cutlass.utils.blackwell_helpers as sm100_utils -from cutlass.cute.runtime import from_dlpack + +import cutlass.cute.testing as testing """ A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Blackwell SM100 architecture @@ -203,6 +204,7 @@ class DenseGemmKernel: self.use_2cta_instrs = use_2cta_instrs self.cluster_shape_mn = cluster_shape_mn # K dimension is deferred in _setup_attributes + self.mma_tiler_mn = mma_tiler_mn self.mma_tiler = (*mma_tiler_mn, 1) self.use_tma_store = use_tma_store @@ -414,59 +416,17 @@ class DenseGemmKernel: tma_atom_c = None tma_tensor_c = None if cutlass.const_expr(self.use_tma_store): - c_cta_v_layout = cute.composition( - cute.make_identity_layout(c.shape), self.epi_tile - ) epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem_layout, - c_cta_v_layout, + self.epi_tile, ) # Compute grid size grid = self._compute_grid(c, self.cta_tile_shape_mnk, self.cluster_shape_mn) - self.buffer_align_bytes = 1024 - - c_smem_size = ( - cute.cosize(self.c_smem_layout_staged.outer) if self.use_tma_store else 0 - ) - - # Define shared storage for kernel - @cute.struct - class SharedStorage: - ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] - ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] - acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] - tmem_dealloc_mbar_ptr: cutlass.Int64 - tmem_holding_buf: cutlass.Int32 - # (EPI_TILE_M, EPI_TILE_N, STAGE) - sC: cute.struct.Align[ - cute.struct.MemRange[ - self.c_dtype, - c_smem_size, - ], - self.buffer_align_bytes, - ] - # (MMA, MMA_M, MMA_K, STAGE) - sA: cute.struct.Align[ - cute.struct.MemRange[ - self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) - ], - self.buffer_align_bytes, - ] - # (MMA, MMA_N, MMA_K, STAGE) - sB: cute.struct.Align[ - cute.struct.MemRange[ - self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) - ], - self.buffer_align_bytes, - ] - - self.shared_storage = SharedStorage - # Launch the kernel synchronously self.kernel( tiled_mma, @@ -551,11 +511,17 @@ class DenseGemmKernel: # # Alloc and init: a+b full/empty, accumulator full, tensor memory dealloc barrier # - smem = utils.SmemAllocator() - storage = smem.allocate(self.shared_storage) + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_acc_stage * 2 + ] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 - tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr - tmem_holding_buf = storage.tmem_holding_buf + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) # Initialize mainloop ab_pipeline (barrier) and states ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) @@ -563,25 +529,19 @@ class DenseGemmKernel: ab_pipeline_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, num_tma_producer ) - ab_pipeline = pipeline.PipelineTmaUmma.create( + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), num_stages=self.num_ab_stage, producer_group=ab_pipeline_producer_group, consumer_group=ab_pipeline_consumer_group, tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cluster_layout_vmnk, - ) - ab_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.num_ab_stage - ) - ab_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_ab_stage - ) + ).make_participants() # Initialize acc_pipeline (barrier) and states acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) acc_pipeline_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta + pipeline.Agent.Thread, self.threads_per_cta ) acc_pipeline = pipeline.PipelineUmmaAsync.create( barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), @@ -597,15 +557,16 @@ class DenseGemmKernel: pipeline.PipelineUserType.Consumer, self.num_acc_stage ) + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=0, num_threads=self.threads_per_cta + ) # Tensor memory dealloc barrier init - if use_2cta_instrs: - if warp_idx == 0: - num_tmem_dealloc_threads = 32 - with cute.arch.elect_one(): - cute.arch.mbarrier_init( - tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads - ) - cute.arch.mbarrier_init_fence() + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) # Cluster arrive after barrier init if cute.size(self.cluster_shape_mn) > 1: @@ -615,20 +576,28 @@ class DenseGemmKernel: # Setup smem tensor A/B/C # # (EPI_TILE_M, EPI_TILE_N, STAGE) - sC = ( - storage.sC.get_tensor( - c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + sC = None + if cutlass.const_expr(self.use_tma_store): + sC = smem.allocate_tensor( + element_type=self.c_dtype, + layout=c_smem_layout_staged.outer, + byte_alignment=128, + swizzle=c_smem_layout_staged.inner, ) - if self.use_tma_store - else None - ) + # (MMA, MMA_M, MMA_K, STAGE) - sA = storage.sA.get_tensor( - a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + sA = smem.allocate_tensor( + element_type=self.a_dtype, + layout=a_smem_layout_staged.outer, + byte_alignment=128, + swizzle=a_smem_layout_staged.inner, ) # (MMA, MMA_N, MMA_K, STAGE) - sB = storage.sB.get_tensor( - b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + sB = smem.allocate_tensor( + element_type=self.b_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, ) # @@ -720,56 +689,16 @@ class DenseGemmKernel: if cute.size(self.cluster_shape_mn) > 1: cute.arch.cluster_wait() - # # Alloc tensor memory buffer - # - if warp_idx == 0: - cute.arch.alloc_tmem( - self.num_tmem_alloc_cols, tmem_holding_buf, is_two_cta=use_2cta_instrs - ) + tmem.allocate(self.num_tmem_alloc_cols) - # - # Bar sync for retrieve tensor memory ptr from shared memory - # - cute.arch.barrier() + # Barrier before retrieve tensor memory ptr from shared memory + tmem.wait_for_alloc() - # - # Retrieving tensor memory ptr and make accumulator tensor - # - tmem_ptr = cute.arch.retrieve_tmem_ptr( - self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf - ) + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) # (MMA, MMA_M, MMA_N) tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) - # - # Partition for epilogue - # - tiled_copy_t2r, tTR_tAcc, tTR_rAcc = self.epilog_tmem_copy_and_partition( - tidx, tCtAcc, tCgC, epi_tile, use_2cta_instrs - ) - - tTR_rC = None - tiled_copy_r2s = None - simt_atom = None - tRS_rC = None - tRS_sC = None - bSG_sC = None - bSG_gC = None - tTR_gC = None - if cutlass.const_expr(self.use_tma_store): - tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) - tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( - tiled_copy_t2r, tTR_rC, tidx, sC - ) - tma_atom_c, bSG_sC, bSG_gC = self.epilog_gmem_copy_and_partition( - tidx, tma_atom_c, tCgC, epi_tile, sC - ) - else: - simt_atom, tTR_rC, tTR_gC = self.epilog_gmem_copy_and_partition( - tidx, tiled_copy_t2r, tCgC, epi_tile, sC - ) - # # Slice to per mma tile index # @@ -777,123 +706,91 @@ class DenseGemmKernel: tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] # ((atom_v, rest_v), RestK) tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] - if cutlass.const_expr(self.use_tma_store): - # ((ATOM_V, REST_V), EPI_M, EPI_N) - bSG_gC = bSG_gC[(None, None, None, *mma_tile_coord_mnl)] - else: - # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) - tTR_gC = tTR_gC[(None, None, None, None, None, *mma_tile_coord_mnl)] # # Pipelining TMA load A/B and MMA mainloop # prefetch_k_tile_cnt = cutlass.min(self.num_ab_stage - 2, k_tile_cnt) - if warp_idx == 0: - # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt - peek_ab_empty_status = cutlass.Boolean(1) - if ab_producer_state.count < k_tile_cnt: - peek_ab_empty_status = ab_pipeline.producer_try_acquire( - ab_producer_state - ) # # Prefetch TMA load A/B # - for prefetch_idx in cutlass.range(prefetch_k_tile_cnt, unroll=1): + for k_tile_idx in cutlass.range(prefetch_k_tile_cnt, unroll=1): # Conditionally wait for AB buffer empty - ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + producer_handle = ab_producer.acquire_and_advance() # TMA load A/B cute.copy( tma_atom_a, - tAgA[(None, ab_producer_state.count)], - tAsA[(None, ab_producer_state.index)], - tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + tAgA[(None, k_tile_idx)], + tAsA[(None, producer_handle.index)], + tma_bar_ptr=producer_handle.barrier, mcast_mask=a_full_mcast_mask, ) cute.copy( tma_atom_b, - tBgB[(None, ab_producer_state.count)], - tBsB[(None, ab_producer_state.index)], - tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + tBgB[(None, k_tile_idx)], + tBsB[(None, producer_handle.index)], + tma_bar_ptr=producer_handle.barrier, mcast_mask=b_full_mcast_mask, ) - # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 - ab_producer_state.advance() - peek_ab_empty_status = cutlass.Boolean(1) - if ab_producer_state.count < k_tile_cnt: - peek_ab_empty_status = ab_pipeline.producer_try_acquire( - ab_producer_state - ) + peek_ab_full_status = cutlass.Boolean(False) + if is_leader_cta: + peek_ab_full_status = ab_consumer.try_wait() - # Peek (try_wait) AB buffer full for k_tile = 0 - peek_ab_full_status = cutlass.Boolean(1) - if ab_consumer_state.count < k_tile_cnt and is_leader_cta: - peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + peek_ab_empty_status = ab_producer.try_acquire() # # MMA mainloop # - for k_tile in range(k_tile_cnt): + for k_tile_idx in cutlass.range(k_tile_cnt): # Conditionally wait for AB buffer empty - ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + if k_tile_idx < k_tile_cnt - prefetch_k_tile_cnt: + producer_handle = ab_producer.acquire_and_advance( + peek_ab_empty_status + ) - if ab_producer_state.count < k_tile_cnt: # TMA load A/B cute.copy( tma_atom_a, - tAgA[(None, ab_producer_state.count)], - tAsA[(None, ab_producer_state.index)], - tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + tAgA[(None, producer_handle.count)], + tAsA[(None, producer_handle.index)], + tma_bar_ptr=producer_handle.barrier, mcast_mask=a_full_mcast_mask, ) cute.copy( tma_atom_b, - tBgB[(None, ab_producer_state.count)], - tBsB[(None, ab_producer_state.index)], - tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + tBgB[(None, producer_handle.count)], + tBsB[(None, producer_handle.index)], + tma_bar_ptr=producer_handle.barrier, mcast_mask=b_full_mcast_mask, ) if is_leader_cta: # Conditionally wait for AB buffer full - ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) + consumer_handle = ab_consumer.wait_and_advance(peek_ab_full_status) # tCtAcc += tCrA * tCrB - num_kblocks = cute.size(tCrA, mode=[2]) - for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): - kblock_coord = (None, None, kblock_idx, ab_consumer_state.index) + num_kblks = cute.size(tCrA, mode=[2]) + for kblk_idx in cutlass.range(num_kblks, unroll_full=True): + kblk_crd = (None, None, kblk_idx, consumer_handle.index) cute.gemm( - tiled_mma, - tCtAcc, - tCrA[kblock_coord], - tCrB[kblock_coord], - tCtAcc, + tiled_mma, tCtAcc, tCrA[kblk_crd], tCrB[kblk_crd], tCtAcc ) # Enable accumulate on tCtAcc after first kblock tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Async arrive AB buffer empty - ab_pipeline.consumer_release(ab_consumer_state) + consumer_handle.release() # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 - ab_producer_state.advance() - peek_ab_empty_status = cutlass.Boolean(1) - if ab_producer_state.count < k_tile_cnt: - peek_ab_empty_status = ab_pipeline.producer_try_acquire( - ab_producer_state - ) + peek_ab_empty_status = ab_producer.try_acquire() # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 - ab_consumer_state.advance() - peek_ab_full_status = cutlass.Boolean(1) - if ab_consumer_state.count < k_tile_cnt: - if is_leader_cta: - peek_ab_full_status = ab_pipeline.consumer_try_wait( - ab_consumer_state - ) + peek_ab_full_status = ab_consumer.try_wait() # Async arrive accumulator buffer full if is_leader_cta: @@ -904,114 +801,39 @@ class DenseGemmKernel: # # Release tensor memory allocation lock - if warp_idx == 0: - cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + tmem.relinquish_alloc_permit() # Wait for accumulator buffer full acc_pipeline.consumer_wait(acc_consumer_state) - tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) if cutlass.const_expr(self.use_tma_store): - bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + assert tma_atom_c is not None and sC is not None + self.epilogue_tma_store( + tidx, + warp_idx, + mma_tile_coord_mnl, # type: ignore + tma_atom_c, + tCtAcc, + sC, + tCgC, + epi_tile, + epilogue_op, + ) + else: - tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC)) - - c_pipeline = None - if cutlass.const_expr(self.use_tma_store): - # Initialize tma store c_pipeline - c_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta - ) - c_pipeline = pipeline.PipelineTmaStore.create( - num_stages=self.num_c_stage, - producer_group=c_producer_group, - ) - - # - # Store accumulator to global memory in subtiles - # - subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) - for subtile_idx in range(subtile_cnt): - # - # Load accumulator from tensor memory buffer to register - # - tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] - cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) - - if cutlass.const_expr(self.use_tma_store): - # - # Perform epilogue op on accumulator and convert to C type - # - acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() - acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) - tRS_rC.store(acc_vec) - - # - # Store C to shared memory - # - c_buffer = subtile_idx % self.num_c_stage - cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]) - # Fence and barrier to make sure shared memory store is visible to TMA store - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, - ) - cute.arch.barrier() - - # - # TMA store C to global memory - # - if warp_idx == 0: - cute.copy( - tma_atom_c, - bSG_sC[(None, c_buffer)], - bSG_gC[(None, subtile_idx)], - ) - # Fence and barrier to make sure TMA store is completed to recollect C buffer - c_pipeline.producer_commit() - c_pipeline.producer_acquire() - cute.arch.barrier() - else: - # - # Perform epilogue op on accumulator and convert to C type - # - acc_vec = tTR_rAcc.load() - acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) - tTR_rC.store(acc_vec) - - # - # Store C to global memory - # - cute.copy(simt_atom, tTR_rC, tTR_gC[(None, None, None, subtile_idx)]) + self.epilogue(tidx, mma_tile_coord_mnl, tCtAcc, tCgC, epi_tile, epilogue_op) # type: ignore # # Dealloc the tensor memory buffer # - cute.arch.barrier() - if warp_idx == 0: - if use_2cta_instrs: - cute.arch.mbarrier_arrive( - tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 - ) - cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) - cute.arch.dealloc_tmem( - tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs - ) - - # - # Wait for C store complete - # - if cutlass.const_expr(self.use_tma_store): - c_pipeline.producer_tail() + pipeline.sync(barrier_id=1) + tmem.free(tmem_ptr) # # Wait A/B buffer empty # if warp_idx == 0: - # Reverse prefetch_k_tile_cnt times to next available buffer - for i in range(prefetch_k_tile_cnt): - ab_producer_state.reverse() - ab_pipeline.producer_tail(ab_producer_state) + ab_producer.tail() return def epilog_tmem_copy_and_partition( @@ -1023,7 +845,8 @@ class DenseGemmKernel: use_2cta_instrs: Union[cutlass.Boolean, bool], ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: """ - Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) + and register array (destination). :param tidx: The thread index in epilogue warp groups :type tidx: cutlass.Int32 @@ -1052,10 +875,7 @@ class DenseGemmKernel: use_2cta_instrs, ) # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N) - tAcc_epi = cute.flat_divide( - tAcc[((None, None), 0, 0)], - epi_tile, - ) + tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0)], epi_tile) # (EPI_TILE_M, EPI_TILE_N) tiled_copy_t2r = tcgen05.make_tmem_copy( copy_atom_t2r, tAcc_epi[(None, None, 0, 0)] @@ -1072,7 +892,7 @@ class DenseGemmKernel: # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) # (T2R, T2R_M, T2R_N) - tTR_rAcc = cute.make_fragment( + tTR_rAcc = cute.make_rmem_tensor( tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype ) return tiled_copy_t2r, tTR_tAcc, tTR_rAcc @@ -1085,7 +905,8 @@ class DenseGemmKernel: sC: cute.Tensor, ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: """ - Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + Make tiledCopy for shared memory store, then use it to partition register array (source) + and shared memory (destination). :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) :type tiled_copy_t2r: cute.TiledCopy @@ -1113,69 +934,185 @@ class DenseGemmKernel: tRS_rC = tiled_copy_r2s.retile(tTR_rC) return tiled_copy_r2s, tRS_rC, tRS_sC - def epilog_gmem_copy_and_partition( + @cute.jit + def epilogue_tma_store( self, - tidx: cutlass.Int32, - atom: Union[cute.CopyAtom, cute.TiledCopy], - gC_mnl: cute.Tensor, - epi_tile: cute.Tile, + epi_tidx: cutlass.Int32, + warp_idx: cutlass.Int32, + mma_tile_coord_mnl: Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32], + tma_atom_c: cute.CopyAtom, + tCtAcc: cute.Tensor, sC: cute.Tensor, - ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: - """Make tiledCopy for global memory store, then use it to: - - partition register array (source) and global memory (destination) for none TMA store version; - - partition shared memory (source) and global memory (destination) for TMA store version. - - :param tidx: The thread index in epilogue warp groups - :type tidx: cutlass.Int32 - :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version - :type atom: cute.CopyAtom or cute.TiledCopy - :param gC_mnl: The global tensor C - :type gC_mnl: cute.Tensor - :param epi_tile: The epilogue tiler - :type epi_tile: cute.Tile - :param sC: The shared memory tensor to be copied and partitioned - :type sC: cute.Tensor - - :return: A tuple containing either: - - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: - - tma_atom_c: The TMA copy atom - - bSG_sC: The partitioned shared memory tensor C - - bSG_gC: The partitioned global tensor C - - For non-TMA store: (simt_atom, tTR_rC, tTR_gC) where: - - simt_atom: The SIMT copy atom - - tTR_rC: The register tensor C - - tTR_gC: The partitioned global tensor C - :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + tCgC: cute.Tensor, + epi_tile: cute.Tile, + epilogue_op: cutlass.Constexpr, + ) -> None: """ - # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) - gC_epi = cute.flat_divide( - gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + Epilogue implementation for TMA store version. + + :param epi_tidx: Thread index + :type epi_tidx: cutlass.Int32 + :param warp_idx: Warp index + :type warp_idx: cutlass.Int32 + :param tCtAcc: Partitioned accumulator tensor + :type tCtAcc: cute.Tensor + :param sC: Shared memory C tensor + :type sC: cute.Tensor + :param tCgC: Global memory C tensor + :type tCgC: cute.Tensor + :param epi_tile: Epilogue tile + :type epi_tile: cute.Tile + :param epilogue_op: Epilogue operation + :type epilogue_op: cutlass.Constexpr + :param tma_atom_c: TMA atom for C tensor + :type tma_atom_c: cute.CopyAtom + """ + tiled_copy_t2r, tTR_tAcc, tTR_rAcc = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc, tCgC, epi_tile, self.use_2cta_instrs ) - if cutlass.const_expr(self.use_tma_store): - tma_atom_c = atom - sC_for_tma_partition = cute.group_modes(sC, 0, 2) - gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) - # ((ATOM_V, REST_V), EPI_M, EPI_N) - # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) - bSG_sC, bSG_gC = cpasync.tma_partition( - tma_atom_c, - 0, - cute.make_layout(1), - sC_for_tma_partition, - gC_for_tma_partition, + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + # ((ATOM_V, REST_V), EPI_M, EPI_N) + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + tCgC_epi = cute.flat_divide( + tCgC[((None, None), 0, 0, None, None, None)], epi_tile + ) + + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + cute.group_modes(sC, 0, 2), + cute.group_modes(tCgC_epi, 0, 2), + ) + + bSG_gC = bSG_gC[(None, None, None, *mma_tile_coord_mnl)] + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # Initialize tma store c_pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, producer_group=c_producer_group + ) + + # + # Store accumulator to global memory in sub-tiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Perform epilogue op on accumulator and convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + c_buffer = subtile_idx % self.num_c_stage + cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, ) - return tma_atom_c, bSG_sC, bSG_gC - else: - tiled_copy_t2r = atom - # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) - thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) - tTR_gC = thr_copy_t2r.partition_D(gC_epi) - # (T2R, T2R_M, T2R_N) - tTR_rC = cute.make_fragment( - tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype - ) - simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) - return simt_atom, tTR_rC, tTR_gC + pipeline.sync(barrier_id=1) + + # TMA store C to global memory + if warp_idx == 0: + cute.copy( + tma_atom_c, bSG_sC[(None, c_buffer)], bSG_gC[(None, subtile_idx)] + ) + # Fence and barrier to make sure TMA store is completed to recollect C buffer + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + pipeline.sync(barrier_id=1) + + # Wait for C store complete + c_pipeline.producer_tail() + + @cute.jit + def epilogue( + self, + epi_tidx: cutlass.Int32, + mma_tile_coord_mnl: Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32], + tCtAcc: cute.Tensor, + tCgC: cute.Tensor, + epi_tile: cute.Tile, + epilogue_op: cutlass.Constexpr, + ) -> None: + """ + Epilogue implementation for non-TMA store version. + + :param epi_tidx: Thread index + :type epi_tidx: cutlass.Int32 + :param tCtAcc: Partitioned accumulator tensor + :type tCtAcc: cute.Tensor + :param tCgC: Global memory C tensor + :type tCgC: cute.Tensor + :param epi_tile: Epilogue tile + :type epi_tile: cute.Tile + :param epilogue_op: Epilogue operation + :type epilogue_op: cutlass.Constexpr + """ + tiled_copy_t2r, tTR_tAcc, tTR_rAcc = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc, tCgC, epi_tile, self.use_2cta_instrs + ) + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + tCgC_epi = cute.flat_divide( + tCgC[((None, None), 0, 0, None, None, None)], epi_tile + ) + + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx) + tTR_gC = thr_copy_t2r.partition_D(tCgC_epi) + # (T2R, T2R_M, T2R_N) + tTR_rC = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype + ) + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) + + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_gC = tTR_gC[(None, None, None, None, None, *mma_tile_coord_mnl)] + tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC)) + + # + # Store accumulator to global memory in sub-tiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Perform epilogue op on accumulator and convert to C type + # + acc_vec = tTR_rAcc.load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tTR_rC.store(acc_vec) + + # Store C to global memory + cute.copy(simt_atom, tTR_rC, tTR_gC[(None, None, None, subtile_idx)]) @staticmethod def _compute_stages( @@ -1326,27 +1263,21 @@ class DenseGemmKernel: tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) return sm100_utils.get_num_tmem_alloc_cols(tCtAcc_fake) - @staticmethod def is_valid_dtypes( - ab_dtype: Type[cutlass.Numeric], - acc_dtype: Type[cutlass.Numeric], - c_dtype: Type[cutlass.Numeric], + self, ab_dtype: Type[cutlass.Numeric], c_dtype: Type[cutlass.Numeric] ) -> bool: """ Check if the dtypes are valid :param ab_dtype: The data type of the A and B operands :type ab_dtype: Type[cutlass.Numeric] - :param acc_dtype: The data type of the accumulator - :type acc_dtype: Type[cutlass.Numeric] :param c_dtype: The data type of the output tensor :type c_dtype: Type[cutlass.Numeric] :return: True if the dtypes are valid, False otherwise :rtype: bool """ - is_valid = True - if ab_dtype not in { + valid_ab_dtypes = { cutlass.Float16, cutlass.BFloat16, cutlass.TFloat32, @@ -1354,21 +1285,36 @@ class DenseGemmKernel: cutlass.Int8, cutlass.Float8E4M3FN, cutlass.Float8E5M2, - }: - is_valid = False - if ( - acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32} - or acc_dtype == cutlass.Float16 - and ab_dtype - not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2} - or acc_dtype == cutlass.Int32 - and ab_dtype not in {cutlass.Uint8, cutlass.Int8} - ): - is_valid = False - if ( - acc_dtype == cutlass.Float32 - and c_dtype - not in { + } + if ab_dtype not in valid_ab_dtypes: + return False + + if self.acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32}: + return False + + # Define compatibility mapping between accumulator type and AB type + acc_ab_compatibility = { + cutlass.Float32: { + cutlass.Float16, + cutlass.BFloat16, + cutlass.TFloat32, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, # Float32 accumulator supports floating point AB types only + cutlass.Float16: { + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, + cutlass.Int32: {cutlass.Uint8, cutlass.Int8}, + } + # Check compatibility between accumulator type and AB type + if ab_dtype not in acc_ab_compatibility[self.acc_dtype]: + return False + + # Define compatibility mapping between accumulator type and C type + acc_c_compatibility = { + cutlass.Float32: { cutlass.Float32, cutlass.Float16, cutlass.BFloat16, @@ -1377,42 +1323,28 @@ class DenseGemmKernel: cutlass.Int32, cutlass.Int8, cutlass.Uint8, - } - or acc_dtype == cutlass.Float16 - and c_dtype - not in { + }, + cutlass.Float16: { cutlass.BFloat16, cutlass.Float16, - } - or acc_dtype == cutlass.Int32 - and c_dtype - not in { + }, + cutlass.Int32: { cutlass.BFloat16, cutlass.Float16, cutlass.Float32, cutlass.Int32, cutlass.Int8, cutlass.Uint8, - } - ): - is_valid = False - return is_valid + }, + } + # Check compatibility between accumulator type and C type + if c_dtype not in acc_c_compatibility[self.acc_dtype]: + return False - @staticmethod - def is_valid_mma_tiler_and_cluster_shape( - use_2cta_instrs: bool, - mma_tiler_mn: Tuple[int, int], - cluster_shape_mn: Tuple[int, int], - ) -> bool: - """ - Check if the mma tiler and cluster shape are valid + return True - :param use_2cta_instrs: Whether to use 2 CTA groups - :type use_2cta_instrs: bool - :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler - :type mma_tiler_mn: Tuple[int, int] - :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster - :type cluster_shape_mn: Tuple[int, int] + def is_valid_mma_tiler_and_cluster_shape(self) -> bool: + """Check if the mma tiler and cluster shape are valid. :return: True if the mma tiler and cluster shape are valid, False otherwise :rtype: bool @@ -1420,29 +1352,29 @@ class DenseGemmKernel: is_valid = True # Skip invalid mma tile shape if not ( - (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) - or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + (not self.use_2cta_instrs and self.mma_tiler_mn[0] in [64, 128]) + or (self.use_2cta_instrs and self.mma_tiler_mn[0] in [128, 256]) ): is_valid = False - if mma_tiler_mn[1] not in range(32, 257, 32): + if self.mma_tiler_mn[1] not in range(32, 257, 32): is_valid = False # Skip illegal cluster shape - if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + if self.cluster_shape_mn[0] % (2 if self.use_2cta_instrs else 1) != 0: is_valid = False # Skip invalid cluster shape is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 if ( - cluster_shape_mn[0] * cluster_shape_mn[1] > 16 - or cluster_shape_mn[0] <= 0 - or cluster_shape_mn[1] <= 0 - or not is_power_of_2(cluster_shape_mn[0]) - or not is_power_of_2(cluster_shape_mn[1]) + self.cluster_shape_mn[0] * self.cluster_shape_mn[1] > 16 + or self.cluster_shape_mn[0] <= 0 + or self.cluster_shape_mn[1] <= 0 + or not is_power_of_2(self.cluster_shape_mn[0]) + or not is_power_of_2(self.cluster_shape_mn[1]) ): is_valid = False return is_valid - @staticmethod def is_valid_tensor_alignment( + self, m: int, n: int, k: int, @@ -1480,41 +1412,29 @@ class DenseGemmKernel: """ is_valid = True - def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + # TODO: move to utils + def check_contiguous_16B_alignment(dtype, is_mode0_major, tensor_shape): major_mode_idx = 0 if is_mode0_major else 1 num_major_elements = tensor_shape[major_mode_idx] num_contiguous_elements = 16 * 8 // dtype.width return num_major_elements % num_contiguous_elements == 0 if ( - not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) - or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) - or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + not check_contiguous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contiguous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contiguous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) ): is_valid = False return is_valid - @staticmethod - def is_valid_epilog_store_option( - use_2cta_instrs: bool, - use_tma_store: bool, - m: int, - n: int, - mma_tiler_mn: Tuple[int, int], - ) -> bool: + def is_valid_epilog_store_option(self, m: int, n: int) -> bool: """ Check if the epilogue store option is valid - :param use_2cta_instrs: Whether to use 2 CTA groups - :type use_2cta_instrs: bool - :param use_tma_store: Whether to use TMA store - :type use_tma_store: bool :param m: The number of rows in the A tensor :type m: int :param n: The number of columns in the B tensor :type n: int - :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler - :type mma_tiler_mn: Tuple[int, int] :return: True if the epilogue store option is valid, False otherwise :rtype: bool @@ -1523,89 +1443,106 @@ class DenseGemmKernel: is_valid = True # None TMA store version does not have predication, can not support OOB tiles cta_tile_shape_mn = ( - mma_tiler_mn[0] // (2 if use_2cta_instrs else 1), - mma_tiler_mn[1], + self.mma_tiler_mn[0] // (2 if self.use_2cta_instrs else 1), + self.mma_tiler_mn[1], ) - if not use_tma_store: + if not self.use_tma_store: if not (m % cta_tile_shape_mn[0] == 0 and n % cta_tile_shape_mn[1] == 0): is_valid = False return is_valid - @staticmethod - def can_implement( - ab_dtype: Type[cutlass.Numeric], - acc_dtype: Type[cutlass.Numeric], - c_dtype: Type[cutlass.Numeric], - use_2cta_instrs: bool, - mma_tiler_mn: Tuple[int, int], - cluster_shape_mn: Tuple[int, int], - use_tma_store: bool, - m: int, - n: int, - k: int, - l: int, - a_major: str, - b_major: str, - c_major: str, - ) -> bool: - """ - Check if the gemm can be implemented + def can_implement(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor) -> bool: + """Check if the given tensors can be implemented by this kernel. - :param ab_dtype: The data type of the A and B operands - :type ab_dtype: Type[cutlass.Numeric] - :param acc_dtype: The data type of the accumulator - :type acc_dtype: Type[cutlass.Numeric] - :param c_dtype: The data type of the output tensor - :type c_dtype: Type[cutlass.Numeric] - :param use_2cta_instrs: Whether to use 2 CTA groups - :type use_2cta_instrs: bool - :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler - :type mma_tiler_mn: Tuple[int, int] - :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster - :type cluster_shape_mn: Tuple[int, int] - :param use_tma_store: Whether to use TMA store - :type use_tma_store: bool - :param m: The number of rows in the A tensor - :type m: int - :param n: The number of columns in the B tensor - :type n: int - :param k: The number of columns in the A tensor - :type k: int - :param l: The number of columns in the C tensor - :type l: int - :param a_major: The major axis of the A tensor - :type a_major: str - :param b_major: The major axis of the B tensor - :type b_major: str - :param c_major: The major axis of the C tensor - :type c_major: str + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor - :return: True if the gemm can be implemented, False otherwise + :return: True if the gemm supports the given config, False otherwise :rtype: bool """ + m, n, k, l = a.shape[0], b.shape[0], a.shape[1], a.shape[2] + + # infer a_major, b_major, c_major + is_m_major_a = utils.LayoutEnum.from_tensor(a).is_m_major_a() + is_n_major_b = utils.LayoutEnum.from_tensor(b).is_n_major_b() + is_m_major_c = utils.LayoutEnum.from_tensor(c).is_m_major_c() + a_major = "m" if is_m_major_a else "k" + b_major = "n" if is_n_major_b else "k" + c_major = "m" if is_m_major_c else "n" + can_implement = True # Skip unsupported types - if not DenseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype): + if not self.is_valid_dtypes(a.element_type, c.element_type): can_implement = False # Skip invalid mma tile shape and cluster shape - if not DenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( - use_2cta_instrs, mma_tiler_mn, cluster_shape_mn - ): + if not self.is_valid_mma_tiler_and_cluster_shape(): can_implement = False # Skip illegal problem shape for load/store alignment - if not DenseGemmKernel.is_valid_tensor_alignment( - m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + if not self.is_valid_tensor_alignment( + m, n, k, l, a.element_type, c.element_type, a_major, b_major, c_major ): can_implement = False # Skip invalid epilogue store option - if not DenseGemmKernel.is_valid_epilog_store_option( - use_2cta_instrs, use_tma_store, m, n, mma_tiler_mn - ): + if not self.is_valid_epilog_store_option(m, n): can_implement = False + return can_implement -def run_dense_gemm( +def create_tensors(l, m, n, k, a_major, b_major, c_major, ab_dtype, c_dtype): + torch.manual_seed(1111) + + a_torch_cpu = cutlass_torch.matrix(l, m, k, a_major == "m", ab_dtype) + b_torch_cpu = cutlass_torch.matrix(l, n, k, b_major == "n", ab_dtype) + c_torch_cpu = cutlass_torch.matrix(l, m, n, c_major == "m", c_dtype) + + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like( + c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + return ( + a_tensor, + b_tensor, + c_tensor, + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + c_torch_gpu, + ) + + +def compare(a_torch_cpu, b_torch_cpu, c_torch_gpu, c_dtype, tolerance): + # Copy gpu result back + kernel_result = c_torch_gpu.cpu() + + # Compute reference result + ref = torch.einsum( + "mkl,nkl->mnl", + a_torch_cpu.to(dtype=torch.float32), + b_torch_cpu.to(dtype=torch.float32), + ) + + # Convert ref to c_dtype + _, ref_torch_gpu = cutlass_torch.cute_tensor_like( + ref, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + ref_result = ref_torch_gpu.cpu() + + # Assert close results + torch.testing.assert_close(kernel_result, ref_result, atol=tolerance, rtol=1e-05) + + +def run( mnkl: Tuple[int, int, int, int], ab_dtype: Type[cutlass.Numeric], c_dtype: Type[cutlass.Numeric], @@ -1613,20 +1550,60 @@ def run_dense_gemm( a_major: str, b_major: str, c_major: str, - mma_tiler_mn: Tuple[int, int], - cluster_shape_mn: Tuple[int, int], - use_2cta_instrs: bool, - use_tma_store: bool, - tolerance: float, + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Tuple[int, int] = (2, 1), + use_2cta_instrs: bool = True, + use_tma_store: bool = True, + tolerance: float = 1e-01, warmup_iterations: int = 0, iterations: int = 1, skip_ref_check: bool = False, - measure_launch_overhead=False, + use_cold_l2: bool = False, + **kwargs, ): + """Execute a batched dense GEMM operation on Blackwell architecture with performance benchmarking. + + This function prepares input tensors, configures and launches the GEMM kernel, + optionally performs reference validation, and benchmarks the execution performance. + + :param mnkl: Problem size (M, N, K, L) + :type mnkl: Tuple[int, int, int, int] + :param ab_dtype: Data type for input tensors A and B + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: Data type for output tensor C + :type c_dtype: Type[cutlass.Numeric] + :param acc_dtype: Data type for accumulation during matrix multiplication + :type acc_dtype: Type[cutlass.Numeric] + :param a_major/b_major/c_major: Memory layout of tensor A/B/C + :type a_major/b_major/c_major: str + :param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the + default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters. + :type mma_tiler_mn: Tuple[int, int], optional + :param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the + default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters. + :type cluster_shape_mn: Tuple[int, int], optional + :param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner + will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters. + :type use_2cta_instrs: bool, optional + :param use_tma_store: Whether to use TMA store. If not specified in the decorator parameters, the autotuner will use + the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters. + :type use_tma_store: bool, optional + :param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01 + :type tolerance: float, optional + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 1 + :type iterations: int, optional + :param skip_ref_check: Whether to skip reference result validation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :raises RuntimeError: If CUDA GPU is not available + :raises ValueError: If the configuration is invalid or unsupported by the kernel + :return: Execution time of the GEMM kernel + :rtype: float """ - Prepare A/B/C tensors, launch GPU kernel, and reference checking. - """ - print(f"Running B100 Dense GEMM test with:") + print("Running Blackwell Dense GEMM test with:") print(f"mnkl: {mnkl}") print(f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}") print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") @@ -1637,168 +1614,78 @@ def run_dense_gemm( print(f"Warmup iterations: {warmup_iterations}") print(f"Iterations: {iterations}") print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") # Unpack parameters m, n, k, l = mnkl - # Skip unsupported testcase - if not DenseGemmKernel.can_implement( - ab_dtype, - acc_dtype, - c_dtype, - use_2cta_instrs, - mma_tiler_mn, - cluster_shape_mn, - use_tma_store, - m, - n, - k, - l, - a_major, - b_major, - c_major, - ): - raise TypeError( - f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {use_tma_store}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" - ) - if not torch.cuda.is_available(): raise RuntimeError("GPU is required to run this example!") - torch.manual_seed(1111) + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) - # Create and permute tensor A/B/C - def create_and_permute_tensor( - l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True - ): - # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) - # else: (l, mode0, mode1) -> (mode0, mode1, l) - shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) - permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) - is_unsigned = dtype in {cutlass.Uint8} - # Temporarily use uint8 as torch does not support fp8 type - torch_dtype = ( - cutlass_torch.dtype(dtype) - if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} - else torch.uint8 - ) - - # Create dtype torch tensor (cpu) - torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( - shape, - torch_dtype, - permute_order=permute_order, - init_type=cutlass_torch.TensorInitType.RANDOM, - init_config=cutlass_torch.RandomInitConfig( - min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 - ), - ) - # Create dtype torch tensor (gpu) - torch_tensor = torch_tensor_cpu.cuda() - - # Create f32 torch tensor (cpu) - f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) - - # Create dtype cute tensor (gpu) - cute_tensor = from_dlpack(torch_tensor, assumed_align=16) - cute_tensor.element_type = dtype - if is_dynamic_layout: - cute_tensor = cute_tensor.mark_layout_dynamic( - leading_dim=(0 if is_mode0_major else 1) - ) - cute_tensor = cutlass_torch.convert_cute_tensor( - f32_torch_tensor, - cute_tensor, - dtype, - is_dynamic_layout=is_dynamic_layout, - ) - - return f32_torch_tensor, cute_tensor, torch_tensor - - a_ref, a_tensor, a_torch = create_and_permute_tensor( - l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True - ) - b_ref, b_tensor, b_torch = create_and_permute_tensor( - l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True - ) - c_ref, c_tensor, c_torch = create_and_permute_tensor( - l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True + a_tensor, b_tensor, c_tensor, a_torch_cpu, b_torch_cpu, c_torch_cpu, c_torch_gpu = ( + create_tensors(l, m, n, k, a_major, b_major, c_major, ab_dtype, c_dtype) ) - # Configure gemm kernel + # Build GEMM object gemm = DenseGemmKernel( - acc_dtype, - use_2cta_instrs, - mma_tiler_mn, - cluster_shape_mn, - use_tma_store, + acc_dtype, use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, use_tma_store ) - torch_stream = torch.cuda.Stream() - stream = cuda.CUstream(torch_stream.cuda_stream) - # Compile gemm kernel - compiled_gemm = cute.compile(gemm, a_tensor, b_tensor, c_tensor, stream) - - # Launch GPU kernel - # Warm up - for i in range(warmup_iterations): - compiled_gemm(a_tensor, b_tensor, c_tensor, stream) - # Execution - for i in range(iterations): - compiled_gemm(a_tensor, b_tensor, c_tensor, stream) - - # Compute reference result - if not skip_ref_check: - if ab_dtype in { - cutlass.Int8, - cutlass.Uint8, - cutlass.Float8E4M3FN, - cutlass.Float8E5M2, - }: - ref = torch.einsum("mkl,nkl->mnl", a_ref.cpu(), b_ref.cpu()) - else: - ref = (torch.einsum("mkl,nkl->mnl", a_ref, b_ref)).cpu() - - # Copy gpu result back - gpu_c = c_torch.cpu() - - # Convert ref to c_type - if c_dtype == cutlass.Float32: - ref_c = ref - elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: - # m major: (l, n, m) -> (m, n, l) - # n major: (l, m, n) -> (m, n, l) - permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) - shape = (l, m, n) if c_major == "n" else (l, n, m) - f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( - shape, - torch.uint8, - permute_order=permute_order, - init_type=cutlass_torch.TensorInitType.SKIP, - ).cuda() - # Create dtype cute tensor (gpu) - ref_c_tensor = from_dlpack( - f8_torch_tensor, assumed_align=16 - ).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) - ref_c_tensor.element_type = c_dtype - ref_c_tensor = cutlass_torch.convert_cute_tensor( - ref, - ref_c_tensor, - c_dtype, - is_dynamic_layout=True, - ) - - ref_c = f8_torch_tensor.cpu() - else: - ref_c = ref.to(cutlass_torch.dtype(c_dtype)) - - # Reference checking ref_c and gpu_c - torch.testing.assert_close( - gpu_c, - ref_c, - atol=tolerance, - rtol=1e-05, + # Check if configuration can be implemented + can_implement = gemm.can_implement(a_tensor, b_tensor, c_tensor) + if not can_implement: + raise ValueError( + f"The current config which is invalid/unsupported: use_2cta_instrs = {use_2cta_instrs}, " + f"mma_tiler_mn = {mma_tiler_mn}, cluster_shape_mn = {cluster_shape_mn}, " + f"use_tma_store = {use_tma_store}" ) + max_active_clusters = utils.HardwareInfo().get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + compiled_gemm = cute.compile(gemm, a_tensor, b_tensor, c_tensor, current_stream) + + if not skip_ref_check: + compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream) + compare(a_torch_cpu, b_torch_cpu, c_torch_gpu, c_dtype, tolerance) + + def generate_tensors(): + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, _ = cutlass_torch.cute_tensor_like( + c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + return testing.JitArguments(a_tensor, b_tensor, c_tensor, current_stream) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch_cpu.numel() * a_torch_cpu.element_size() + + b_torch_cpu.numel() * b_torch_cpu.element_size() + + c_torch_cpu.numel() * c_torch_cpu.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds if __name__ == "__main__": @@ -1806,15 +1693,12 @@ if __name__ == "__main__": def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: try: return tuple(int(x.strip()) for x in s.split(",")) - # or: return tuple([int(x.strip()) for x in s.split(",")]) except ValueError: raise argparse.ArgumentTypeError( "Invalid format. Expected comma-separated integers." ) - parser = argparse.ArgumentParser( - description="Example of MxNxKxL GEMM on Blackwell." - ) + parser = argparse.ArgumentParser(description="Example of Dense GEMM on Blackwell.") parser.add_argument( "--mnkl", @@ -1826,7 +1710,7 @@ if __name__ == "__main__": "--mma_tiler_mn", type=parse_comma_separated_ints, default=(128, 128), - help="Mma tiler (comma-separated)", + help="Mma tile shape (comma-separated)", ) parser.add_argument( "--cluster_shape_mn", @@ -1854,10 +1738,21 @@ if __name__ == "__main__": parser.add_argument( "--warmup_iterations", type=int, default=0, help="Warmup iterations" ) - parser.add_argument("--iterations", type=int, default=1, help="Iterations") + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) parser.add_argument( "--skip_ref_check", action="store_true", help="Skip reference checking" ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) args = parser.parse_args() @@ -1870,7 +1765,7 @@ if __name__ == "__main__": if len(args.cluster_shape_mn) != 2: parser.error("--cluster_shape_mn must contain exactly 2 values") - run_dense_gemm( + run( args.mnkl, args.ab_dtype, args.c_dtype, @@ -1886,5 +1781,6 @@ if __name__ == "__main__": args.warmup_iterations, args.iterations, args.skip_ref_check, + args.use_cold_l2, ) print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_alpha_beta_persistent.py b/examples/python/CuTeDSL/blackwell/dense_gemm_alpha_beta_persistent.py new file mode 100644 index 00000000..c6262f91 --- /dev/null +++ b/examples/python/CuTeDSL/blackwell/dense_gemm_alpha_beta_persistent.py @@ -0,0 +1,2218 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Optional, Tuple, Type, Union + +import torch +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.nvgpu import cpasync, tcgen05 + + +""" +A high-performance persistent batched dense GEMM (D = alpha * A * B + beta * C) example for the NVIDIA Blackwell SM100 architecture +using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") +- Matrix D is MxNxL, L is batch dimension, D can be row-major("N") or column-major("M") +- alpha and beta are float scalars + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions) + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. EPILOGUE warp: + - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Load C matrix from global memory (GMEM) to shared memory (SMEM) using TMA operations and then copied to registers (RMEM). + - Compute D = alpha * accumulator + beta * C. + - Type convert D matrix to output type. + - Store D matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations, + - Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor: + e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0)) + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +Input arguments to this example is same as dense_gemm.py. + +.. code-block:: bash + + python examples/blackwell/dense_gemm_alpha_beta_persistent.py \ + --ab_dtype Float16 --c_dtype Float16 --d_dtype Float16 --acc_dtype Float32 --epi_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_2cta_instrs --alpha 2.0 --beta 1.0 + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/dense_gemm_alpha_beta_persistent.py \ + --ab_dtype Float16 --c_dtype Float16 --d_dtype Float16 --acc_dtype Float32 --epi_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_2cta_instrs --alpha 2.0 --beta 1.0 \ + --warmup_iterations 1 --iterations 10 --skip_ref_check + + +Constraints are same as dense_gemm.py: +* Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2), + see detailed valid dtype combinations in below SM100PersistentDenseGemmAlphaBetaKernel class documentation +* A/B tensor must have the same data type +* C/D tensor must have the same major order +* Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) +* Mma tiler N must be 32-256, step 32 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 if use_2cta_instrs=True +* The contiguous dimension of A/B/C/D tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32, + Float16/BFloat16, and Int8/Uint8/Float8, respectively. +* OOB tiles are not allowed when TMA store is disabled +""" + + +class SM100PersistentDenseGemmAlphaBetaKernel: + """This class implements batched matrix multiplication (D = alpha * A * B + beta * C) with support for various data types + and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param epi_dtype: Data type for epilogue operation + :type epi_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported A/B data types: + - TFloat32 + - Float16/BFloat16 + - Int8/Uint8 + - Float8E4M3FN/Float8E5M2 + + :note: Supported accumulator data types: + - Float32 (for all floating point A/B data types) + - Float16 (only for fp16 and fp8 A/B data types) + - Int32 (only for uint8/int8 A/B data types) + + :note: Supported C/D data types: + - Float32 (for float32 and int32 accumulator data types) + - Int32 (for float32 and int32 accumulator data types) + - Float16/BFloat16 (for fp16 and fp8 accumulator data types) + - Int8/Uint8 (for uint8/int8 accumulator data types) + - Float8E4M3FN/Float8E5M2 (for float32 accumulator data types) + + :note: Constraints: + - MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) + - MMA tiler N must be 32-256, step 32 + - Cluster shape M must be multiple of 2 if use_2cta_instrs=True + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = SM100PersistentDenseGemmAlphaBetaKernel( + ... acc_dtype=cutlass.Float32, + ... epi_dtype=cutlass.Float32, + ... use_2cta_instrs=True, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2), + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, d_tensor, alpha, beta, max_active_clusters, stream) + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + epi_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ): + """Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param epi_dtype: Data type of the epilogue. + :type epi_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.epi_dtype: Type[cutlass.Numeric] = epi_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler_mn = mma_tiler_mn + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = ( + tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_ids = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.epilog_load_warp_id = 6 + self.threads_per_cta = 32 * len( + ( + self.mma_warp_id, + self.tma_warp_id, + *self.epilog_warp_ids, + self.epilog_load_warp_id, + ) + ) + # Set barrier id for cta sync, epilogue sync and tmem ptr sync + self.cta_sync_bar_id = 1 + self.epilog_sync_bar_id = 2 + self.tmem_alloc_sync_bar_id = 3 + self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C/D stage counts in shared memory + - Computing A/B/C/D shared memory layout + - Computing tensor memory allocation columns + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + layout_d=self.cd_layout, + elem_ty_d=self.d_dtype, + layout_c=self.cd_layout, + elem_ty_c=self.c_dtype, + ) + + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_c_stage, + self.num_d_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.d_dtype, + self.cd_layout, + self.num_smem_capacity, + self.occupancy, + ) + + # Compute A/B/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.cd_layout, + self.epi_tile, + self.num_c_stage, + ) + self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.d_dtype, + self.cd_layout, + self.epi_tile, + self.num_d_stage, + ) + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( + tiled_mma, self.mma_tiler, self.num_acc_stage + ) + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + d: cute.Tensor, + alpha: cutlass.Float32, + beta: cutlass.Float32, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Input tensor C + :type c: cute.Tensor + :param d: Output tensor D + :type d: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param alpha: Scalar multiplier for the matrix product of A and B + :type alpha: cutlass.Float32 + :param beta: Scalar multiplier for the matrix C + :type beta: cutlass.Float32 + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + :raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.d_dtype: Type[cutlass.Numeric] = d.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.cd_layout = utils.LayoutEnum.from_tensor(c) + if cutlass.const_expr(self.cd_layout != utils.LayoutEnum.from_tensor(d)): + raise ValueError("C and D must have the same layout.") + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if a.element_type is cutlass.Float32 else None + ), + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup TMA for C/D + tma_atom_d = None + tma_tensor_d = None + tma_atom_c = None + tma_tensor_c = None + self.tma_c_load_bytes = 0 + d_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0)) + tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + d, + d_smem_layout, + self.epi_tile, + ) + + c_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + self.tma_c_load_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + c, + c_smem_layout, + self.epi_tile, + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + d, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + + c_smem_size = cute.cosize(self.c_smem_layout_staged.outer) + d_smem_size = cute.cosize(self.d_smem_layout_staged.outer) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + c_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_c_stage] + c_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_c_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sD: cute.struct.Align[ + cute.struct.MemRange[ + self.d_dtype, + d_smem_size, + ], + self.buffer_align_bytes, + ] + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + c_smem_size, + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + tma_atom_d, + tma_tensor_d, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.d_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + alpha, + beta, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: cute.Tensor, + tma_atom_d: Optional[cute.CopyAtom], + mD_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + alpha: cutlass.Float32, + beta: cutlass.Float32, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_c) + cpasync.prefetch_descriptor(tma_atom_d) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_ids) * ( + 2 if use_2cta_instrs else 1 + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Load C pipeline + c_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + c_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len(self.epilog_warp_ids), + ) + c_pipeline = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.c_full_mbar_ptr.data_ptr(), + num_stages=self.num_c_stage, + producer_group=c_producer_group, + consumer_group=c_consumer_group, + tx_count=self.tma_c_load_bytes, + ) + + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_ids)), + ) + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_ids[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/C/D + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + sD = storage.sD.get_tensor( + d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, loopM, loopK, loopL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, loopN, loopK, loopL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, loopM, loopN, loopL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + # (bM, bN, loopM, loopN, loopL) + gD_mnl = cute.local_tile( + mD_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C/D + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) + tCgC = thr_mma.partition_C(gC_mnl) + # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) + tCgD = thr_mma.partition_C(gD_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), tiles_m, tiles_k, tiles_l) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), tiles_n, tiles_k, tiles_l) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C/D + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # Named barriers + # + cta_sync_barrier = pipeline.NamedBarrier( + self.cta_sync_bar_id, self.threads_per_cta + ) + epilog_sync_barrier = pipeline.NamedBarrier( + self.epilog_sync_bar_id, 32 * len(self.epilog_warp_ids) + ) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + cta_sync_barrier.arrive_and_wait() + + # + # Specialized TMA load warp + # + + if warp_idx == self.tma_warp_id: + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), loopK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), loopK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status + ) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA_slice[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # + # Mma mainloop + # + for k_tile in range(k_tile_cnt): + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait( + ab_consumer_state, peek_ab_full_status + ) + + # tCtAcc += tCrA * tCrB + num_k_blocks = cute.size(tCrA, mode=[2]) + for k_block_idx in cutlass.range( + num_k_blocks, unroll_full=True + ): + k_block_coord = ( + None, + None, + k_block_idx, + ab_consumer_state.index, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[k_block_coord], + tCrB[k_block_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first k_block + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Async arrive accumulator buffer full + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # + # Alloc tensor memory buffer + # + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + epi_tidx = tidx + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgD, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_s2r = None + simt_atom_c = None + tSR_rC = None + tSR_sC = None + tTR_gC_partitioned = None + + tTR_rD = None + tiled_copy_r2s = None + simt_atom_d = None + tRS_rD = None + tRS_sD = None + bSG_sD = None + bSG_gD_partitioned = None + tTR_gD_partitioned = None + + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + (tiled_copy_s2r, tSR_rC, tSR_sC) = self.epilog_smem_copy_and_partition_load( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + tTR_rD = cute.make_rmem_tensor(tTR_rAcc.shape, self.d_dtype) + ( + tiled_copy_r2s, + tRS_rD, + tRS_sD, + ) = self.epilog_smem_copy_and_partition_store( + tiled_copy_t2r, tTR_rD, epi_tidx, sD + ) + ( + tma_atom_d, + bSG_sD, + bSG_gD_partitioned, + ) = self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_d, tCgD, epi_tile, sD, self.d_dtype + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + # Store D pipeline + d_pipeline = None + c_pipeline_consumer_state = None + d_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_ids), + ) + d_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_d_stage, + producer_group=d_producer_group, + ) + + c_pipeline_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_c_stage + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + tTR_gC = None + bSG_gD = None + tTR_gD = None + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gD = bSG_gD_partitioned[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # + # Wait for accumulator buffer full + # + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + # Wait for C load to complete + c_pipeline.consumer_wait(c_pipeline_consumer_state) + + # Load C from shared memory to register + cute.copy( + tiled_copy_s2r, + tSR_sC[(None, None, None, c_pipeline_consumer_state.index)], + tSR_rC, + ) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + c_pipeline.consumer_release(c_pipeline_consumer_state) + + # Advance pipeline states + c_pipeline_consumer_state.advance() + + # + # Perform epilogue op on accumulator and convert to D type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + c_vec_load = tiled_copy_r2s.retile(tSR_rC).load() + d_vec = epilogue_op( + ( + alpha.to(self.epi_dtype) * acc_vec.to(self.epi_dtype) + + beta.to(self.epi_dtype) * c_vec_load.to(self.epi_dtype) + ) + ).to(self.d_dtype) + tRS_rD.store(d_vec) + + # + # Store C to shared memory + # + d_buffer = (num_prev_subtiles + subtile_idx) % self.num_d_stage + cute.copy( + tiled_copy_r2s, tRS_rD, tRS_sD[(None, None, None, d_buffer)] + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + epilog_sync_barrier.arrive_and_wait() + + # + # TMA store D to global memory + # + if warp_idx == self.epilog_warp_ids[0]: + cute.copy( + tma_atom_d, + bSG_sD[(None, d_buffer)], + bSG_gD[(None, subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + d_pipeline.producer_commit() + d_pipeline.producer_acquire() + epilog_sync_barrier.arrive_and_wait() + + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + epilog_sync_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + # + # Wait for D store complete + # + d_pipeline.producer_tail() + + # + # Specialized epilog load warp + # + if warp_idx == self.epilog_load_warp_id: + ( + tma_atom_c, + bGS_sC, + bGS_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition( + tidx, tma_atom_c, tCgC, epi_tile, sC, self.c_dtype + ) + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + c_pipeline_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_c_stage + ) + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + bGS_gC = bGS_gC_partitioned[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + bGS_gC = cute.group_modes(bGS_gC, 1, cute.rank(bGS_gC)) + subtile_cnt = cute.size(bGS_gC.shape, mode=[1]) + for subtile_idx in cutlass.range(subtile_cnt): + # Load C from global memory to shared memory using TMA load + c_pipeline.producer_acquire(c_pipeline_producer_state) + cute.copy( + tma_atom_c, + bGS_gC[(None, subtile_idx)], + bGS_sC[(None, c_pipeline_producer_state.index)], + tma_bar_ptr=c_pipeline.producer_get_barrier( + c_pipeline_producer_state + ), + ) + c_pipeline_producer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait C buffer empty + # + c_pipeline.producer_tail(c_pipeline_producer_state) + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.cd_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N) + gC_mnl_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, 0, 0, 0)], epi_tile) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition_load( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory load, then use it to partition register array (destination) and shared memory (source). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tiled_copy_s2r, tSR_rC, tSR_sC) where: + - tiled_copy_s2r: The tiled copy operation for smem to register copy(s2r) + - tSR_rC: The partitioned tensor C (register destination) + - tSR_sC: The partitioned tensor C (smem source) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_s2r = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) + tiled_copy_s2r = cute.make_tiled_copy_D(copy_atom_s2r, tiled_copy_t2r) + # (S2R, S2R_M, S2R_N, PIPE_C) + thr_copy_s2r = tiled_copy_s2r.get_slice(tidx) + tSR_sC = thr_copy_s2r.partition_D(sC) + # (S2R, S2R_M, S2R_N) + tSR_rC = tiled_copy_s2r.retile(tTR_rC) + return tiled_copy_s2r, tSR_rC, tSR_sC + + def epilog_smem_copy_and_partition_store( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.cd_layout, self.d_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + dtype: Type[cutlass.Numeric], + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing either: + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + - For non-TMA store: (simt_atom, tTR_rC, tTR_gC) where: + - simt_atom: The SIMT copy atom + - tTR_rC: The register tensor C + - tTR_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, tiles_m, tiles_n, tiles_l) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, tiles_m, tiles_n, tiles_l) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + d_dtype: Type[cutlass.Numeric], + cd_layout: utils.LayoutEnum, + num_smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C/D operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (input). + :type c_dtype: type[cutlass.Numeric] + :param d_dtype: Data type of operand D (output). + :type d_dtype: type[cutlass.Numeric] + :param cd_layout: Layout of operand C/D in global memory. + :type cd_layout: utils.LayoutEnum + :param num_smem_capacity: Total available shared memory capacity in bytes. + :type num_smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 2 + + # Default C stages + num_c_stage = 2 + + # Default D stages + num_d_stage = 2 + + # Calculate smem layout and size for one stage of A, B, C, and D + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + cd_layout, + epi_tile, + 1, + ) + d_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + d_dtype, + cd_layout, + epi_tile, + 1, + ) + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one) + d_bytes = d_bytes_per_stage * num_d_stage + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C/D stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + num_smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes + d_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + num_d_stage += ( + num_smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes + d_bytes) + ) // (occupancy * d_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage, num_d_stage + + @staticmethod + def _compute_grid( + d: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param d: The output tensor D + :type d: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + d_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gd = cute.zipped_divide(d, tiler=d_shape) + num_ctas_mnl = gd[(0, (None, None, None))].shape + + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _compute_num_tmem_alloc_cols( + tiled_mma: cute.TiledMma, + mma_tiler: Tuple[int, int, int], + num_acc_stage: int, + ) -> int: + """ + Compute the number of tensor memory allocation columns. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler: The shape (M, N, K) of the MMA tile. + :type mma_tiler: tuple[int, int, int] + :param num_acc_stage: The stage of the accumulator tensor. + :type num_acc_stage: int + + :return: The number of tensor memory allocation columns. + :rtype: int + """ + acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage)) + num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake) + + return num_tmem_alloc_cols + + def is_valid_dtypes( + self, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + valid_ab_dtypes = { + cutlass.Float16, + cutlass.BFloat16, + cutlass.TFloat32, + cutlass.Uint8, + cutlass.Int8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + } + if ab_dtype not in valid_ab_dtypes: + return False + + if self.acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32}: + return False + + # Define compatibility mapping between accumulator type and AB type + acc_ab_compatibility = { + cutlass.Float32: { + cutlass.Float16, + cutlass.BFloat16, + cutlass.TFloat32, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, # Float32 accumulator supports floating point AB types only + cutlass.Float16: { + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, + cutlass.Int32: {cutlass.Uint8, cutlass.Int8}, + } + # Check compatibility between accumulator type and AB type + if ab_dtype not in acc_ab_compatibility[self.acc_dtype]: + return False + + # Define compatibility mapping between accumulator type and C type + acc_c_compatibility = { + cutlass.Float32: { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + }, + cutlass.Float16: { + cutlass.BFloat16, + cutlass.Float16, + }, + cutlass.Int32: { + cutlass.BFloat16, + cutlass.Float16, + cutlass.Float32, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + }, + } + # Check compatibility between accumulator type and C type + if c_dtype not in acc_c_compatibility[self.acc_dtype]: + return False + + return True + + def is_valid_mma_tiler_and_cluster_shape(self) -> bool: + """Check if the mma tiler and cluster shape are valid. + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not ( + (not self.use_2cta_instrs and self.mma_tiler_mn[0] in [64, 128]) + or (self.use_2cta_instrs and self.mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + if self.mma_tiler_mn[1] not in range(32, 257, 32): + is_valid = False + # Skip illegal cluster shape + if self.cluster_shape_mn[0] % (2 if self.use_2cta_instrs else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + self.cluster_shape_mn[0] * self.cluster_shape_mn[1] > 16 + or self.cluster_shape_mn[0] <= 0 + or self.cluster_shape_mn[1] <= 0 + or not is_power_of_2(self.cluster_shape_mn[0]) + or not is_power_of_2(self.cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + def is_valid_tensor_alignment( + self, + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + d_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + cd_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the C tensor + :type c_dtype: Type[cutlass.Numeric] + :param d_dtype: The data type of the D tensor + :type d_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param cd_major: The major axis of the C/D tensor + :type cd_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, cd_major == "m", (m, n, l)) + or not check_contigous_16B_alignment(d_dtype, cd_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + def can_implement( + self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, d: cute.Tensor + ) -> bool: + """Check if the given tensors can be implemented by this kernel. + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + + :return: True if the gemm supports the given config, False otherwise + :rtype: bool + """ + m, n, k, l = a.shape[0], b.shape[0], a.shape[1], a.shape[2] + + # infer a_major, b_major, cd_major + is_m_major_a = utils.LayoutEnum.from_tensor(a).is_m_major_a() + is_n_major_b = utils.LayoutEnum.from_tensor(b).is_n_major_b() + is_m_major_c = utils.LayoutEnum.from_tensor(c).is_m_major_c() + a_major = "m" if is_m_major_a else "k" + b_major = "n" if is_n_major_b else "k" + cd_major = "m" if is_m_major_c else "n" + + can_implement = True + # Skip unsupported types + if not self.is_valid_dtypes(a.element_type, c.element_type): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not self.is_valid_mma_tiler_and_cluster_shape(): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not self.is_valid_tensor_alignment( + m, + n, + k, + l, + a.element_type, + c.element_type, + d.element_type, + a_major, + b_major, + cd_major, + ): + can_implement = False + + return can_implement + + +def create_tensors(l, m, n, k, a_major, b_major, cd_major, ab_dtype, c_dtype, d_dtype): + torch.manual_seed(1111) + + a_torch_cpu = cutlass_torch.matrix(l, m, k, a_major == "m", ab_dtype) + b_torch_cpu = cutlass_torch.matrix(l, n, k, b_major == "n", ab_dtype) + c_torch_cpu = cutlass_torch.matrix(l, m, n, cd_major == "m", c_dtype) + d_torch_cpu = cutlass_torch.matrix(l, m, n, cd_major == "m", d_dtype) + + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, _ = cutlass_torch.cute_tensor_like( + c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + d_tensor, d_torch_gpu = cutlass_torch.cute_tensor_like( + d_torch_cpu, d_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + return ( + a_tensor, + b_tensor, + c_tensor, + d_tensor, + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + d_torch_gpu, + ) + + +def run( + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + d: cute.Tensor, + stream: cuda.CUstream, + alpha: float, + beta: float, + acc_dtype: Type[cutlass.Numeric] = cutlass.Float32, + epi_dtype: Type[cutlass.Numeric] = cutlass.Float32, + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Tuple[int, int] = (2, 1), + use_2cta_instrs: bool = True, + warmup_iterations: int = 0, + iterations: int = 1, +): + """Run the gemm kernel utility function. It will return the compiled gemm kernel function and its execution time. + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param stream: CUDA stream + :type stream: cuda.CUstream + :param acc_dtype: Accumulator data type, defaults to cutlass.Float32 + :type acc_dtype: cutlass.DataType, optional + :param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the + default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters. + :type mma_tiler_mn: Tuple[int, int], optional + :param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the + default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters. + :type cluster_shape_mn: Tuple[int, int], optional + :param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner + will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters. + :type use_2cta_instrs: bool, optional + + :return: Compiled GEMM kernel function and its execution time + :rtype: Callable + """ + # Build GEMM object + gemm = SM100PersistentDenseGemmAlphaBetaKernel( + acc_dtype, + epi_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + ) + + # Check if configuration can be implemented + can_implement = gemm.can_implement(a, b, c, d) + if not can_implement: + raise ValueError( + f"The current config which is invalid/unsupported: use_2cta_instrs = {use_2cta_instrs}, " + f"mma_tiler_mn = {mma_tiler_mn}, cluster_shape_mn = {cluster_shape_mn}" + ) + max_active_clusters = utils.HardwareInfo().get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + compiled_gemm = cute.compile( + gemm, a, b, c, d, alpha, beta, max_active_clusters, stream + ) + exec_time = testing.benchmark( + compiled_gemm, + kernel_arguments=testing.JitArguments(a, b, c, d, alpha, beta, stream), + stream=stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + return compiled_gemm, exec_time + + +def compare( + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + d_torch_gpu, + d_dtype, + epi_dtype, + alpha, + beta, + tolerance, +): + # Copy gpu result back + kernel_result = d_torch_gpu.cpu() + + # Compute reference result + ref = torch.einsum( + "mkl,nkl->mnl", + a_torch_cpu.to(dtype=torch.float32), + b_torch_cpu.to(dtype=torch.float32), + ) + + torch_epi_dtype = cutlass_torch.dtype(epi_dtype) + torch_alpha = torch.tensor(alpha, dtype=torch_epi_dtype) + torch_beta = torch.tensor(beta, dtype=torch_epi_dtype) + ref_d_epi_dtype = torch_alpha * ref.to( + dtype=torch_epi_dtype + ) + torch_beta * c_torch_cpu.to(dtype=torch_epi_dtype) + + # Convert ref to d_dtype + _, ref_torch_gpu = cutlass_torch.cute_tensor_like( + ref_d_epi_dtype, d_dtype, is_dynamic_layout=True, assumed_align=16 + ) + ref_result = ref_torch_gpu.cpu() + + # Assert close results + torch.testing.assert_close(kernel_result, ref_result, atol=tolerance, rtol=1e-05) + + +def run_dense_gemm( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + d_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + epi_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + cd_major: str, + alpha: float, + beta: float, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_2cta_instrs: bool, + tolerance: float, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, +): + """ + Prepare A/B/C/D tensors, launch GPU kernel, and reference checking. + """ + print("Running Blackwell Persistent Dense GEMM test with:") + print(f"mnkl: {mnkl}") + print( + f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, D dtype: {d_dtype}, Acc dtype: {acc_dtype}, Epi dtype: {epi_dtype}" + ) + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {cd_major}, D: {cd_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + + # Unpack parameters + m, n, k, l = mnkl + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + ( + a_tensor, + b_tensor, + c_tensor, + d_tensor, + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + d_torch_gpu, + ) = create_tensors( + l, m, n, k, a_major, b_major, cd_major, ab_dtype, c_dtype, d_dtype + ) + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + compiled_gemm, exec_time = run( + a_tensor, + b_tensor, + c_tensor, + d_tensor, + current_stream, + alpha, + beta, + acc_dtype, + epi_dtype, + mma_tiler_mn, + cluster_shape_mn, + use_2cta_instrs, + warmup_iterations, + iterations, + ) + + print(f"Execution time: {exec_time} us") + + # Compute reference result + if not skip_ref_check: + compare( + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + d_torch_gpu, + d_dtype, + epi_dtype, + alpha, + beta, + tolerance, + ) + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of Dense Persistent GEMM on Blackwell." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(256, 256, 512, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.TFloat32) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument("--d_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument("--epi_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument("--alpha", type=float, default=1.0, help="alpha scale factor") + parser.add_argument("--beta", type=float, default=0.0, help="beta scale factor") + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--cd_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run_dense_gemm( + args.mnkl, + args.ab_dtype, + args.c_dtype, + args.d_dtype, + args.acc_dtype, + args.epi_dtype, + args.a_major, + args.b_major, + args.cd_major, + args.alpha, + args.beta, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.use_2cta_instrs, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + ) + print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py b/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py index f5022a82..2ff8ff26 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py @@ -27,21 +27,19 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import argparse -from typing import Optional, Type, Tuple, Union +from typing import Optional, Tuple, Type, Union -import cuda.bindings.driver as cuda import torch +import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute -from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.cute.testing as testing import cutlass.torch as cutlass_torch import cutlass.utils as utils import cutlass.pipeline as pipeline -import cutlass.cute.testing as testing import cutlass.utils.blackwell_helpers as sm100_utils -from cutlass.cute.runtime import from_dlpack - +from cutlass.cute.nvgpu import cpasync, tcgen05 """ A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture @@ -111,6 +109,84 @@ Constraints are same as dense_gemm.py: """ +def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + smem_capacity: int, + occupancy: int, + use_tma_store: bool, + c_smem_layout: Union[cute.Layout, None], +) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + :param use_tma_store: Whether TMA store is enabled. + :type use_tma_store: bool + :param c_smem_layout: Layout of C operand in shared memory, or None if not using TMA store. + :type c_smem_layout: Union[cute.Layout, None] + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 2 + + # Default C stages + num_c_stage = 2 if use_tma_store else 0 + + # Calculate smem layout and size for one stage of A, B, and C with 1-stage + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, mma_tiler_mnk, a_dtype, 1 + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, mma_tiler_mnk, b_dtype, 1 + ) + + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + mbar_helpers_bytes = 1024 + + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + if use_tma_store: + num_c_stage += ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes) + ) // (occupancy * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage + + class PersistentDenseGemmKernel: """This class implements batched matrix multiplication (C = A x B) with support for various data types and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. @@ -203,6 +279,7 @@ class PersistentDenseGemmKernel: self.use_2cta_instrs = use_2cta_instrs self.cluster_shape_mn = cluster_shape_mn # K dimension is deferred in _setup_attributes + self.mma_tiler_mn = mma_tiler_mn self.mma_tiler = (*mma_tiler_mn, 1) self.use_tma_store = use_tma_store @@ -212,21 +289,16 @@ class PersistentDenseGemmKernel: self.occupancy = 1 # Set specialized warp ids - self.epilog_warp_id = ( - 0, - 1, - 2, - 3, - ) + self.epilog_warp_id = (0, 1, 2, 3) self.mma_warp_id = 4 self.tma_warp_id = 5 self.threads_per_cta = 32 * len( (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) ) # Set barrier id for cta sync, epilogue sync and tmem ptr sync - self.cta_sync_bar_id = 0 self.epilog_sync_bar_id = 1 - self.tmem_ptr_sync_bar_id = 2 + self.tmem_alloc_sync_bar_id = 2 + self.tmem_dealloc_sync_bar_id = 3 self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") def _setup_attributes(self): @@ -290,43 +362,38 @@ class PersistentDenseGemmKernel: else: self.epi_tile = self.cta_tile_shape_mnk[:2] + c_smem_layout = None + if cutlass.const_expr(self.use_tma_store): + c_smem_layout = sm100_utils.make_smem_layout_epi( + self.c_dtype, self.c_layout, self.epi_tile, 1 + ) + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory - self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = _compute_stages( tiled_mma, self.mma_tiler, self.a_dtype, self.b_dtype, - self.epi_tile, self.c_dtype, - self.c_layout, self.smem_capacity, self.occupancy, self.use_tma_store, + c_smem_layout, ) # Compute A/B/C shared memory layout self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( - tiled_mma, - self.mma_tiler, - self.a_dtype, - self.num_ab_stage, + tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage ) self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( - tiled_mma, - self.mma_tiler, - self.b_dtype, - self.num_ab_stage, + tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage ) - self.c_smem_layout_staged = ( - sm100_utils.make_smem_layout_epi( - self.c_dtype, - self.c_layout, - self.epi_tile, - self.num_c_stage, + + self.c_smem_layout_staged = None + if self.use_tma_store: + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage ) - if cutlass.const_expr(self.use_tma_store) - else None - ) # Compute the number of tensor memory allocation columns self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( @@ -432,15 +499,9 @@ class PersistentDenseGemmKernel: tma_atom_c = None tma_tensor_c = None if cutlass.const_expr(self.use_tma_store): - c_cta_v_layout = cute.composition( - cute.make_identity_layout(c.shape), self.epi_tile - ) - epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + epi_smem_layout = cute.select(self.c_smem_layout_staged, mode=[0, 1]) tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( - cpasync.CopyBulkTensorTileS2GOp(), - c, - epi_smem_layout, - c_cta_v_layout, + cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem_layout, self.epi_tile ) # Compute grid size @@ -448,48 +509,6 @@ class PersistentDenseGemmKernel: c, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters ) - self.buffer_align_bytes = 1024 - - c_smem_size = ( - cute.cosize(self.c_smem_layout_staged.outer) - if cutlass.const_expr(self.use_tma_store) - else 0 - ) - - # Define shared storage for kernel - @cute.struct - class SharedStorage: - ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] - ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] - acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] - acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] - tmem_dealloc_mbar_ptr: cutlass.Int64 - tmem_holding_buf: cutlass.Int32 - # (EPI_TILE_M, EPI_TILE_N, STAGE) - sC: cute.struct.Align[ - cute.struct.MemRange[ - self.c_dtype, - c_smem_size, - ], - self.buffer_align_bytes, - ] - # (MMA, MMA_M, MMA_K, STAGE) - sA: cute.struct.Align[ - cute.struct.MemRange[ - self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) - ], - self.buffer_align_bytes, - ] - # (MMA, MMA_N, MMA_K, STAGE) - sB: cute.struct.Align[ - cute.struct.MemRange[ - self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) - ], - self.buffer_align_bytes, - ] - - self.shared_storage = SharedStorage - # Launch the kernel synchronously self.kernel( tiled_mma, @@ -498,7 +517,7 @@ class PersistentDenseGemmKernel: tma_atom_b, tma_tensor_b, tma_atom_c, - tma_tensor_c if cutlass.const_expr(self.use_tma_store) else c, + tma_tensor_c if self.use_tma_store else c, self.cluster_layout_vmnk, self.a_smem_layout_staged, self.b_smem_layout_staged, @@ -569,11 +588,18 @@ class PersistentDenseGemmKernel: # # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier # - smem = utils.SmemAllocator() - storage = smem.allocate(self.shared_storage) + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_acc_stage * 2 + ] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 - tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr - tmem_holding_buf = storage.tmem_holding_buf + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) # Initialize mainloop ab_pipeline (barrier) and states ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) @@ -581,14 +607,14 @@ class PersistentDenseGemmKernel: ab_pipeline_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, num_tma_producer ) - ab_pipeline = pipeline.PipelineTmaUmma.create( + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), num_stages=self.num_ab_stage, producer_group=ab_pipeline_producer_group, consumer_group=ab_pipeline_consumer_group, tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cluster_layout_vmnk, - ) + ).make_participants() # Initialize acc_pipeline (barrier) and states acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) @@ -606,15 +632,24 @@ class PersistentDenseGemmKernel: cta_layout_vmnk=cluster_layout_vmnk, ) + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)), + ) + tmem_dealloc_barrier = None + if cutlass.const_expr(not self.use_tma_store): + tmem_dealloc_barrier = pipeline.NamedBarrier( + barrier_id=self.tmem_dealloc_sync_bar_id, + num_threads=32 * len(self.epilog_warp_id), + ) # Tensor memory dealloc barrier init - if use_2cta_instrs: - if warp_idx == self.tma_warp_id: - num_tmem_dealloc_threads = 32 - with cute.arch.elect_one(): - cute.arch.mbarrier_init( - tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads - ) - cute.arch.mbarrier_init_fence() + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) # Cluster arrive after barrier init if cute.size(self.cluster_shape_mn) > 1: @@ -623,21 +658,19 @@ class PersistentDenseGemmKernel: # # Setup smem tensor A/B/C # - # (EPI_TILE_M, EPI_TILE_N, STAGE) - sC = ( - storage.sC.get_tensor( - c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner - ) - if cutlass.const_expr(self.use_tma_store) - else None - ) # (MMA, MMA_M, MMA_K, STAGE) - sA = storage.sA.get_tensor( - a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + sA = smem.allocate_tensor( + element_type=self.a_dtype, + layout=a_smem_layout_staged.outer, + byte_alignment=128, + swizzle=a_smem_layout_staged.inner, ) # (MMA, MMA_N, MMA_K, STAGE) - sB = storage.sB.get_tensor( - b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + sB = smem.allocate_tensor( + element_type=self.b_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, ) # @@ -731,9 +764,7 @@ class PersistentDenseGemmKernel: if cute.size(self.cluster_shape_mn) > 1: cute.arch.cluster_wait() else: - cute.arch.barrier( - barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta - ) + cute.arch.sync_threads() # # Specialized TMA load warp @@ -748,10 +779,6 @@ class PersistentDenseGemmKernel: ) work_tile = tile_sched.initial_work_tile_info() - ab_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.num_ab_stage - ) - while work_tile.is_valid_tile: # Get tile coord from tile scheduler cur_tile_coord = work_tile.tile_idx @@ -774,44 +801,36 @@ class PersistentDenseGemmKernel: ] # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt - ab_producer_state.reset_count() - peek_ab_empty_status = cutlass.Boolean(1) - if ab_producer_state.count < k_tile_cnt: - peek_ab_empty_status = ab_pipeline.producer_try_acquire( - ab_producer_state - ) + ab_producer.reset() + peek_ab_empty_status = ab_producer.try_acquire() + # # Tma load loop # for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): # Conditionally wait for AB buffer empty - ab_pipeline.producer_acquire( - ab_producer_state, peek_ab_empty_status - ) + handle = ab_producer.acquire_and_advance(peek_ab_empty_status) # TMA load A/B cute.copy( tma_atom_a, - tAgA_slice[(None, ab_producer_state.count)], - tAsA[(None, ab_producer_state.index)], - tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + tAgA_slice[(None, handle.count)], + tAsA[(None, handle.index)], + tma_bar_ptr=handle.barrier, mcast_mask=a_full_mcast_mask, ) cute.copy( tma_atom_b, - tBgB_slice[(None, ab_producer_state.count)], - tBsB[(None, ab_producer_state.index)], - tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + tBgB_slice[(None, handle.count)], + tBsB[(None, handle.index)], + tma_bar_ptr=handle.barrier, mcast_mask=b_full_mcast_mask, ) # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 - ab_producer_state.advance() peek_ab_empty_status = cutlass.Boolean(1) - if ab_producer_state.count < k_tile_cnt: - peek_ab_empty_status = ab_pipeline.producer_try_acquire( - ab_producer_state - ) + if handle.count + 1 < k_tile_cnt: + peek_ab_empty_status = ab_producer.try_acquire() # # Advance to next tile @@ -822,29 +841,17 @@ class PersistentDenseGemmKernel: # # Wait A/B buffer empty # - ab_pipeline.producer_tail(ab_producer_state) + ab_producer.tail() # # Specialized MMA warp # if warp_idx == self.mma_warp_id: - # - # Bar sync for retrieve tensor memory ptr from shared mem - # - tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) - cute.arch.barrier( - barrier_id=self.tmem_ptr_sync_bar_id, - number_of_threads=tmem_ptr_read_threads, - ) - # # Retrieving tensor memory ptr and make accumulator tensor # - tmem_ptr = cute.arch.retrieve_tmem_ptr( - self.acc_dtype, - alignment=16, - ptr_to_buffer_holding_addr=tmem_holding_buf, - ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) @@ -856,9 +863,6 @@ class PersistentDenseGemmKernel: ) work_tile = tile_sched.initial_work_tile_info() - ab_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_ab_stage - ) acc_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.num_acc_stage ) @@ -877,12 +881,10 @@ class PersistentDenseGemmKernel: tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] # Peek (try_wait) AB buffer full for k_tile = 0 - ab_consumer_state.reset_count() + ab_consumer.reset() peek_ab_full_status = cutlass.Boolean(1) - if ab_consumer_state.count < k_tile_cnt and is_leader_cta: - peek_ab_full_status = ab_pipeline.consumer_try_wait( - ab_consumer_state - ) + if is_leader_cta: + peek_ab_full_status = ab_consumer.try_wait() # # Wait for accumulator buffer empty @@ -901,41 +903,30 @@ class PersistentDenseGemmKernel: for k_tile in range(k_tile_cnt): if is_leader_cta: # Conditionally wait for AB buffer full - ab_pipeline.consumer_wait( - ab_consumer_state, peek_ab_full_status - ) + handle = ab_consumer.wait_and_advance(peek_ab_full_status) # tCtAcc += tCrA * tCrB num_kblocks = cute.size(tCrA, mode=[2]) - for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): - kblock_coord = ( - None, - None, - kblock_idx, - ab_consumer_state.index, - ) + for kblk_idx in cutlass.range(num_kblocks, unroll_full=True): + kblk_crd = (None, None, kblk_idx, handle.index) cute.gemm( tiled_mma, tCtAcc, - tCrA[kblock_coord], - tCrB[kblock_coord], + tCrA[kblk_crd], + tCrB[kblk_crd], tCtAcc, ) # Enable accumulate on tCtAcc after first kblock tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Async arrive AB buffer empty - ab_pipeline.consumer_release(ab_consumer_state) + handle.release() - # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 - ab_consumer_state.advance() - peek_ab_full_status = cutlass.Boolean(1) - if ab_consumer_state.count < k_tile_cnt: - if is_leader_cta: - peek_ab_full_status = ab_pipeline.consumer_try_wait( - ab_consumer_state - ) + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + peek_ab_full_status = cutlass.Boolean(1) + if handle.count + 1 < k_tile_cnt: + peek_ab_full_status = ab_consumer.try_wait() # # Async arrive accumulator buffer full @@ -954,6 +945,17 @@ class PersistentDenseGemmKernel: # Wait for accumulator buffer empty # acc_pipeline.producer_tail(acc_producer_state) + + sC = None + if cutlass.const_expr(self.use_tma_store): + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = smem.allocate_tensor( + element_type=self.c_dtype, + layout=c_smem_layout_staged.outer, + byte_alignment=128, + swizzle=c_smem_layout_staged.inner, + ) + # # Specialized epilogue warps # @@ -961,260 +963,310 @@ class PersistentDenseGemmKernel: # # Alloc tensor memory buffer # - if warp_idx == self.epilog_warp_id[0]: - cute.arch.alloc_tmem( - self.num_tmem_alloc_cols, - tmem_holding_buf, - is_two_cta=use_2cta_instrs, - ) - - # - # Bar sync for retrieve tensor memory ptr from shared memory - # - tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) - cute.arch.barrier( - barrier_id=self.tmem_ptr_sync_bar_id, - number_of_threads=tmem_ptr_read_threads, - ) + tmem.allocate(self.num_tmem_alloc_cols) # # Retrieving tensor memory ptr and make accumulator tensor # - tmem_ptr = cute.arch.retrieve_tmem_ptr( - self.acc_dtype, - alignment=16, - ptr_to_buffer_holding_addr=tmem_holding_buf, - ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) # - # Partition for epilogue - # - epi_tidx = tidx - ( - tiled_copy_t2r, - tTR_tAcc_base, - tTR_rAcc, - ) = self.epilog_tmem_copy_and_partition( - epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs - ) - - tTR_rC = None - tiled_copy_r2s = None - simt_atom = None - tRS_rC = None - tRS_sC = None - bSG_sC = None - bSG_gC_partitioned = None - tTR_gC_partitioned = None - if cutlass.const_expr(self.use_tma_store): - tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) - tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( - tiled_copy_t2r, tTR_rC, epi_tidx, sC - ) - ( - tma_atom_c, - bSG_sC, - bSG_gC_partitioned, - ) = self.epilog_gmem_copy_and_partition( - epi_tidx, tma_atom_c, tCgC, epi_tile, sC - ) - else: - ( - simt_atom, - tTR_rC, - tTR_gC_partitioned, - ) = self.epilog_gmem_copy_and_partition( - epi_tidx, tiled_copy_t2r, tCgC, epi_tile, sC - ) - - # - # Persistent tile scheduling loop + # Persistent tile scheduling loop for epilogue # tile_sched = utils.StaticPersistentTileScheduler.create( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) - work_tile = tile_sched.initial_work_tile_info() - acc_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_acc_stage - ) - - c_pipeline = None if cutlass.const_expr(self.use_tma_store): - # Threads/warps participating in tma store pipeline - c_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - 32 * len(self.epilog_warp_id), - 32 * len(self.epilog_warp_id), + assert tma_atom_c is not None and sC is not None + self.epilogue_tma_store( + tidx, + warp_idx, + acc_pipeline, + tiled_mma, + tma_atom_c, + tCtAcc_base, + sC, + tCgC, + epi_tile, + tile_sched, + epilogue_op, ) - c_pipeline = pipeline.PipelineTmaStore.create( - num_stages=self.num_c_stage, - producer_group=c_producer_group, + else: + self.epilogue( + tidx, + acc_pipeline, + tiled_mma, + tCtAcc_base, + tCgC, + epi_tile, + tile_sched, + epilogue_op, + tmem_dealloc_barrier, ) - while work_tile.is_valid_tile: - # Get tile coord from tile scheduler - cur_tile_coord = work_tile.tile_idx - mma_tile_coord_mnl = ( - cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), - cur_tile_coord[1], - cur_tile_coord[2], - ) - - # - # Slice to per mma tile index - # - bSG_gC = None - tTR_gC = None - if cutlass.const_expr(self.use_tma_store): - # ((ATOM_V, REST_V), EPI_M, EPI_N) - bSG_gC = bSG_gC_partitioned[ - ( - None, - None, - None, - *mma_tile_coord_mnl, - ) - ] - else: - # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) - tTR_gC = tTR_gC_partitioned[ - ( - None, - None, - None, - None, - None, - *mma_tile_coord_mnl, - ) - ] - - # Set tensor memory buffer for current tile - # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) - tTR_tAcc = tTR_tAcc_base[ - (None, None, None, None, None, acc_consumer_state.index) - ] - - # - # Wait for accumulator buffer full - # - acc_pipeline.consumer_wait(acc_consumer_state) - - tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) - if cutlass.const_expr(self.use_tma_store): - bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) - else: - tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC)) - - # - # Store accumulator to global memory in subtiles - # - subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) - num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt - for subtile_idx in cutlass.range(subtile_cnt): - # - # Load accumulator from tensor memory buffer to register - # - tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] - cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) - - if cutlass.const_expr(self.use_tma_store): - # - # Convert to C type - # - acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() - acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) - tRS_rC.store(acc_vec) - - # - # Store C to shared memory - # - c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage - cute.copy( - tiled_copy_r2s, - tRS_rC, - tRS_sC[(None, None, None, c_buffer)], - ) - # Fence and barrier to make sure shared memory store is visible to TMA store - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, - ) - epilog_threads = 32 * len(self.epilog_warp_id) - cute.arch.barrier( - barrier_id=self.epilog_sync_bar_id, - number_of_threads=epilog_threads, - ) - - # - # TMA store C to global memory - # - if warp_idx == self.epilog_warp_id[0]: - cute.copy( - tma_atom_c, - bSG_sC[(None, c_buffer)], - bSG_gC[(None, subtile_idx)], - ) - # Fence and barrier to make sure shared memory store is visible to TMA store - c_pipeline.producer_commit() - c_pipeline.producer_acquire() - cute.arch.barrier( - barrier_id=self.epilog_sync_bar_id, - number_of_threads=epilog_threads, - ) - else: - # - # Convert to C type - # - acc_vec = tTR_rAcc.load() - acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) - tTR_rC.store(acc_vec) - - # - # Store C to global memory - # - cute.copy( - simt_atom, tTR_rC, tTR_gC[(None, None, None, subtile_idx)] - ) - - # - # Async arrive accumulator buffer empty - # - with cute.arch.elect_one(): - acc_pipeline.consumer_release(acc_consumer_state) - acc_consumer_state.advance() - - # - # Advance to next tile - # - tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() - # # Dealloc the tensor memory buffer # - if warp_idx == self.epilog_warp_id[0]: - cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) - epilog_threads = 32 * len(self.epilog_warp_id) - cute.arch.barrier( - barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + @cute.jit + def epilogue_tma_store( + self, + epi_tidx: cutlass.Int32, + warp_idx: cutlass.Int32, + acc_pipeline: pipeline.PipelineAsync, + tiled_mma: cute.TiledMma, + tma_atom_c: cute.CopyAtom, + # Input of epilogue + tCtAcc_base: cute.Tensor, + # Staging of epilogue + sC: cute.Tensor, + # Output of epilogue + tCgC: cute.Tensor, + epi_tile: cute.Tile, + tile_sched: utils.StaticPersistentTileScheduler, + epilogue_op: cutlass.Constexpr, + ) -> None: + tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, self.use_2cta_instrs + ) + + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + tCgC_epi = cute.flat_divide( + tCgC[((None, None), 0, 0, None, None, None)], epi_tile + ) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC_partitioned = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + cute.group_modes(sC, 0, 2), + cute.group_modes(tCgC_epi, 0, 2), + ) + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, producer_group=c_producer_group + ) + + epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=self.epilog_sync_bar_id, + num_threads=32 * len(self.epilog_warp_id), + ) + + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], ) - if warp_idx == self.epilog_warp_id[0]: - if use_2cta_instrs: - cute.arch.mbarrier_arrive( - tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 - ) - cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) - cute.arch.dealloc_tmem( - tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + + # + # Slice to per mma tile index + # + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[(None, None, None, *mma_tile_coord_mnl)] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # + # Wait for accumulator buffer full + # + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, ) + epilog_sync_barrier.arrive_and_wait() + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + epilog_sync_barrier.arrive_and_wait() + + epilog_sync_barrier.arrive_and_wait() + # - # Wait for C store complete + # Async arrive accumulator buffer empty # - if cutlass.const_expr(self.use_tma_store): - c_pipeline.producer_tail() + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # Wait for C store complete + c_pipeline.producer_tail() + + @cute.jit + def epilogue( + self, + epi_tidx: cutlass.Int32, + acc_pipeline: pipeline.PipelineAsync, + tiled_mma: cute.TiledMma, + tCtAcc_base: cute.Tensor, + tCgC: cute.Tensor, + epi_tile: cute.Tile, + tile_sched: utils.StaticPersistentTileScheduler, + epilogue_op: cutlass.Constexpr, + tmem_dealloc_barrier: pipeline.NamedBarrier, + ) -> None: + tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, self.use_2cta_instrs + ) + + gC_epi = cute.flat_divide( + tCgC[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx) + tTR_gC_partitioned = thr_copy_t2r.partition_D(gC_epi) + # (T2R, T2R_M, T2R_N) + tTR_rC = cute.make_rmem_tensor( + tTR_gC_partitioned[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype + ) + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_gC = tTR_gC_partitioned[ + (None, None, None, None, None, *mma_tile_coord_mnl) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC)) + + # + # Wait for accumulator buffer full + # + acc_pipeline.consumer_wait(acc_consumer_state) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to C type + # + acc_vec = tTR_rAcc.load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tTR_rC.store(acc_vec) + + # + # Store C to global memory + # + cute.copy(simt_atom, tTR_rC, tTR_gC[(None, None, None, subtile_idx)]) + + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # Synchronize before TMEM dealloc (done by the caller) + tmem_dealloc_barrier.arrive_and_wait() def epilog_tmem_copy_and_partition( self, @@ -1254,10 +1306,7 @@ class PersistentDenseGemmKernel: use_2cta_instrs, ) # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) - tAcc_epi = cute.flat_divide( - tAcc[((None, None), 0, 0, None)], - epi_tile, - ) + tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile) # (EPI_TILE_M, EPI_TILE_N) tiled_copy_t2r = tcgen05.make_tmem_copy( copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] @@ -1274,7 +1323,7 @@ class PersistentDenseGemmKernel: # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) # (T2R, T2R_M, T2R_N) - tTR_rAcc = cute.make_fragment( + tTR_rAcc = cute.make_rmem_tensor( tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype ) return tiled_copy_t2r, tTR_tAcc, tTR_rAcc @@ -1316,169 +1365,6 @@ class PersistentDenseGemmKernel: tRS_rC = tiled_copy_r2s.retile(tTR_rC) return tiled_copy_r2s, tRS_rC, tRS_sC - def epilog_gmem_copy_and_partition( - self, - tidx: cutlass.Int32, - atom: Union[cute.CopyAtom, cute.TiledCopy], - gC_mnl: cute.Tensor, - epi_tile: cute.Tile, - sC: cute.Tensor, - ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: - """Make tiledCopy for global memory store, then use it to: - - partition register array (source) and global memory (destination) for none TMA store version; - - partition shared memory (source) and global memory (destination) for TMA store version. - - :param tidx: The thread index in epilogue warp groups - :type tidx: cutlass.Int32 - :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version - :type atom: cute.CopyAtom or cute.TiledCopy - :param gC_mnl: The global tensor C - :type gC_mnl: cute.Tensor - :param epi_tile: The epilogue tiler - :type epi_tile: cute.Tile - :param sC: The shared memory tensor to be copied and partitioned - :type sC: cute.Tensor - - :return: A tuple containing either: - - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: - - tma_atom_c: The TMA copy atom - - bSG_sC: The partitioned shared memory tensor C - - bSG_gC: The partitioned global tensor C - - For non-TMA store: (simt_atom, tTR_rC, tTR_gC) where: - - simt_atom: The SIMT copy atom - - tTR_rC: The register tensor C - - tTR_gC: The partitioned global tensor C - :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] - """ - # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) - gC_epi = cute.flat_divide( - gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile - ) - if cutlass.const_expr(self.use_tma_store): - tma_atom_c = atom - sC_for_tma_partition = cute.group_modes(sC, 0, 2) - gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) - # ((ATOM_V, REST_V), EPI_M, EPI_N) - # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) - bSG_sC, bSG_gC = cpasync.tma_partition( - tma_atom_c, - 0, - cute.make_layout(1), - sC_for_tma_partition, - gC_for_tma_partition, - ) - return tma_atom_c, bSG_sC, bSG_gC - else: - tiled_copy_t2r = atom - # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) - thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) - tTR_gC = thr_copy_t2r.partition_D(gC_epi) - # (T2R, T2R_M, T2R_N) - tTR_rC = cute.make_fragment( - tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype - ) - simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) - return simt_atom, tTR_rC, tTR_gC - - @staticmethod - def _compute_stages( - tiled_mma: cute.TiledMma, - mma_tiler_mnk: Tuple[int, int, int], - a_dtype: Type[cutlass.Numeric], - b_dtype: Type[cutlass.Numeric], - epi_tile: cute.Tile, - c_dtype: Type[cutlass.Numeric], - c_layout: utils.LayoutEnum, - smem_capacity: int, - occupancy: int, - use_tma_store: bool, - ) -> Tuple[int, int, int]: - """Computes the number of stages for A/B/C operands based on heuristics. - - :param tiled_mma: The tiled MMA object defining the core computation. - :type tiled_mma: cute.TiledMma - :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. - :type mma_tiler_mnk: tuple[int, int, int] - :param a_dtype: Data type of operand A. - :type a_dtype: type[cutlass.Numeric] - :param b_dtype: Data type of operand B. - :type b_dtype: type[cutlass.Numeric] - :param epi_tile: The epilogue tile shape. - :type epi_tile: cute.Tile - :param c_dtype: Data type of operand C (output). - :type c_dtype: type[cutlass.Numeric] - :param c_layout: Layout enum of operand C. - :type c_layout: utils.LayoutEnum - :param smem_capacity: Total available shared memory capacity in bytes. - :type smem_capacity: int - :param occupancy: Target number of CTAs per SM (occupancy). - :type occupancy: int - :param use_tma_store: Whether TMA store is enabled. - :type use_tma_store: bool - - :return: A tuple containing the computed number of stages for: - (ACC stages, A/B operand stages, C stages) - :rtype: tuple[int, int, int] - """ - # Default ACC stages - num_acc_stage = 2 - - # Default C stages - num_c_stage = 2 if use_tma_store else 0 - - # Calculate smem layout and size for one stage of A, B, and C - a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( - tiled_mma, - mma_tiler_mnk, - a_dtype, - 1, # a tmp 1 stage is provided - ) - b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( - tiled_mma, - mma_tiler_mnk, - b_dtype, - 1, # a tmp 1 stage is provided - ) - c_smem_layout_staged_one = ( - sm100_utils.make_smem_layout_epi( - c_dtype, - c_layout, - epi_tile, - 1, - ) - if use_tma_store - else None - ) - ab_bytes_per_stage = cute.size_in_bytes( - a_dtype, a_smem_layout_stage_one - ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) - mbar_helpers_bytes = 1024 - c_bytes_per_stage = ( - cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) - if use_tma_store - else 0 - ) - c_bytes = c_bytes_per_stage * num_c_stage - - # Calculate A/B stages: - # Start with total smem per CTA (capacity / occupancy) - # Subtract reserved bytes and initial C stages bytes - # Divide remaining by bytes needed per A/B stage - num_ab_stage = ( - smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) - ) // ab_bytes_per_stage - - # Refine epilogue stages: - # Calculate remaining smem after allocating for A/B stages and reserved bytes - # Add remaining unused smem to epilogue - if use_tma_store: - num_c_stage += ( - smem_capacity - - occupancy * ab_bytes_per_stage * num_ab_stage - - occupancy * (mbar_helpers_bytes + c_bytes) - ) // (occupancy * c_bytes_per_stage) - return num_acc_stage, num_ab_stage, num_c_stage - @staticmethod def _compute_grid( c: cute.Tensor, @@ -1541,11 +1427,8 @@ class PersistentDenseGemmKernel: return num_tmem_alloc_cols - @staticmethod def is_valid_dtypes( - ab_dtype: Type[cutlass.Numeric], - acc_dtype: Type[cutlass.Numeric], - c_dtype: Type[cutlass.Numeric], + self, ab_dtype: Type[cutlass.Numeric], c_dtype: Type[cutlass.Numeric] ) -> bool: """ Check if the dtypes are valid @@ -1560,8 +1443,7 @@ class PersistentDenseGemmKernel: :return: True if the dtypes are valid, False otherwise :rtype: bool """ - is_valid = True - if ab_dtype not in { + valid_ab_dtypes = { cutlass.Float16, cutlass.BFloat16, cutlass.TFloat32, @@ -1569,21 +1451,36 @@ class PersistentDenseGemmKernel: cutlass.Int8, cutlass.Float8E4M3FN, cutlass.Float8E5M2, - }: - is_valid = False - if ( - acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32} - or acc_dtype == cutlass.Float16 - and ab_dtype - not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2} - or acc_dtype == cutlass.Int32 - and ab_dtype not in {cutlass.Uint8, cutlass.Int8} - ): - is_valid = False - if ( - acc_dtype == cutlass.Float32 - and c_dtype - not in { + } + if ab_dtype not in valid_ab_dtypes: + return False + + if self.acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32}: + return False + + # Define compatibility mapping between accumulator type and AB type + acc_ab_compatibility = { + cutlass.Float32: { + cutlass.Float16, + cutlass.BFloat16, + cutlass.TFloat32, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, # Float32 accumulator supports floating point AB types only + cutlass.Float16: { + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, + cutlass.Int32: {cutlass.Uint8, cutlass.Int8}, + } + # Check compatibility between accumulator type and AB type + if ab_dtype not in acc_ab_compatibility[self.acc_dtype]: + return False + + # Define compatibility mapping between accumulator type and C type + acc_c_compatibility = { + cutlass.Float32: { cutlass.Float32, cutlass.Float16, cutlass.BFloat16, @@ -1592,42 +1489,28 @@ class PersistentDenseGemmKernel: cutlass.Int32, cutlass.Int8, cutlass.Uint8, - } - or acc_dtype == cutlass.Float16 - and c_dtype - not in { + }, + cutlass.Float16: { cutlass.BFloat16, cutlass.Float16, - } - or acc_dtype == cutlass.Int32 - and c_dtype - not in { + }, + cutlass.Int32: { cutlass.BFloat16, cutlass.Float16, cutlass.Float32, cutlass.Int32, cutlass.Int8, cutlass.Uint8, - } - ): - is_valid = False - return is_valid + }, + } + # Check compatibility between accumulator type and C type + if c_dtype not in acc_c_compatibility[self.acc_dtype]: + return False - @staticmethod - def is_valid_mma_tiler_and_cluster_shape( - use_2cta_instrs: bool, - mma_tiler_mn: Tuple[int, int], - cluster_shape_mn: Tuple[int, int], - ) -> bool: - """ - Check if the mma tiler and cluster shape are valid + return True - :param use_2cta_instrs: Whether to use 2 CTA groups - :type use_2cta_instrs: bool - :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler - :type mma_tiler_mn: Tuple[int, int] - :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster - :type cluster_shape_mn: Tuple[int, int] + def is_valid_mma_tiler_and_cluster_shape(self) -> bool: + """Check if the mma tiler and cluster shape are valid. :return: True if the mma tiler and cluster shape are valid, False otherwise :rtype: bool @@ -1635,29 +1518,29 @@ class PersistentDenseGemmKernel: is_valid = True # Skip invalid mma tile shape if not ( - (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) - or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + (not self.use_2cta_instrs and self.mma_tiler_mn[0] in [64, 128]) + or (self.use_2cta_instrs and self.mma_tiler_mn[0] in [128, 256]) ): is_valid = False - if mma_tiler_mn[1] not in range(32, 257, 32): + if self.mma_tiler_mn[1] not in range(32, 257, 32): is_valid = False # Skip illegal cluster shape - if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + if self.cluster_shape_mn[0] % (2 if self.use_2cta_instrs else 1) != 0: is_valid = False # Skip invalid cluster shape is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 if ( - cluster_shape_mn[0] * cluster_shape_mn[1] > 16 - or cluster_shape_mn[0] <= 0 - or cluster_shape_mn[1] <= 0 - or not is_power_of_2(cluster_shape_mn[0]) - or not is_power_of_2(cluster_shape_mn[1]) + self.cluster_shape_mn[0] * self.cluster_shape_mn[1] > 16 + or self.cluster_shape_mn[0] <= 0 + or self.cluster_shape_mn[1] <= 0 + or not is_power_of_2(self.cluster_shape_mn[0]) + or not is_power_of_2(self.cluster_shape_mn[1]) ): is_valid = False return is_valid - @staticmethod def is_valid_tensor_alignment( + self, m: int, n: int, k: int, @@ -1695,41 +1578,29 @@ class PersistentDenseGemmKernel: """ is_valid = True - def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + # TODO: move to utils + def check_contiguous_16B_alignment(dtype, is_mode0_major, tensor_shape): major_mode_idx = 0 if is_mode0_major else 1 num_major_elements = tensor_shape[major_mode_idx] num_contiguous_elements = 16 * 8 // dtype.width return num_major_elements % num_contiguous_elements == 0 if ( - not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) - or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) - or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + not check_contiguous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contiguous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contiguous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) ): is_valid = False return is_valid - @staticmethod - def is_valid_epilog_store_option( - use_2cta_instrs: bool, - use_tma_store: bool, - m: int, - n: int, - mma_tiler_mn: Tuple[int, int], - ) -> bool: + def is_valid_epilog_store_option(self, m: int, n: int) -> bool: """ Check if the epilogue store option is valid - :param use_2cta_instrs: Whether to use 2 CTA groups - :type use_2cta_instrs: bool - :param use_tma_store: Whether to use TMA store - :type use_tma_store: bool :param m: The number of rows in the A tensor :type m: int :param n: The number of columns in the B tensor :type n: int - :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler - :type mma_tiler_mn: Tuple[int, int] :return: True if the epilogue store option is valid, False otherwise :rtype: bool @@ -1738,88 +1609,105 @@ class PersistentDenseGemmKernel: is_valid = True # None TMA store version does not have predication, can not support OOB tiles cta_tile_shape_mn = ( - mma_tiler_mn[0] // (2 if use_2cta_instrs else 1), - mma_tiler_mn[1], + self.mma_tiler_mn[0] // (2 if self.use_2cta_instrs else 1), + self.mma_tiler_mn[1], ) - if not use_tma_store: + if not self.use_tma_store: if not (m % cta_tile_shape_mn[0] == 0 and n % cta_tile_shape_mn[1] == 0): is_valid = False return is_valid - @staticmethod - def can_implement( - ab_dtype: Type[cutlass.Numeric], - acc_dtype: Type[cutlass.Numeric], - c_dtype: Type[cutlass.Numeric], - use_2cta_instrs: bool, - mma_tiler_mn: Tuple[int, int], - cluster_shape_mn: Tuple[int, int], - use_tma_store: bool, - m: int, - n: int, - k: int, - l: int, - a_major: str, - b_major: str, - c_major: str, - ) -> bool: - """ - Check if the gemm can be implemented + def can_implement(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor) -> bool: + """Check if the given tensors can be implemented by this kernel. - :param ab_dtype: The data type of the A and B operands - :type ab_dtype: Type[cutlass.Numeric] - :param acc_dtype: The data type of the accumulator - :type acc_dtype: Type[cutlass.Numeric] - :param c_dtype: The data type of the output tensor - :type c_dtype: Type[cutlass.Numeric] - :param use_2cta_instrs: Whether to use 2 CTA groups - :type use_2cta_instrs: bool - :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler - :type mma_tiler_mn: Tuple[int, int] - :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster - :type cluster_shape_mn: Tuple[int, int] - :param use_tma_store: Whether to use TMA store - :type use_tma_store: bool - :param m: The number of rows in the A tensor - :type m: int - :param n: The number of columns in the B tensor - :type n: int - :param k: The number of columns in the A tensor - :type k: int - :param l: The number of columns in the C tensor - :type l: int - :param a_major: The major axis of the A tensor - :type a_major: str - :param b_major: The major axis of the B tensor - :type b_major: str - :param c_major: The major axis of the C tensor - :type c_major: str + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor - :return: True if the gemm can be implemented, False otherwise + :return: True if the gemm supports the given config, False otherwise :rtype: bool """ + m, n, k, l = a.shape[0], b.shape[0], a.shape[1], a.shape[2] + + # infer a_major, b_major, c_major + is_m_major_a = utils.LayoutEnum.from_tensor(a).is_m_major_a() + is_n_major_b = utils.LayoutEnum.from_tensor(b).is_n_major_b() + is_m_major_c = utils.LayoutEnum.from_tensor(c).is_m_major_c() + a_major = "m" if is_m_major_a else "k" + b_major = "n" if is_n_major_b else "k" + c_major = "m" if is_m_major_c else "n" + can_implement = True # Skip unsupported types - if not PersistentDenseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype): + if not self.is_valid_dtypes(a.element_type, c.element_type): can_implement = False # Skip invalid mma tile shape and cluster shape - if not PersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( - use_2cta_instrs, mma_tiler_mn, cluster_shape_mn - ): + if not self.is_valid_mma_tiler_and_cluster_shape(): can_implement = False # Skip illegal problem shape for load/store alignment - if not PersistentDenseGemmKernel.is_valid_tensor_alignment( - m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + if not self.is_valid_tensor_alignment( + m, n, k, l, a.element_type, c.element_type, a_major, b_major, c_major ): can_implement = False # Skip invalid epilogue store option - if not PersistentDenseGemmKernel.is_valid_epilog_store_option( - use_2cta_instrs, use_tma_store, m, n, mma_tiler_mn - ): + if not self.is_valid_epilog_store_option(m, n): can_implement = False + return can_implement +def create_tensors(l, m, n, k, a_major, b_major, c_major, ab_dtype, c_dtype): + torch.manual_seed(1111) + + a_torch_cpu = cutlass_torch.matrix(l, m, k, a_major == "m", ab_dtype) + b_torch_cpu = cutlass_torch.matrix(l, n, k, b_major == "n", ab_dtype) + c_torch_cpu = cutlass_torch.matrix(l, m, n, c_major == "m", c_dtype) + + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like( + c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + return ( + a_tensor, + b_tensor, + c_tensor, + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + c_torch_gpu, + ) + + +def compare(a_torch_cpu, b_torch_cpu, c_torch_gpu, c_dtype, tolerance): + # Copy gpu result back + kernel_result = c_torch_gpu.cpu() + + # Compute reference result + ref = torch.einsum( + "mkl,nkl->mnl", + a_torch_cpu.to(dtype=torch.float32), + b_torch_cpu.to(dtype=torch.float32), + ) + + # Convert ref to c_dtype + _, ref_torch_gpu = cutlass_torch.cute_tensor_like( + ref, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + ref_result = ref_torch_gpu.cpu() + + # Assert close results + torch.testing.assert_close(kernel_result, ref_result, atol=tolerance, rtol=1e-05) + + def run( mnkl: Tuple[int, int, int, int], ab_dtype: Type[cutlass.Numeric], @@ -1881,7 +1769,7 @@ def run( :return: Execution time of the GEMM kernel :rtype: float """ - print(f"Running Blackwell Persistent Dense GEMM test with:") + print("Running Blackwell Persistent Dense GEMM test with:") print(f"mnkl: {mnkl}") print(f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}") print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") @@ -1897,166 +1785,41 @@ def run( # Unpack parameters m, n, k, l = mnkl - # Skip unsupported testcase - if not PersistentDenseGemmKernel.can_implement( - ab_dtype, - acc_dtype, - c_dtype, - use_2cta_instrs, - mma_tiler_mn, - cluster_shape_mn, - use_tma_store, - m, - n, - k, - l, - a_major, - b_major, - c_major, - ): - raise TypeError( - f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {use_tma_store}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" - ) - if not torch.cuda.is_available(): raise RuntimeError("GPU is required to run this example!") - torch.manual_seed(1111) - - # Create and permute tensor A/B/C - def create_and_permute_tensor( - l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True - ): - # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) - # else: (l, mode0, mode1) -> (mode0, mode1, l) - shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) - permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) - is_unsigned = dtype in {cutlass.Uint8} - # Temporarily use uint8 as torch does not support fp8 type - torch_dtype = ( - cutlass_torch.dtype(dtype) - if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} - else torch.uint8 - ) - - # Create dtype torch tensor (cpu) - torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( - shape, - torch_dtype, - permute_order=permute_order, - init_type=cutlass_torch.TensorInitType.RANDOM, - init_config=cutlass_torch.RandomInitConfig( - min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 - ), - ) - # Create dtype torch tensor (gpu) - torch_tensor = torch_tensor_cpu.cuda() - - # Create f32 torch tensor (cpu) - f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) - - # Create dtype cute tensor (gpu) - cute_tensor = from_dlpack(torch_tensor, assumed_align=16) - cute_tensor.element_type = dtype - if is_dynamic_layout: - cute_tensor = cute_tensor.mark_layout_dynamic( - leading_dim=(0 if is_mode0_major else 1) - ) - cute_tensor = cutlass_torch.convert_cute_tensor( - f32_torch_tensor, - cute_tensor, - dtype, - is_dynamic_layout=is_dynamic_layout, - ) - - return f32_torch_tensor, cute_tensor, torch_tensor, torch_tensor_cpu - - a_ref, a_tensor, a_torch, a_torch_cpu = create_and_permute_tensor( - l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True - ) - b_ref, b_tensor, b_torch, b_torch_cpu = create_and_permute_tensor( - l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True - ) - c_ref, c_tensor, c_torch, c_torch_cpu = create_and_permute_tensor( - l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True - ) - - # Configure gemm kernel - gemm = PersistentDenseGemmKernel( - acc_dtype, - use_2cta_instrs, - mma_tiler_mn, - cluster_shape_mn, - use_tma_store, - ) - - # Compute max active clusters on current device - hardware_info = cutlass.utils.HardwareInfo() - max_active_clusters = hardware_info.get_max_active_clusters( - cluster_shape_mn[0] * cluster_shape_mn[1] - ) - # Get current CUDA stream from PyTorch torch_stream = torch.cuda.current_stream() # Get the raw stream pointer as a CUstream current_stream = cuda.CUstream(torch_stream.cuda_stream) - # Compile gemm kernel + + a_tensor, b_tensor, c_tensor, a_torch_cpu, b_torch_cpu, c_torch_cpu, c_torch_gpu = ( + create_tensors(l, m, n, k, a_major, b_major, c_major, ab_dtype, c_dtype) + ) + + # Build GEMM object + gemm = PersistentDenseGemmKernel( + acc_dtype, use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, use_tma_store + ) + + # Check if configuration can be implemented + can_implement = gemm.can_implement(a_tensor, b_tensor, c_tensor) + if not can_implement: + raise ValueError( + f"The current config which is invalid/unsupported: use_2cta_instrs = {use_2cta_instrs}, " + f"mma_tiler_mn = {mma_tiler_mn}, cluster_shape_mn = {cluster_shape_mn}, " + f"use_tma_store = {use_tma_store}" + ) + max_active_clusters = utils.HardwareInfo().get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) compiled_gemm = cute.compile( gemm, a_tensor, b_tensor, c_tensor, max_active_clusters, current_stream ) if not skip_ref_check: compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream) - if ab_dtype in { - cutlass.Int8, - cutlass.Uint8, - cutlass.Float8E4M3FN, - cutlass.Float8E5M2, - }: - ref = torch.einsum("mkl,nkl->mnl", a_ref.cpu(), b_ref.cpu()) - else: - ref = (torch.einsum("mkl,nkl->mnl", a_ref, b_ref)).cpu() - - # Copy gpu result back - gpu_c = c_torch.cpu() - - # Convert ref to c_type - if c_dtype == cutlass.Float32: - ref_c = ref - elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: - # m major: (l, n, m) -> (m, n, l) - # n major: (l, m, n) -> (m, n, l) - permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) - shape = (l, m, n) if c_major == "n" else (l, n, m) - f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( - shape, - torch.uint8, - permute_order=permute_order, - init_type=cutlass_torch.TensorInitType.SKIP, - ).cuda() - # Create dtype cute tensor (gpu) - ref_c_tensor = from_dlpack( - f8_torch_tensor, assumed_align=16 - ).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) - ref_c_tensor.element_type = c_dtype - ref_c_tensor = cutlass_torch.convert_cute_tensor( - ref, - ref_c_tensor, - c_dtype, - is_dynamic_layout=True, - ) - - ref_c = f8_torch_tensor.cpu() - else: - ref_c = ref.to(cutlass_torch.dtype(c_dtype)) - - # Reference checking ref_c and gpu_c - torch.testing.assert_close( - gpu_c, - ref_c, - atol=tolerance, - rtol=1e-05, - ) + compare(a_torch_cpu, b_torch_cpu, c_torch_gpu, c_dtype, tolerance) def generate_tensors(): a_tensor, _ = cutlass_torch.cute_tensor_like( diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py b/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py index 8f5e172e..3876c288 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py @@ -39,7 +39,7 @@ import cutlass.pipeline as pipeline from cutlass.cute.nvgpu import cpasync, tcgen05 import cutlass.torch as cutlass_torch import cutlass.utils.blackwell_helpers as sm100_utils -from cutlass.cute.runtime import from_dlpack +import cutlass.cute.testing as testing """ A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Blackwell SM100 architecture @@ -203,6 +203,7 @@ class DenseGemmKernel: self.use_2cta_instrs = use_2cta_instrs self.cluster_shape_mn = cluster_shape_mn # K dimension is deferred in _setup_attributes + self.mma_tiler_mn = mma_tiler_mn self.mma_tiler = (*mma_tiler_mn, 1) self.use_tma_store = use_tma_store @@ -414,59 +415,17 @@ class DenseGemmKernel: tma_atom_c = None tma_tensor_c = None if cutlass.const_expr(self.use_tma_store): - c_cta_v_layout = cute.composition( - cute.make_identity_layout(c.shape), self.epi_tile - ) epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem_layout, - c_cta_v_layout, + self.epi_tile, ) # Compute grid size grid = self._compute_grid(c, self.cta_tile_shape_mnk, self.cluster_shape_mn) - self.buffer_align_bytes = 1024 - - c_smem_size = ( - cute.cosize(self.c_smem_layout_staged.outer) if self.use_tma_store else 0 - ) - - # Define shared storage for kernel - @cute.struct - class SharedStorage: - ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] - ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] - acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] - tmem_dealloc_mbar_ptr: cutlass.Int64 - tmem_holding_buf: cutlass.Int32 - # (EPI_TILE_M, EPI_TILE_N, STAGE) - sC: cute.struct.Align[ - cute.struct.MemRange[ - self.c_dtype, - c_smem_size, - ], - self.buffer_align_bytes, - ] - # (MMA, MMA_M, MMA_K, STAGE) - sA: cute.struct.Align[ - cute.struct.MemRange[ - self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) - ], - self.buffer_align_bytes, - ] - # (MMA, MMA_N, MMA_K, STAGE) - sB: cute.struct.Align[ - cute.struct.MemRange[ - self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) - ], - self.buffer_align_bytes, - ] - - self.shared_storage = SharedStorage - # Launch the kernel synchronously self.kernel( tiled_mma, @@ -551,11 +510,17 @@ class DenseGemmKernel: # # Alloc and init: a+b full/empty, accumulator full, tensor memory dealloc barrier # - smem = utils.SmemAllocator() - storage = smem.allocate(self.shared_storage) + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_acc_stage * 2 + ] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 - tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr - tmem_holding_buf = storage.tmem_holding_buf + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) # Initialize mainloop ab_pipeline (barrier) and states ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) @@ -563,25 +528,20 @@ class DenseGemmKernel: ab_pipeline_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, num_tma_producer ) - ab_pipeline = pipeline.PipelineTmaUmma.create( + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), num_stages=self.num_ab_stage, producer_group=ab_pipeline_producer_group, consumer_group=ab_pipeline_consumer_group, tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cluster_layout_vmnk, - ) - ab_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.num_ab_stage - ) - ab_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_ab_stage - ) + ).make_participants() # Initialize acc_pipeline (barrier) and states acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) acc_pipeline_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta + pipeline.Agent.Thread, + self.threads_per_cta, ) acc_pipeline = pipeline.PipelineUmmaAsync.create( barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), @@ -597,15 +557,16 @@ class DenseGemmKernel: pipeline.PipelineUserType.Consumer, self.num_acc_stage ) + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=0, num_threads=self.threads_per_cta + ) # Tensor memory dealloc barrier init - if use_2cta_instrs: - if warp_idx == 0: - num_tmem_dealloc_threads = 32 - with cute.arch.elect_one(): - cute.arch.mbarrier_init( - tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads - ) - cute.arch.mbarrier_init_fence() + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) # Cluster arrive after barrier init if cute.size(self.cluster_shape_mn) > 1: @@ -615,20 +576,28 @@ class DenseGemmKernel: # Setup smem tensor A/B/C # # (EPI_TILE_M, EPI_TILE_N, STAGE) - sC = ( - storage.sC.get_tensor( - c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + sC = None + if cutlass.const_expr(self.use_tma_store): + sC = smem.allocate_tensor( + element_type=self.c_dtype, + layout=c_smem_layout_staged.outer, + byte_alignment=128, + swizzle=c_smem_layout_staged.inner, ) - if self.use_tma_store - else None - ) + # (MMA, MMA_M, MMA_K, STAGE) - sA = storage.sA.get_tensor( - a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + sA = smem.allocate_tensor( + element_type=self.a_dtype, + layout=a_smem_layout_staged.outer, + byte_alignment=128, + swizzle=a_smem_layout_staged.inner, ) # (MMA, MMA_N, MMA_K, STAGE) - sB = storage.sB.get_tensor( - b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + sB = smem.allocate_tensor( + element_type=self.b_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, ) # @@ -720,56 +689,16 @@ class DenseGemmKernel: if cute.size(self.cluster_shape_mn) > 1: cute.arch.cluster_wait() - # # Alloc tensor memory buffer - # - if warp_idx == 0: - cute.arch.alloc_tmem( - self.num_tmem_alloc_cols, tmem_holding_buf, is_two_cta=use_2cta_instrs - ) + tmem.allocate(self.num_tmem_alloc_cols) - # - # Bar sync for retrieve tensor memory ptr from shared memory - # - cute.arch.barrier() + # Barrier before retrieve tensor memory ptr from shared memory + tmem.wait_for_alloc() - # - # Retrieving tensor memory ptr and make accumulator tensor - # - tmem_ptr = cute.arch.retrieve_tmem_ptr( - self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf - ) + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) # (MMA, MMA_M, MMA_N) tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) - # - # Partition for epilogue - # - tiled_copy_t2r, tTR_tAcc, tTR_rAcc = self.epilog_tmem_copy_and_partition( - tidx, tCtAcc, tCgC, epi_tile, use_2cta_instrs - ) - - tTR_rC = None - tiled_copy_r2s = None - simt_atom = None - tRS_rC = None - tRS_sC = None - bSG_sC = None - bSG_gC = None - tTR_gC = None - if cutlass.const_expr(self.use_tma_store): - tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) - tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( - tiled_copy_t2r, tTR_rC, tidx, sC - ) - tma_atom_c, bSG_sC, bSG_gC = self.epilog_gmem_copy_and_partition( - tidx, tma_atom_c, tCgC, epi_tile, sC - ) - else: - simt_atom, tTR_rC, tTR_gC = self.epilog_gmem_copy_and_partition( - tidx, tiled_copy_t2r, tCgC, epi_tile, sC - ) - # # Slice to per mma tile index # @@ -777,65 +706,50 @@ class DenseGemmKernel: tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] # ((atom_v, rest_v), RestK) tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] - if cutlass.const_expr(self.use_tma_store): - # ((ATOM_V, REST_V), EPI_M, EPI_N) - bSG_gC = bSG_gC[(None, None, None, *mma_tile_coord_mnl)] - else: - # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) - tTR_gC = tTR_gC[(None, None, None, None, None, *mma_tile_coord_mnl)] - # /////////////////////////////////////////////////////////////////////////////// - # MAINLOOP - # /////////////////////////////////////////////////////////////////////////////// - prefetch_k_tile_cnt = cutlass.min(self.num_ab_stage - 2, k_tile_cnt) if warp_idx == 0: + # /////////////////////////////////////////////////////////////////////////////// + # MAINLOOP + # /////////////////////////////////////////////////////////////////////////////// for k_tile in cutlass.range( - k_tile_cnt, - prefetch_stages=self.num_ab_stage - 2, + k_tile_cnt, prefetch_stages=self.num_ab_stage - 2 ): # wait for AB buffer empty - ab_pipeline.producer_acquire(ab_producer_state) + producer_handle = ab_producer.acquire_and_advance() - # TMA load A/B + # TMA load A/B cute.copy( tma_atom_a, - tAgA[(None, ab_producer_state.count)], - tAsA[(None, ab_producer_state.index)], - tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + tAgA[(None, producer_handle.count)], + tAsA[(None, producer_handle.index)], + tma_bar_ptr=producer_handle.barrier, mcast_mask=a_full_mcast_mask, ) cute.copy( tma_atom_b, - tBgB[(None, ab_producer_state.count)], - tBsB[(None, ab_producer_state.index)], - tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + tBgB[(None, producer_handle.count)], + tBsB[(None, producer_handle.index)], + tma_bar_ptr=producer_handle.barrier, mcast_mask=b_full_mcast_mask, ) if is_leader_cta: # Wait for AB buffer full - ab_pipeline.consumer_wait(ab_consumer_state) + consumer_handle = ab_consumer.wait_and_advance() # tCtAcc += tCrA * tCrB - num_kblocks = cute.size(tCrA, mode=[2]) - for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): - kblock_coord = (None, None, kblock_idx, ab_consumer_state.index) + num_kblks = cute.size(tCrA, mode=[2]) + for kblk_idx in cutlass.range(num_kblks, unroll_full=True): + kblk_crd = (None, None, kblk_idx, consumer_handle.index) cute.gemm( - tiled_mma, - tCtAcc, - tCrA[kblock_coord], - tCrB[kblock_coord], - tCtAcc, + tiled_mma, tCtAcc, tCrA[kblk_crd], tCrB[kblk_crd], tCtAcc ) # Enable accumulate on tCtAcc after first kblock tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Async arrive AB buffer empty - ab_pipeline.consumer_release(ab_consumer_state) - - ab_producer_state.advance() - ab_consumer_state.advance() + consumer_handle.release() # Async arrive accumulator buffer full if is_leader_cta: @@ -846,114 +760,37 @@ class DenseGemmKernel: # # Release tensor memory allocation lock - if warp_idx == 0: - cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + tmem.relinquish_alloc_permit() # Wait for accumulator buffer full acc_pipeline.consumer_wait(acc_consumer_state) - tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) if cutlass.const_expr(self.use_tma_store): - bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + assert tma_atom_c is not None and sC is not None + self.epilogue_tma_store( + tidx, + warp_idx, + mma_tile_coord_mnl, # type: ignore + tma_atom_c, + tCtAcc, + sC, + tCgC, + epi_tile, + epilogue_op, + ) + else: - tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC)) + self.epilogue(tidx, mma_tile_coord_mnl, tCtAcc, tCgC, epi_tile, epilogue_op) # type: ignore - c_pipeline = None - if cutlass.const_expr(self.use_tma_store): - # Initialize tma store c_pipeline - c_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta - ) - c_pipeline = pipeline.PipelineTmaStore.create( - num_stages=self.num_c_stage, - producer_group=c_producer_group, - ) - - # - # Store accumulator to global memory in subtiles - # - subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) - for subtile_idx in range(subtile_cnt): - # - # Load accumulator from tensor memory buffer to register - # - tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] - cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) - - if cutlass.const_expr(self.use_tma_store): - # - # Perform epilogue op on accumulator and convert to C type - # - acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() - acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) - tRS_rC.store(acc_vec) - - # - # Store C to shared memory - # - c_buffer = subtile_idx % self.num_c_stage - cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]) - # Fence and barrier to make sure shared memory store is visible to TMA store - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, - ) - cute.arch.barrier() - - # - # TMA store C to global memory - # - if warp_idx == 0: - cute.copy( - tma_atom_c, - bSG_sC[(None, c_buffer)], - bSG_gC[(None, subtile_idx)], - ) - # Fence and barrier to make sure TMA store is completed to recollect C buffer - c_pipeline.producer_commit() - c_pipeline.producer_acquire() - cute.arch.barrier() - else: - # - # Perform epilogue op on accumulator and convert to C type - # - acc_vec = tTR_rAcc.load() - acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) - tTR_rC.store(acc_vec) - - # - # Store C to global memory - # - cute.copy(simt_atom, tTR_rC, tTR_gC[(None, None, None, subtile_idx)]) - - # # Dealloc the tensor memory buffer - # - cute.arch.barrier() - if warp_idx == 0: - if use_2cta_instrs: - cute.arch.mbarrier_arrive( - tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 - ) - cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) - cute.arch.dealloc_tmem( - tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs - ) - - # - # Wait for C store complete - # - if cutlass.const_expr(self.use_tma_store): - c_pipeline.producer_tail() + pipeline.sync(barrier_id=1) + tmem.free(tmem_ptr) # # Wait A/B buffer empty # if warp_idx == 0: - # Reverse prefetch_k_tile_cnt times to next available buffer - for i in range(prefetch_k_tile_cnt): - ab_producer_state.reverse() - ab_pipeline.producer_tail(ab_producer_state) + ab_producer.tail() return def epilog_tmem_copy_and_partition( @@ -994,10 +831,7 @@ class DenseGemmKernel: use_2cta_instrs, ) # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N) - tAcc_epi = cute.flat_divide( - tAcc[((None, None), 0, 0)], - epi_tile, - ) + tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0)], epi_tile) # (EPI_TILE_M, EPI_TILE_N) tiled_copy_t2r = tcgen05.make_tmem_copy( copy_atom_t2r, tAcc_epi[(None, None, 0, 0)] @@ -1014,7 +848,7 @@ class DenseGemmKernel: # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) # (T2R, T2R_M, T2R_N) - tTR_rAcc = cute.make_fragment( + tTR_rAcc = cute.make_rmem_tensor( tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype ) return tiled_copy_t2r, tTR_tAcc, tTR_rAcc @@ -1055,69 +889,185 @@ class DenseGemmKernel: tRS_rC = tiled_copy_r2s.retile(tTR_rC) return tiled_copy_r2s, tRS_rC, tRS_sC - def epilog_gmem_copy_and_partition( + @cute.jit + def epilogue_tma_store( self, - tidx: cutlass.Int32, - atom: Union[cute.CopyAtom, cute.TiledCopy], - gC_mnl: cute.Tensor, - epi_tile: cute.Tile, + epi_tidx: cutlass.Int32, + warp_idx: cutlass.Int32, + mma_tile_coord_mnl: Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32], + tma_atom_c: cute.CopyAtom, + tCtAcc: cute.Tensor, sC: cute.Tensor, - ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: - """Make tiledCopy for global memory store, then use it to: - - partition register array (source) and global memory (destination) for none TMA store version; - - partition shared memory (source) and global memory (destination) for TMA store version. - - :param tidx: The thread index in epilogue warp groups - :type tidx: cutlass.Int32 - :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version - :type atom: cute.CopyAtom or cute.TiledCopy - :param gC_mnl: The global tensor C - :type gC_mnl: cute.Tensor - :param epi_tile: The epilogue tiler - :type epi_tile: cute.Tile - :param sC: The shared memory tensor to be copied and partitioned - :type sC: cute.Tensor - - :return: A tuple containing either: - - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: - - tma_atom_c: The TMA copy atom - - bSG_sC: The partitioned shared memory tensor C - - bSG_gC: The partitioned global tensor C - - For non-TMA store: (simt_atom, tTR_rC, tTR_gC) where: - - simt_atom: The SIMT copy atom - - tTR_rC: The register tensor C - - tTR_gC: The partitioned global tensor C - :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + tCgC: cute.Tensor, + epi_tile: cute.Tile, + epilogue_op: cutlass.Constexpr, + ) -> None: """ - # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) - gC_epi = cute.flat_divide( - gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + Epilogue implementation for TMA store version. + + :param epi_tidx: Thread index + :type epi_tidx: cutlass.Int32 + :param warp_idx: Warp index + :type warp_idx: cutlass.Int32 + :param tCtAcc: Partitioned accumulator tensor + :type tCtAcc: cute.Tensor + :param sC: Shared memory C tensor + :type sC: cute.Tensor + :param tCgC: Global memory C tensor + :type tCgC: cute.Tensor + :param epi_tile: Epilogue tile + :type epi_tile: cute.Tile + :param epilogue_op: Epilogue operation + :type epilogue_op: cutlass.Constexpr + :param tma_atom_c: TMA atom for C tensor + :type tma_atom_c: cute.CopyAtom + """ + tiled_copy_t2r, tTR_tAcc, tTR_rAcc = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc, tCgC, epi_tile, self.use_2cta_instrs ) - if cutlass.const_expr(self.use_tma_store): - tma_atom_c = atom - sC_for_tma_partition = cute.group_modes(sC, 0, 2) - gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) - # ((ATOM_V, REST_V), EPI_M, EPI_N) - # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) - bSG_sC, bSG_gC = cpasync.tma_partition( - tma_atom_c, - 0, - cute.make_layout(1), - sC_for_tma_partition, - gC_for_tma_partition, + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + # ((ATOM_V, REST_V), EPI_M, EPI_N) + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + tCgC_epi = cute.flat_divide( + tCgC[((None, None), 0, 0, None, None, None)], epi_tile + ) + + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + cute.group_modes(sC, 0, 2), + cute.group_modes(tCgC_epi, 0, 2), + ) + + bSG_gC = bSG_gC[(None, None, None, *mma_tile_coord_mnl)] + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # Initialize tma store c_pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, producer_group=c_producer_group + ) + + # + # Store accumulator to global memory in sub-tiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Perform epilogue op on accumulator and convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + c_buffer = subtile_idx % self.num_c_stage + cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, ) - return tma_atom_c, bSG_sC, bSG_gC - else: - tiled_copy_t2r = atom - # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) - thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) - tTR_gC = thr_copy_t2r.partition_D(gC_epi) - # (T2R, T2R_M, T2R_N) - tTR_rC = cute.make_fragment( - tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype - ) - simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) - return simt_atom, tTR_rC, tTR_gC + pipeline.sync(barrier_id=1) + + # TMA store C to global memory + if warp_idx == 0: + cute.copy( + tma_atom_c, bSG_sC[(None, c_buffer)], bSG_gC[(None, subtile_idx)] + ) + # Fence and barrier to make sure TMA store is completed to recollect C buffer + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + pipeline.sync(barrier_id=1) + + # Wait for C store complete + c_pipeline.producer_tail() + + @cute.jit + def epilogue( + self, + epi_tidx: cutlass.Int32, + mma_tile_coord_mnl: Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32], + tCtAcc: cute.Tensor, + tCgC: cute.Tensor, + epi_tile: cute.Tile, + epilogue_op: cutlass.Constexpr, + ) -> None: + """ + Epilogue implementation for non-TMA store version. + + :param epi_tidx: Thread index + :type epi_tidx: cutlass.Int32 + :param tCtAcc: Partitioned accumulator tensor + :type tCtAcc: cute.Tensor + :param tCgC: Global memory C tensor + :type tCgC: cute.Tensor + :param epi_tile: Epilogue tile + :type epi_tile: cute.Tile + :param epilogue_op: Epilogue operation + :type epilogue_op: cutlass.Constexpr + """ + tiled_copy_t2r, tTR_tAcc, tTR_rAcc = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc, tCgC, epi_tile, self.use_2cta_instrs + ) + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + tCgC_epi = cute.flat_divide( + tCgC[((None, None), 0, 0, None, None, None)], epi_tile + ) + + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx) + tTR_gC = thr_copy_t2r.partition_D(tCgC_epi) + # (T2R, T2R_M, T2R_N) + tTR_rC = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype + ) + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) + + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_gC = tTR_gC[(None, None, None, None, None, *mma_tile_coord_mnl)] + tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC)) + + # + # Store accumulator to global memory in sub-tiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Perform epilogue op on accumulator and convert to C type + # + acc_vec = tTR_rAcc.load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tTR_rC.store(acc_vec) + + # Store C to global memory + cute.copy(simt_atom, tTR_rC, tTR_gC[(None, None, None, subtile_idx)]) @staticmethod def _compute_stages( @@ -1268,27 +1218,21 @@ class DenseGemmKernel: tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) return sm100_utils.get_num_tmem_alloc_cols(tCtAcc_fake) - @staticmethod def is_valid_dtypes( - ab_dtype: Type[cutlass.Numeric], - acc_dtype: Type[cutlass.Numeric], - c_dtype: Type[cutlass.Numeric], + self, ab_dtype: Type[cutlass.Numeric], c_dtype: Type[cutlass.Numeric] ) -> bool: """ Check if the dtypes are valid :param ab_dtype: The data type of the A and B operands :type ab_dtype: Type[cutlass.Numeric] - :param acc_dtype: The data type of the accumulator - :type acc_dtype: Type[cutlass.Numeric] :param c_dtype: The data type of the output tensor :type c_dtype: Type[cutlass.Numeric] :return: True if the dtypes are valid, False otherwise :rtype: bool """ - is_valid = True - if ab_dtype not in { + valid_ab_dtypes = { cutlass.Float16, cutlass.BFloat16, cutlass.TFloat32, @@ -1296,21 +1240,36 @@ class DenseGemmKernel: cutlass.Int8, cutlass.Float8E4M3FN, cutlass.Float8E5M2, - }: - is_valid = False - if ( - acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32} - or acc_dtype == cutlass.Float16 - and ab_dtype - not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2} - or acc_dtype == cutlass.Int32 - and ab_dtype not in {cutlass.Uint8, cutlass.Int8} - ): - is_valid = False - if ( - acc_dtype == cutlass.Float32 - and c_dtype - not in { + } + if ab_dtype not in valid_ab_dtypes: + return False + + if self.acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32}: + return False + + # Define compatibility mapping between accumulator type and AB type + acc_ab_compatibility = { + cutlass.Float32: { + cutlass.Float16, + cutlass.BFloat16, + cutlass.TFloat32, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, # Float32 accumulator supports floating point AB types only + cutlass.Float16: { + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, + cutlass.Int32: {cutlass.Uint8, cutlass.Int8}, + } + # Check compatibility between accumulator type and AB type + if ab_dtype not in acc_ab_compatibility[self.acc_dtype]: + return False + + # Define compatibility mapping between accumulator type and C type + acc_c_compatibility = { + cutlass.Float32: { cutlass.Float32, cutlass.Float16, cutlass.BFloat16, @@ -1319,42 +1278,28 @@ class DenseGemmKernel: cutlass.Int32, cutlass.Int8, cutlass.Uint8, - } - or acc_dtype == cutlass.Float16 - and c_dtype - not in { + }, + cutlass.Float16: { cutlass.BFloat16, cutlass.Float16, - } - or acc_dtype == cutlass.Int32 - and c_dtype - not in { + }, + cutlass.Int32: { cutlass.BFloat16, cutlass.Float16, cutlass.Float32, cutlass.Int32, cutlass.Int8, cutlass.Uint8, - } - ): - is_valid = False - return is_valid + }, + } + # Check compatibility between accumulator type and C type + if c_dtype not in acc_c_compatibility[self.acc_dtype]: + return False - @staticmethod - def is_valid_mma_tiler_and_cluster_shape( - use_2cta_instrs: bool, - mma_tiler_mn: Tuple[int, int], - cluster_shape_mn: Tuple[int, int], - ) -> bool: - """ - Check if the mma tiler and cluster shape are valid + return True - :param use_2cta_instrs: Whether to use 2 CTA groups - :type use_2cta_instrs: bool - :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler - :type mma_tiler_mn: Tuple[int, int] - :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster - :type cluster_shape_mn: Tuple[int, int] + def is_valid_mma_tiler_and_cluster_shape(self) -> bool: + """Check if the mma tiler and cluster shape are valid. :return: True if the mma tiler and cluster shape are valid, False otherwise :rtype: bool @@ -1362,29 +1307,29 @@ class DenseGemmKernel: is_valid = True # Skip invalid mma tile shape if not ( - (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) - or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + (not self.use_2cta_instrs and self.mma_tiler_mn[0] in [64, 128]) + or (self.use_2cta_instrs and self.mma_tiler_mn[0] in [128, 256]) ): is_valid = False - if mma_tiler_mn[1] not in range(32, 257, 32): + if self.mma_tiler_mn[1] not in range(32, 257, 32): is_valid = False # Skip illegal cluster shape - if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + if self.cluster_shape_mn[0] % (2 if self.use_2cta_instrs else 1) != 0: is_valid = False # Skip invalid cluster shape is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 if ( - cluster_shape_mn[0] * cluster_shape_mn[1] > 16 - or cluster_shape_mn[0] <= 0 - or cluster_shape_mn[1] <= 0 - or not is_power_of_2(cluster_shape_mn[0]) - or not is_power_of_2(cluster_shape_mn[1]) + self.cluster_shape_mn[0] * self.cluster_shape_mn[1] > 16 + or self.cluster_shape_mn[0] <= 0 + or self.cluster_shape_mn[1] <= 0 + or not is_power_of_2(self.cluster_shape_mn[0]) + or not is_power_of_2(self.cluster_shape_mn[1]) ): is_valid = False return is_valid - @staticmethod def is_valid_tensor_alignment( + self, m: int, n: int, k: int, @@ -1422,41 +1367,28 @@ class DenseGemmKernel: """ is_valid = True - def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + def check_contiguous_16B_alignment(dtype, is_mode0_major, tensor_shape): major_mode_idx = 0 if is_mode0_major else 1 num_major_elements = tensor_shape[major_mode_idx] num_contiguous_elements = 16 * 8 // dtype.width return num_major_elements % num_contiguous_elements == 0 if ( - not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) - or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) - or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + not check_contiguous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contiguous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contiguous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) ): is_valid = False return is_valid - @staticmethod - def is_valid_epilog_store_option( - use_2cta_instrs: bool, - use_tma_store: bool, - m: int, - n: int, - mma_tiler_mn: Tuple[int, int], - ) -> bool: + def is_valid_epilog_store_option(self, m: int, n: int) -> bool: """ Check if the epilogue store option is valid - :param use_2cta_instrs: Whether to use 2 CTA groups - :type use_2cta_instrs: bool - :param use_tma_store: Whether to use TMA store - :type use_tma_store: bool :param m: The number of rows in the A tensor :type m: int :param n: The number of columns in the B tensor :type n: int - :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler - :type mma_tiler_mn: Tuple[int, int] :return: True if the epilogue store option is valid, False otherwise :rtype: bool @@ -1465,89 +1397,106 @@ class DenseGemmKernel: is_valid = True # None TMA store version does not have predication, can not support OOB tiles cta_tile_shape_mn = ( - mma_tiler_mn[0] // (2 if use_2cta_instrs else 1), - mma_tiler_mn[1], + self.mma_tiler_mn[0] // (2 if self.use_2cta_instrs else 1), + self.mma_tiler_mn[1], ) - if not use_tma_store: + if not self.use_tma_store: if not (m % cta_tile_shape_mn[0] == 0 and n % cta_tile_shape_mn[1] == 0): is_valid = False return is_valid - @staticmethod - def can_implement( - ab_dtype: Type[cutlass.Numeric], - acc_dtype: Type[cutlass.Numeric], - c_dtype: Type[cutlass.Numeric], - use_2cta_instrs: bool, - mma_tiler_mn: Tuple[int, int], - cluster_shape_mn: Tuple[int, int], - use_tma_store: bool, - m: int, - n: int, - k: int, - l: int, - a_major: str, - b_major: str, - c_major: str, - ) -> bool: - """ - Check if the gemm can be implemented + def can_implement(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor) -> bool: + """Check if the given tensors can be implemented by this kernel. - :param ab_dtype: The data type of the A and B operands - :type ab_dtype: Type[cutlass.Numeric] - :param acc_dtype: The data type of the accumulator - :type acc_dtype: Type[cutlass.Numeric] - :param c_dtype: The data type of the output tensor - :type c_dtype: Type[cutlass.Numeric] - :param use_2cta_instrs: Whether to use 2 CTA groups - :type use_2cta_instrs: bool - :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler - :type mma_tiler_mn: Tuple[int, int] - :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster - :type cluster_shape_mn: Tuple[int, int] - :param use_tma_store: Whether to use TMA store - :type use_tma_store: bool - :param m: The number of rows in the A tensor - :type m: int - :param n: The number of columns in the B tensor - :type n: int - :param k: The number of columns in the A tensor - :type k: int - :param l: The number of columns in the C tensor - :type l: int - :param a_major: The major axis of the A tensor - :type a_major: str - :param b_major: The major axis of the B tensor - :type b_major: str - :param c_major: The major axis of the C tensor - :type c_major: str + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor - :return: True if the gemm can be implemented, False otherwise + :return: True if the gemm supports the given config, False otherwise :rtype: bool """ + m, n, k, l = a.shape[0], b.shape[0], a.shape[1], a.shape[2] + + # infer a_major, b_major, c_major + is_m_major_a = utils.LayoutEnum.from_tensor(a).is_m_major_a() + is_n_major_b = utils.LayoutEnum.from_tensor(b).is_n_major_b() + is_m_major_c = utils.LayoutEnum.from_tensor(c).is_m_major_c() + a_major = "m" if is_m_major_a else "k" + b_major = "n" if is_n_major_b else "k" + c_major = "m" if is_m_major_c else "n" + can_implement = True # Skip unsupported types - if not DenseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype): + if not self.is_valid_dtypes(a.element_type, c.element_type): can_implement = False # Skip invalid mma tile shape and cluster shape - if not DenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( - use_2cta_instrs, mma_tiler_mn, cluster_shape_mn - ): + if not self.is_valid_mma_tiler_and_cluster_shape(): can_implement = False # Skip illegal problem shape for load/store alignment - if not DenseGemmKernel.is_valid_tensor_alignment( - m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + if not self.is_valid_tensor_alignment( + m, n, k, l, a.element_type, c.element_type, a_major, b_major, c_major ): can_implement = False # Skip invalid epilogue store option - if not DenseGemmKernel.is_valid_epilog_store_option( - use_2cta_instrs, use_tma_store, m, n, mma_tiler_mn - ): + if not self.is_valid_epilog_store_option(m, n): can_implement = False + return can_implement -def run_dense_gemm( +def create_tensors(l, m, n, k, a_major, b_major, c_major, ab_dtype, c_dtype): + torch.manual_seed(1111) + + a_torch_cpu = cutlass_torch.matrix(l, m, k, a_major == "m", ab_dtype) + b_torch_cpu = cutlass_torch.matrix(l, n, k, b_major == "n", ab_dtype) + c_torch_cpu = cutlass_torch.matrix(l, m, n, c_major == "m", c_dtype) + + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like( + c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + return ( + a_tensor, + b_tensor, + c_tensor, + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + c_torch_gpu, + ) + + +def compare(a_torch_cpu, b_torch_cpu, c_torch_gpu, c_dtype, tolerance): + # Copy gpu result back + kernel_result = c_torch_gpu.cpu() + + # Compute reference result + ref = torch.einsum( + "mkl,nkl->mnl", + a_torch_cpu.to(dtype=torch.float32), + b_torch_cpu.to(dtype=torch.float32), + ) + + # Convert ref to c_dtype + _, ref_torch_gpu = cutlass_torch.cute_tensor_like( + ref, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + ref_result = ref_torch_gpu.cpu() + + # Assert close results + torch.testing.assert_close(kernel_result, ref_result, atol=tolerance, rtol=1e-05) + + +def run( mnkl: Tuple[int, int, int, int], ab_dtype: Type[cutlass.Numeric], c_dtype: Type[cutlass.Numeric], @@ -1555,19 +1504,56 @@ def run_dense_gemm( a_major: str, b_major: str, c_major: str, - mma_tiler_mn: Tuple[int, int], - cluster_shape_mn: Tuple[int, int], - use_2cta_instrs: bool, - use_tma_store: bool, - tolerance: float, + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Tuple[int, int] = (2, 1), + use_2cta_instrs: bool = True, + use_tma_store: bool = True, + tolerance: float = 1e-01, warmup_iterations: int = 0, iterations: int = 1, skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, ): + """Execute a batched dense GEMM operation on Blackwell architecture with software pipeline and performance benchmarking. + + This function prepares input tensors, configures and launches the GEMM kernel with software pipeline, + optionally performs reference validation, and benchmarks the execution performance. + + :param mnkl: Problem size (M, N, K, L) + :type mnkl: Tuple[int, int, int, int] + :param ab_dtype: Data type for input tensors A and B + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: Data type for output tensor C + :type c_dtype: Type[cutlass.Numeric] + :param acc_dtype: Data type for accumulation during matrix multiplication + :type acc_dtype: Type[cutlass.Numeric] + :param a_major/b_major/c_major: Memory layout of tensor A/B/C + :type a_major/b_major/c_major: str + :param mma_tiler_mn: MMA tiling size, defaults to (256, 256) + :type mma_tiler_mn: Tuple[int, int], optional + :param cluster_shape_mn: Cluster shape, defaults to (2, 1) + :type cluster_shape_mn: Tuple[int, int], optional + :param use_2cta_instrs: Whether to use 2CTA instructions, defaults to True + :type use_2cta_instrs: bool, optional + :param use_tma_store: Whether to use TMA store, defaults to True + :type use_tma_store: bool, optional + :param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01 + :type tolerance: float, optional + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 1 + :type iterations: int, optional + :param skip_ref_check: Whether to skip reference result validation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :raises RuntimeError: If CUDA GPU is not available + :raises ValueError: If the configuration is invalid or unsupported by the kernel + :return: Execution time of the GEMM kernel + :rtype: float """ - Prepare A/B/C tensors, launch GPU kernel, and reference checking. - """ - print(f"Running B100 software pipeline Dense GEMM test with:") + print("Running Blackwell Software Pipeline Dense GEMM test with:") print(f"mnkl: {mnkl}") print(f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}") print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") @@ -1578,184 +1564,89 @@ def run_dense_gemm( print(f"Warmup iterations: {warmup_iterations}") print(f"Iterations: {iterations}") print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") # Unpack parameters m, n, k, l = mnkl - # Skip unsupported testcase - if not DenseGemmKernel.can_implement( - ab_dtype, - acc_dtype, - c_dtype, - use_2cta_instrs, - mma_tiler_mn, - cluster_shape_mn, - use_tma_store, - m, - n, - k, - l, - a_major, - b_major, - c_major, - ): - raise TypeError( - f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {use_tma_store}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" - ) - if not torch.cuda.is_available(): raise RuntimeError("GPU is required to run this example!") - torch.manual_seed(1111) + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) - # Create and permute tensor A/B/C - def create_and_permute_tensor( - l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True - ): - # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) - # else: (l, mode0, mode1) -> (mode0, mode1, l) - shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) - permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) - is_unsigned = dtype in {cutlass.Uint8} - # Temporarily use uint8 as torch does not support fp8 type - torch_dtype = ( - cutlass_torch.dtype(dtype) - if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} - else torch.uint8 - ) - - # Create dtype torch tensor (cpu) - torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( - shape, - torch_dtype, - permute_order=permute_order, - init_type=cutlass_torch.TensorInitType.RANDOM, - init_config=cutlass_torch.RandomInitConfig( - min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 - ), - ) - # Create dtype torch tensor (gpu) - torch_tensor = torch_tensor_cpu.cuda() - - # Create f32 torch tensor (cpu) - f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) - - # Create dtype cute tensor (gpu) - cute_tensor = from_dlpack(torch_tensor, assumed_align=16) - cute_tensor.element_type = dtype - if is_dynamic_layout: - cute_tensor = cute_tensor.mark_layout_dynamic( - leading_dim=(0 if is_mode0_major else 1) - ) - cute_tensor = cutlass_torch.convert_cute_tensor( - f32_torch_tensor, - cute_tensor, - dtype, - is_dynamic_layout=is_dynamic_layout, - ) - - return f32_torch_tensor, cute_tensor, torch_tensor - - a_ref, a_tensor, a_torch = create_and_permute_tensor( - l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True - ) - b_ref, b_tensor, b_torch = create_and_permute_tensor( - l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True - ) - c_ref, c_tensor, c_torch = create_and_permute_tensor( - l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True + a_tensor, b_tensor, c_tensor, a_torch_cpu, b_torch_cpu, c_torch_cpu, c_torch_gpu = ( + create_tensors(l, m, n, k, a_major, b_major, c_major, ab_dtype, c_dtype) ) - # Configure gemm kernel + # Build GEMM object gemm = DenseGemmKernel( - acc_dtype, - use_2cta_instrs, - mma_tiler_mn, - cluster_shape_mn, - use_tma_store, + acc_dtype, use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, use_tma_store ) - torch_stream = torch.cuda.Stream() - stream = cuda.CUstream(torch_stream.cuda_stream) - # Compile gemm kernel - compiled_gemm = cute.compile(gemm, a_tensor, b_tensor, c_tensor, stream) - - # Launch GPU kernel - # Warm up - for i in range(warmup_iterations): - compiled_gemm(a_tensor, b_tensor, c_tensor, stream) - # Execution - for i in range(iterations): - compiled_gemm(a_tensor, b_tensor, c_tensor, stream) - - # Compute reference result - if not skip_ref_check: - if ab_dtype in { - cutlass.Int8, - cutlass.Uint8, - cutlass.Float8E4M3FN, - cutlass.Float8E5M2, - }: - ref = torch.einsum("mkl,nkl->mnl", a_ref.cpu(), b_ref.cpu()) - else: - ref = (torch.einsum("mkl,nkl->mnl", a_ref, b_ref)).cpu() - - # Copy gpu result back - gpu_c = c_torch.cpu() - - # Convert ref to c_type - if c_dtype == cutlass.Float32: - ref_c = ref - elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: - # m major: (l, n, m) -> (m, n, l) - # n major: (l, m, n) -> (m, n, l) - permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) - shape = (l, m, n) if c_major == "n" else (l, n, m) - f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( - shape, - torch.uint8, - permute_order=permute_order, - init_type=cutlass_torch.TensorInitType.SKIP, - ).cuda() - # Create dtype cute tensor (gpu) - ref_c_tensor = from_dlpack( - f8_torch_tensor, assumed_align=16 - ).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) - ref_c_tensor.element_type = c_dtype - ref_c_tensor = cutlass_torch.convert_cute_tensor( - ref, - ref_c_tensor, - c_dtype, - is_dynamic_layout=True, - ) - - ref_c = f8_torch_tensor.cpu() - else: - ref_c = ref.to(cutlass_torch.dtype(c_dtype)) - - # Reference checking ref_c and gpu_c - torch.testing.assert_close( - gpu_c, - ref_c, - atol=tolerance, - rtol=1e-05, + # Check if configuration can be implemented + can_implement = gemm.can_implement(a_tensor, b_tensor, c_tensor) + if not can_implement: + raise ValueError( + f"The current config which is invalid/unsupported: use_2cta_instrs = {use_2cta_instrs}, " + f"mma_tiler_mn = {mma_tiler_mn}, cluster_shape_mn = {cluster_shape_mn}, " + f"use_tma_store = {use_tma_store}" ) + compiled_gemm = cute.compile(gemm, a_tensor, b_tensor, c_tensor, current_stream) + + if not skip_ref_check: + compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream) + compare(a_torch_cpu, b_torch_cpu, c_torch_gpu, c_dtype, tolerance) + + def generate_tensors(): + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, _ = cutlass_torch.cute_tensor_like( + c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + return testing.JitArguments(a_tensor, b_tensor, c_tensor, current_stream) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch_cpu.numel() * a_torch_cpu.element_size() + + b_torch_cpu.numel() * b_torch_cpu.element_size() + + c_torch_cpu.numel() * c_torch_cpu.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + if __name__ == "__main__": def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: try: return tuple(int(x.strip()) for x in s.split(",")) - # or: return tuple([int(x.strip()) for x in s.split(",")]) except ValueError: raise argparse.ArgumentTypeError( "Invalid format. Expected comma-separated integers." ) - parser = argparse.ArgumentParser( - description="Example of MxNxKxL GEMM on Blackwell." - ) + parser = argparse.ArgumentParser(description="Example of Dense GEMM on Blackwell.") parser.add_argument( "--mnkl", @@ -1767,7 +1658,7 @@ if __name__ == "__main__": "--mma_tiler_mn", type=parse_comma_separated_ints, default=(128, 128), - help="Mma tiler (comma-separated)", + help="Mma tile shape (comma-separated)", ) parser.add_argument( "--cluster_shape_mn", @@ -1795,10 +1686,21 @@ if __name__ == "__main__": parser.add_argument( "--warmup_iterations", type=int, default=0, help="Warmup iterations" ) - parser.add_argument("--iterations", type=int, default=1, help="Iterations") + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) parser.add_argument( "--skip_ref_check", action="store_true", help="Skip reference checking" ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) args = parser.parse_args() @@ -1811,7 +1713,7 @@ if __name__ == "__main__": if len(args.cluster_shape_mn) != 2: parser.error("--cluster_shape_mn must contain exactly 2 values") - run_dense_gemm( + run( args.mnkl, args.ab_dtype, args.c_dtype, @@ -1827,5 +1729,6 @@ if __name__ == "__main__": args.warmup_iterations, args.iterations, args.skip_ref_check, + args.use_cold_l2, ) print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/fmha.py b/examples/python/CuTeDSL/blackwell/fmha.py index 560129df..58905eef 100644 --- a/examples/python/CuTeDSL/blackwell/fmha.py +++ b/examples/python/CuTeDSL/blackwell/fmha.py @@ -27,10 +27,11 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import argparse -import enum import math +import os +import sys import time -from typing import Type, Tuple +from typing import Type, Tuple, Union, Optional import torch import torch.nn.functional as F @@ -45,7 +46,11 @@ import cutlass.torch as cutlass_torch import cutlass.utils.blackwell_helpers as sm100_utils import cutlass.cute.testing as testing from cutlass.cute.runtime import from_dlpack -from cutlass.cute.typing import Int32, Int64, Float32, Boolean +from cutlass.cute.typing import Int32, Int64, Float32 + +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.join(current_dir, "..")) +from utils import fmha_helpers as fmha_utils """ A fused multi-head attention (FMHA) example for the NVIDIA Blackwell SM100 architecture using CUTE DSL @@ -95,167 +100,9 @@ Constraints for this example: * For persistent scheduling, use --is_persistent (note: specify without =True/False) """ -class FmhaStaticTileSchedulerParams: - def __init__( - self, - is_persistent: bool, - problem_shape_mbh: cute.Shape, - *, - loc=None, - ip=None, - ): - self.is_persistent = is_persistent - self.problem_shape_mbh = problem_shape_mbh - self._loc = loc - self._ip = ip - - def __extract_mlir_values__(self): - values, self._values_pos = [], [] - for obj in [self.is_persistent, self.problem_shape_mbh]: - obj_values = cutlass.extract_mlir_values(obj) - values += obj_values - self._values_pos.append(len(obj_values)) - return values - - def __new_from_mlir_values__(self, values): - obj_list = [] - for obj, n_items in zip( - [self.is_persistent, self.problem_shape_mbh], self._values_pos - ): - obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) - values = values[n_items:] - return FmhaStaticTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) - - -def create_fmha_static_tile_scheduler_params( - is_persistent: bool, - problem_shape_mbh: cute.Shape, -) -> FmhaStaticTileSchedulerParams: - return FmhaStaticTileSchedulerParams(is_persistent, problem_shape_mbh) - -class FmhaStaticTileScheduler: - - def __init__( - self, - params: FmhaStaticTileSchedulerParams, - current_work_linear_idx: Int32, - blk_coord: cute.Coord, - grid_shape: cute.Shape, - *, - loc=None, - ip=None, - ): - self._params = params - self._blk_coord = blk_coord - self._grid_shape = grid_shape - self._is_persistent = params.is_persistent - self._current_work_linear_idx = current_work_linear_idx - self._problem_shape_mbh = cute.make_layout( - params.problem_shape_mbh, loc=loc, ip=ip - ) - self._num_blocks = cute.size(self._problem_shape_mbh, loc=loc, ip=ip) - self._is_first_block = True - self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip) - self._loc = loc - self._ip = ip - - # called by host - @staticmethod - def get_grid_shape( - params: FmhaStaticTileSchedulerParams, - *, - loc=None, - ip=None, - ) -> cute.Shape: - if params.is_persistent: - hardware_info = cutlass.utils.HardwareInfo() - sm_count = hardware_info.get_device_multiprocessor_count() - return ( - cutlass.min( - sm_count, cute.size(params.problem_shape_mbh, loc=loc, ip=ip) - ), - 1, - 1, - ) - else: - return params.problem_shape_mbh - - @staticmethod - def check_valid_work_for_seqlen_q( - q_tiler: int, - current_idx: Int32, - seqlen_q: Int32, - ) -> Boolean: - return current_idx * q_tiler < seqlen_q - - def get_current_work(self, *, loc=None, ip=None) -> utils.WorkTileInfo: - is_valid = ( - self._current_work_linear_idx < self._num_blocks - if self._is_persistent - else self._is_first_block - ) - - blk_coord = (0, 0, 0) - if self._is_persistent: - blk_coord = self._problem_shape_mbh.get_hier_coord( - self._current_work_linear_idx, loc=loc, ip=ip - ) - else: - blk_coord = self._blk_coord - - # cur_tile_coord is (mid, 0, (bid, hid)) - cur_tile_coord = ( - blk_coord[0], - 0, - (blk_coord[1], blk_coord[2]), - ) - - return utils.WorkTileInfo(cur_tile_coord, is_valid) - - def initial_work_tile_info(self, *, loc=None, ip=None): - return self.get_current_work(loc=loc, ip=ip) - - def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None): - if self._is_persistent: - self._current_work_linear_idx += advance_count * self.num_persistent_sm - self._is_first_block = False - - def __extract_mlir_values__(self): - values = cutlass.extract_mlir_values(self._params) - values.extend(cutlass.extract_mlir_values(self._current_work_linear_idx)) - values.extend(cutlass.extract_mlir_values(self._blk_coord)) - values.extend(cutlass.extract_mlir_values(self._grid_shape)) - return values - - def __new_from_mlir_values__(self, values): - assert len(values) == 10 - new_params = cutlass.new_from_mlir_values(self._params, values[0:3]) - new_current_work_linear_idx = cutlass.new_from_mlir_values( - self._current_work_linear_idx, [values[3]] - ) - new_blk_coord = cutlass.new_from_mlir_values(self._blk_coord, values[4:7]) - new_grid_shape = cutlass.new_from_mlir_values(self._grid_shape, values[7:]) - return FmhaStaticTileScheduler( - new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape - ) - - -def create_fmha_static_tile_scheduler( - params: FmhaStaticTileSchedulerParams, - blk_coord: cute.Coord, - grid_shape: cute.Shape, -) -> FmhaStaticTileScheduler: - return FmhaStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape) - - -class MaskType(enum.Enum): - NO_MASK = enum.auto() - RESIDUAL_MASK = enum.auto() - CAUSAL_MASK = enum.auto() - def make_thread_cooperative_group(size: int): - return pipeline.CooperativeGroup(pipeline.Agent.Thread, size, size) + return pipeline.CooperativeGroup(pipeline.Agent.Thread, size) class BlackwellFusedMultiHeadAttentionForward: @@ -265,7 +112,7 @@ class BlackwellFusedMultiHeadAttentionForward: pv_acc_dtype: Type[cutlass.Numeric], mma_tiler: Tuple[int, int, int], is_persistent: bool, - mask_type: MaskType, + mask_type: fmha_utils.MaskEnum, ): """Initializes the configuration for a Blackwell Fused Multi-Head Attention (FMHA) kernel. @@ -283,6 +130,7 @@ class BlackwellFusedMultiHeadAttentionForward: 3. Kernel Execution Mode: - is_persistent: Boolean indicating whether to use persistent kernel mode - mask_type: Specifies the type of mask to use (no mask, residual mask, or causal mask) + - window_size_left/right: Sliding window size for attention masking :param qk_acc_dtype: Data type for Q*K^T matrix multiplication accumulator :type qk_acc_dtype: Type[cutlass.Numeric] @@ -293,7 +141,11 @@ class BlackwellFusedMultiHeadAttentionForward: :param is_persistent: Whether to use persistent kernel mode :type is_persistent: bool :param mask_type: Type of mask to use - :type mask_type: MaskType + :type mask_type: fmha_utils.MaskEnum + :param window_size_left: Left-side sliding window size for attention masking + :type window_size_left: int + :param window_size_right: Right-side sliding window size for attention masking + :type window_size_right: int """ self.qk_acc_dtype = qk_acc_dtype @@ -312,6 +164,7 @@ class BlackwellFusedMultiHeadAttentionForward: self.cluster_shape_mn = (1, 1) self.is_persistent = is_persistent self.mask_type = mask_type + self.softmax0_warp_ids = (0, 1, 2, 3) self.softmax1_warp_ids = (4, 5, 6, 7) self.correction_warp_ids = (8, 9, 10, 11) @@ -335,8 +188,14 @@ class BlackwellFusedMultiHeadAttentionForward: ) ) - self.cta_sync_bar_id = 0 - self.tmem_alloc_sync_bar_id = 1 + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=self.threads_per_warp, + ) self.tmem_s0_offset = 0 self.tmem_s1_offset = 128 @@ -352,7 +211,6 @@ class BlackwellFusedMultiHeadAttentionForward: self.num_regs_softmax = 192 self.num_regs_correction = 96 self.num_regs_other = 32 - self.num_regs_empty = 24 self.buffer_align_bytes = 1024 @@ -388,10 +246,14 @@ class BlackwellFusedMultiHeadAttentionForward: v_iter: cute.Pointer, o_iter: cute.Pointer, problem_size: Tuple[Int32, Int32, Int32, Int32, Int32, Int32], - cum_seqlen_q: cute.Tensor | None, - cum_seqlen_k: cute.Tensor | None, + cum_seqlen_q: Optional[cute.Tensor], + cum_seqlen_k: Optional[cute.Tensor], + lse_iter: Optional[cute.Pointer], scale_softmax_log2: Float32, + scale_softmax: Float32, scale_output: Float32, + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], stream: cuda.CUstream, ): """Execute the Fused Multi-Head Attention operation on the provided tensors. @@ -415,22 +277,28 @@ class BlackwellFusedMultiHeadAttentionForward: :type v_iter: cute.Pointer :param o_iter: The output tensor pointer :type o_iter: cute.Pointer - :param problem_size: The problem size with shape [b, s_q, s_k, h_q, h_k, d]. If cum_seqlen_q or cum_seqlen_k is not None, s_q and s_k are the max of the cumulative sequence length respectively. + :param problem_size: The problem size with shape [b, s_q, s_lse, s_k, h_q, h_k, d]. If cum_seqlen_q or cum_seqlen_k is not None, s_q and s_k are the max of the cumulative sequence length respectively. :type problem_size: Tuple[Int32, Int32, Int32, Int32, Int32, Int32] :param cum_seqlen_q: The cumulative sequence length tensor for query - :type cum_seqlen_q: cute.Tensor | None + :type cum_seqlen_q: Optional[cute.Tensor] :param cum_seqlen_k: The cumulative sequence length tensor for key - :type cum_seqlen_k: cute.Tensor | None + :type cum_seqlen_k: Optional[cute.Tensor] :param scale_softmax_log2: The log2 scale factor for softmax :type scale_softmax_log2: Float32 + :param scale_softmax: The scale factor for softmax + :type scale_softmax: Float32 :param scale_output: The scale factor for the output :type scale_output: Float32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] :param stream: The CUDA stream to execute the kernel on :type stream: cuda.CUstream :raises TypeError: If tensor data types don't match or aren't supported :raises RuntimeError: If tensor layouts aren't in supported formats """ - b, s_q, s_k, h_q, h_k, d = problem_size + b, s_q, s_lse, s_k, h_q, h_k, d = problem_size h_r = h_q // h_k qo_offset = 0 if cum_seqlen_q is None else -s_q * d * h_r * h_k kv_offset = 0 if cum_seqlen_k is None else -s_k * d * h_k @@ -438,6 +306,8 @@ class BlackwellFusedMultiHeadAttentionForward: b_kv = b if cum_seqlen_k is None else s_k * (1 + b) stride_b_qo = h_r * h_k * s_q * d if cum_seqlen_q is None else d * h_r * h_k stride_b_kv = h_k * s_k * d if cum_seqlen_k is None else d * h_k + b_lse = b if cum_seqlen_q is None else 1 + stride_b_lse = h_r * h_k * s_lse if cum_seqlen_q is None else 0 # (s, d, ((h_r, h_k), b)) q_layout = cute.make_layout( @@ -463,6 +333,15 @@ class BlackwellFusedMultiHeadAttentionForward: stride=(d * h_r * h_k, 1, ((d, d * h_r), stride_b_qo)), ) o = cute.make_tensor(o_iter + qo_offset, o_layout) + if cutlass.const_expr(lse_iter is not None): + # (s, ((h_r, h_k), b)) + lse_layout = cute.make_layout( + (s_lse, ((h_r, h_k), b_lse)), + stride=(1, ((s_lse, h_r * s_lse), stride_b_lse)), + ) + lse = cute.make_tensor(lse_iter, lse_layout) + else: + lse = None # setup static attributes before smem/grid/tma computation self.q_dtype = q.element_type @@ -470,7 +349,7 @@ class BlackwellFusedMultiHeadAttentionForward: self.v_dtype = v.element_type self.o_dtype = o.element_type - self.tile_sched_params, grid = self._compute_grid( + self.tile_sched_params, grid = fmha_utils.compute_grid( cute.shape((s_q, d, ((h_r, h_k), b))), self.cta_tiler, self.is_persistent, @@ -591,16 +470,13 @@ class BlackwellFusedMultiHeadAttentionForward: self.cluster_layout_vmnk.shape, ) - o_cta_v_layout = cute.composition( - cute.make_identity_layout(o.shape), self.epi_tile - ) o_smem_layout = cute.select(o_smem_layout_staged, mode=[0, 1]) tma_atom_o, tma_tensor_o = cute.nvgpu.cpasync.make_tiled_tma_atom( tma_store_op, o, o_smem_layout, - o_cta_v_layout, + self.epi_tile, ) q_copy_size = cute.size_in_bytes(self.q_dtype, q_smem_layout) @@ -655,8 +531,12 @@ class BlackwellFusedMultiHeadAttentionForward: tma_tensor_o, cum_seqlen_q, cum_seqlen_k, + lse, scale_softmax_log2, + scale_softmax, scale_output, + window_size_left, + window_size_right, q_smem_layout_staged, k_smem_layout_staged, p_tmem_layout_staged, @@ -685,16 +565,20 @@ class BlackwellFusedMultiHeadAttentionForward: mV_dkl: cute.Tensor, tma_atom_o: cute.CopyAtom, mO_qdl: cute.Tensor, - cum_seqlen_q: cute.Tensor | None, - cum_seqlen_k: cute.Tensor | None, + cum_seqlen_q: Optional[cute.Tensor], + cum_seqlen_k: Optional[cute.Tensor], + mLSE: Optional[cute.Tensor], scale_softmax_log2: Float32, + scale_softmax: Float32, scale_output: Float32, + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], q_smem_layout_staged: cute.ComposedLayout, k_smem_layout_staged: cute.ComposedLayout, p_tmem_layout_staged: cute.ComposedLayout, v_smem_layout_staged: cute.ComposedLayout, o_smem_layout_staged: cute.ComposedLayout, - tile_sched_params: FmhaStaticTileSchedulerParams, + tile_sched_params: fmha_utils.FmhaStaticTileSchedulerParams, ): """The device kernel implementation of the Fused Multi-Head Attention. @@ -733,6 +617,10 @@ class BlackwellFusedMultiHeadAttentionForward: :type scale_softmax_log2: Float32 :param scale_output: The scale factor for the output :type scale_output: Float32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] :param q_smem_layout_staged: Shared memory layout for query tensor :type q_smem_layout_staged: cute.ComposedLayout :param k_smem_layout_staged: Shared memory layout for key tensor @@ -744,9 +632,8 @@ class BlackwellFusedMultiHeadAttentionForward: :param o_smem_layout_staged: Shared memory layout for output tensor :type o_smem_layout_staged: cute.ComposedLayout :param tile_sched_params: Scheduling parameters for work distribution - :type tile_sched_params: FmhaStaticTileSchedulerParams + :type tile_sched_params: fmha_utils.FmhaStaticTileSchedulerParams """ - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # coord inside cta tidx, _, _ = cute.arch.thread_idx() @@ -908,15 +795,12 @@ class BlackwellFusedMultiHeadAttentionForward: + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p1_offset, tOrP.layout, ) - cute.arch.barrier( - barrier_id=self.cta_sync_bar_id, - number_of_threads=self.threads_per_cta, - ) + self.cta_sync_barrier.arrive_and_wait() # /////////////////////////////////////////////////////////////////////////////// # EMPTY # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.empty_warp_id: - cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) # /////////////////////////////////////////////////////////////////////////////// # LOAD @@ -924,7 +808,7 @@ class BlackwellFusedMultiHeadAttentionForward: if warp_idx == self.load_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) - tile_sched = create_fmha_static_tile_scheduler( + tile_sched = fmha_utils.create_fmha_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() @@ -938,12 +822,10 @@ class BlackwellFusedMultiHeadAttentionForward: if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = ( - not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( - self.cta_tiler[0], - curr_block_coord[0], - seqlen_q, - ) + continue_cond = not fmha_utils.FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, ) if not continue_cond: mQ_qdl_ = mQ_qdl @@ -1038,7 +920,15 @@ class BlackwellFusedMultiHeadAttentionForward: tma_bar_ptr=q0_handle.barrier, ) # K0 - kv_coord = 0 # seqlen_kv_loop + seqlen_kv_loop_start = fmha_utils.FusedMask.get_trip_start( + self.mask_type, + curr_block_coord, + self.cta_tiler, + seqlen_q, + seqlen_k, + window_size_left, + ) + kv_coord = seqlen_kv_loop_start k_handle = load_kv_producer.acquire_and_advance() cute.copy( tma_atom_k, @@ -1066,7 +956,15 @@ class BlackwellFusedMultiHeadAttentionForward: kv_coord += 1 seqlen_kv_loop_steps = ( - self.get_trip_count(curr_block_coord, self.cta_tiler, seqlen_k) + fmha_utils.FusedMask.get_trip_count( + self.mask_type, + curr_block_coord, + self.cta_tiler, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) - 1 ) for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): @@ -1102,11 +1000,8 @@ class BlackwellFusedMultiHeadAttentionForward: # Alloc tmem buffer tmem_alloc_cols = Int32(self.tmem_alloc_cols) cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) - cute.arch.barrier( - barrier_id=self.tmem_alloc_sync_bar_id, - number_of_threads=self.threads_per_warp, - ) - tile_sched = create_fmha_static_tile_scheduler( + self.tmem_alloc_barrier.arrive_and_wait() + tile_sched = fmha_utils.create_fmha_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() @@ -1115,15 +1010,14 @@ class BlackwellFusedMultiHeadAttentionForward: curr_block_coord = work_tile.tile_idx batch_coord = curr_block_coord[2][1] continue_cond = False + seqlen_q = mQ_qdl.shape[0] if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = ( - not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( - self.cta_tiler[0], - curr_block_coord[0], - seqlen_q, - ) + continue_cond = not fmha_utils.FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, ) if not continue_cond: @@ -1144,13 +1038,13 @@ class BlackwellFusedMultiHeadAttentionForward: # 4. gemm num_kphases = cute.size(tSrQ0, mode=[2]) for kphase_idx in cutlass.range(num_kphases, unroll_full=True): - kphase_coord_0 = (None, None, kphase_idx) + kphase_coord = (None, None, kphase_idx) qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( qk_tiled_mma, tStS0, - tSrQ0[kphase_coord_0], - tSrK0[kphase_coord_0], + tSrQ0[kphase_coord], + tSrK0[kphase_coord], tStS0, ) # 5. release S0 @@ -1166,13 +1060,13 @@ class BlackwellFusedMultiHeadAttentionForward: # 3. gemm num_kphases = cute.size(tSrQ1, mode=[2]) for kphase_idx in cutlass.range(num_kphases, unroll_full=True): - kphase_coord_1 = (None, None, kphase_idx) + kphase_coord = (None, None, kphase_idx) qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( qk_tiled_mma, tStS1, - tSrQ1[kphase_coord_1], - tSrK0[kphase_coord_1], + tSrQ1[kphase_coord], + tSrK0[kphase_coord], tStS1, ) # 4. release S1 @@ -1198,13 +1092,13 @@ class BlackwellFusedMultiHeadAttentionForward: # 4. gemm num_kphases = cute.size(tOrP0, mode=[2]) for kphase_idx in cutlass.range(num_kphases, unroll_full=True): - kphase_coord_2 = (None, None, kphase_idx) + kphase_coord = (None, None, kphase_idx) pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( pv_tiled_mma, tOtO0, - tOrP0[kphase_coord_2], - tOrVi[kphase_coord_2], + tOrP0[kphase_coord], + tOrVi[kphase_coord], tOtO0, ) # 5. release accumulated O0_partial @@ -1212,7 +1106,15 @@ class BlackwellFusedMultiHeadAttentionForward: # End of GEMM_PV00 (P0 * V0 -> O0_partial) seqlen_kv_loop_steps = ( - self.get_trip_count(curr_block_coord, self.cta_tiler, seqlen_k) + fmha_utils.FusedMask.get_trip_count( + self.mask_type, + curr_block_coord, + self.cta_tiler, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) - 1 ) @@ -1228,13 +1130,13 @@ class BlackwellFusedMultiHeadAttentionForward: for kphase_idx in cutlass.range( inner_num_kphases, unroll_full=True ): - kphase_coord_3 = (None, None, kphase_idx) + kphase_coord = (None, None, kphase_idx) qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( qk_tiled_mma, tStS0, - tSrQ0[kphase_coord_3], - tSrKi[kphase_coord_3], + tSrQ0[kphase_coord], + tSrKi[kphase_coord], tStS0, ) # 3. release S0 @@ -1251,13 +1153,13 @@ class BlackwellFusedMultiHeadAttentionForward: for kphase_idx in cutlass.range( inner_num_kphases, unroll_full=True ): - kphase_coord_4 = (None, None, kphase_idx) + kphase_coord = (None, None, kphase_idx) pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, pv_whether_acc) cute.gemm( pv_tiled_mma, tOtO1, - tOrP1[kphase_coord_4], - tOrVi[kphase_coord_4], + tOrP1[kphase_coord], + tOrVi[kphase_coord], tOtO1, ) pv_whether_acc = True @@ -1273,13 +1175,13 @@ class BlackwellFusedMultiHeadAttentionForward: for kphase_idx in cutlass.range( inner_num_kphases, unroll_full=True ): - kphase_coord_5 = (None, None, kphase_idx) + kphase_coord = (None, None, kphase_idx) qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( qk_tiled_mma, tStS1, - tSrQ1[kphase_coord_5], - tSrKi[kphase_coord_5], + tSrQ1[kphase_coord], + tSrKi[kphase_coord], tStS1, ) s1_handle.commit() @@ -1300,13 +1202,13 @@ class BlackwellFusedMultiHeadAttentionForward: for kphase_idx in cutlass.range( inner_num_kphases, unroll_full=True ): - kphase_coord_6 = (None, None, kphase_idx) + kphase_coord = (None, None, kphase_idx) pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) cute.gemm( pv_tiled_mma, tOtO0, - tOrP0[kphase_coord_6], - tOrVi[kphase_coord_6], + tOrP0[kphase_coord], + tOrVi[kphase_coord], tOtO0, ) # 5. release accumulated O0_partial @@ -1326,15 +1228,16 @@ class BlackwellFusedMultiHeadAttentionForward: # 3. gemm num_kphases = cute.size(tOrP1, mode=[2]) for kphase_idx in cutlass.range(num_kphases, unroll_full=True): - kphase_coord_7 = (None, None, kphase_idx) - pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + kphase_coord = (None, None, kphase_idx) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, pv_whether_acc) cute.gemm( pv_tiled_mma, tOtO1, - tOrP1[kphase_coord_7], - tOrVi[kphase_coord_7], + tOrP1[kphase_coord], + tOrVi[kphase_coord], tOtO1, ) + pv_whether_acc = True # 4. commit accumulated O1 o1_handle.commit() # 5. release Vi_end @@ -1367,7 +1270,7 @@ class BlackwellFusedMultiHeadAttentionForward: # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.epilogue_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) - tile_sched = create_fmha_static_tile_scheduler( + tile_sched = fmha_utils.create_fmha_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() @@ -1382,12 +1285,10 @@ class BlackwellFusedMultiHeadAttentionForward: if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = ( - not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( - self.cta_tiler[0], - curr_block_coord[0], - seqlen_q, - ) + continue_cond = not fmha_utils.FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, ) if not continue_cond: curr_block_coord_o = curr_block_coord @@ -1456,12 +1357,15 @@ class BlackwellFusedMultiHeadAttentionForward: self.softmax( stage=0, seqlen_k=mK_kdl.shape[0], + seqlen_q=mQ_qdl.shape[0], cum_seqlen_q=cum_seqlen_q, cum_seqlen_k=cum_seqlen_k, scale_softmax_log2=scale_softmax_log2, qk_thr_mma=qk_thr_mma, tStS=tStS, tStSi=tStS0, + window_size_left=window_size_left, + window_size_right=window_size_right, mma_si_consumer=mma_s0_consumer, si_corr_producer=s0_corr_producer, s0_s1_sequence_consumer=s0_s1_sequence_consumer, @@ -1483,12 +1387,15 @@ class BlackwellFusedMultiHeadAttentionForward: self.softmax( stage=1, seqlen_k=mK_kdl.shape[0], + seqlen_q=mQ_qdl.shape[0], cum_seqlen_q=cum_seqlen_q, cum_seqlen_k=cum_seqlen_k, scale_softmax_log2=scale_softmax_log2, qk_thr_mma=qk_thr_mma, tStS=tStS, tStSi=tStS1, + window_size_left=window_size_left, + window_size_right=window_size_right, mma_si_consumer=mma_s1_consumer, si_corr_producer=s1_corr_producer, s0_s1_sequence_consumer=s0_s1_sequence_consumer, @@ -1531,29 +1438,39 @@ class BlackwellFusedMultiHeadAttentionForward: tTMEM_LOAD_VECtS1 = thr_tmem_load_vec.partition_S(tStS_vec1) tTMEM_LOAD_VECcS = thr_tmem_load_vec.partition_D(tScS_vec) - tile_sched = create_fmha_static_tile_scheduler( + tile_sched = fmha_utils.create_fmha_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx + curr_block_coord_lse = curr_block_coord batch_coord = curr_block_coord[2][1] seqlen_k = mK_kdl.shape[0] continue_cond = False + cuseqlen_q = Int32(0) + seqlen_q = mQ_qdl.shape[0] if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = ( - not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( - self.cta_tiler[0], - curr_block_coord[0], - seqlen_q, - ) + # for varlen LSE, batch == 1 + curr_block_coord_lse = ( + curr_block_coord[0], + curr_block_coord[1], + (curr_block_coord[2][0], 0), + ) + continue_cond = not fmha_utils.FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, ) if not continue_cond: + row_idx = ( + curr_block_coord[0] * self.cta_tiler[0] + tTMEM_LOAD_VECcS[0][0] + ) if cutlass.const_expr(cum_seqlen_k is not None): cuseqlen_k = cum_seqlen_k[batch_coord] seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k @@ -1561,14 +1478,23 @@ class BlackwellFusedMultiHeadAttentionForward: vec0_handle = s0_corr_consumer.wait_and_advance() vec0_handle.release() vec1_handle = s1_corr_consumer.wait_and_advance() + seqlen_kv_loop_steps = ( - self.get_trip_count(curr_block_coord, self.cta_tiler, seqlen_k) + fmha_utils.FusedMask.get_trip_count( + self.mask_type, + curr_block_coord, + self.cta_tiler, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) - 1 ) for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): # wait for vec0 (row_wise current max & previous max) vec0_handle = s0_corr_consumer.wait_and_advance() - tTMEM_LOAD_VECrS = cute.make_fragment( + tTMEM_LOAD_VECrS = cute.make_rmem_tensor( tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype ) cute.copy( @@ -1605,7 +1531,7 @@ class BlackwellFusedMultiHeadAttentionForward: # wait for vec0 (row_wise global sum) vec0_handle = s0_corr_consumer.wait_and_advance() - tTMEM_LOAD_VECrS = cute.make_fragment( + tTMEM_LOAD_VECrS = cute.make_rmem_tensor( tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype ) cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS0, tTMEM_LOAD_VECrS) @@ -1617,6 +1543,13 @@ class BlackwellFusedMultiHeadAttentionForward: self.correction_epilog( pv_thr_mma, tOtO0, + mLSE, + tTMEM_LOAD_VECrS, + row_idx, + cuseqlen_q, + seqlen_q, + curr_block_coord_lse, + scale_softmax, scale_output / tTMEM_LOAD_VECrS[0], sO[None, None, 0], ) @@ -1631,9 +1564,17 @@ class BlackwellFusedMultiHeadAttentionForward: # wait for o1 o1_handle = mma_corr_consumer.wait_and_advance() o1_final_handle = corr_epi_producer.acquire_and_advance() + row_idx += self.qk_mma_tiler[0] self.correction_epilog( pv_thr_mma, tOtO1, + mLSE, + tTMEM_LOAD_VECrS, + row_idx, + cuseqlen_q, + seqlen_q, + curr_block_coord_lse, + scale_softmax, scale_output / tTMEM_LOAD_VECrS[0], sO[None, None, 1], ) @@ -1686,7 +1627,7 @@ class BlackwellFusedMultiHeadAttentionForward: :type need_apply_mask: bool :param iter_args: Tuple containing the counting tensor, row_max, row_sum, and vector buffer's handle for current iteration :type iter_args: tuple - :param value_args: Tuple containing seqlen_k and scale_softmax_log2 + :param value_args: Tuple containing seqlen_k, seqlen_q, and scale_softmax_log2 :type value_args: tuple :param pipeline_args: Tuple containing pipeline related arguments for MMA, correction, and sequence synchronization :type pipeline_args: tuple @@ -1698,7 +1639,9 @@ class BlackwellFusedMultiHeadAttentionForward: :rtype: tuple """ cS, row_max, row_sum, vec_i_handle = iter_args - seqlen_k, scale_softmax_log2 = value_args + seqlen_k, seqlen_q, scale_softmax_log2, window_size_left, window_size_right = ( + value_args + ) ( mma_si_consumer, si_corr_producer, @@ -1735,17 +1678,25 @@ class BlackwellFusedMultiHeadAttentionForward: # Wait for Si si_handle = mma_si_consumer.wait_and_advance() - tTMEM_LOADrS = cute.make_fragment(tTMEM_LOADcS.shape, self.qk_acc_dtype) + tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype) cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) if need_apply_mask: - self.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, seqlen_k) + fmha_utils.FusedMask.apply_mask( + self.mask_type, + tTMEM_LOADrS, + tTMEM_LOADcS, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) old_row_max = row_max row_max = tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0) row_max_safe = row_max if row_max == -cutlass.Float32.inf: row_max_safe = 0.0 - tTMEM_STORE_VECrS = cute.make_fragment( + tTMEM_STORE_VECrS = cute.make_rmem_tensor( tTMEM_STORE_VECcS.shape, self.qk_acc_dtype ) tTMEM_STORE_VECrS[0] = old_row_max @@ -1755,7 +1706,7 @@ class BlackwellFusedMultiHeadAttentionForward: # Notify correction wg that row_max is ready vec_i_handle.commit() - tTMEM_STORErS_x4 = cute.make_fragment(tTMEM_STOREcS.shape, self.qk_acc_dtype) + tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype) tTMEM_STORErS_x4_e = cute.make_tensor( cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout, @@ -1850,17 +1801,20 @@ class BlackwellFusedMultiHeadAttentionForward: self, stage: int, seqlen_k: Int32, - cum_seqlen_q: cute.Tensor | None, - cum_seqlen_k: cute.Tensor | None, + seqlen_q: Int32, + cum_seqlen_q: Optional[cute.Tensor], + cum_seqlen_k: Optional[cute.Tensor], scale_softmax_log2: Float32, - qk_thr_mma: cute.core.ThrMma, + qk_thr_mma: cute.ThrMma, tStS: cute.Tensor, tStSi: cute.Tensor, + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], mma_si_consumer: pipeline.PipelineConsumer, si_corr_producer: pipeline.PipelineProducer, s0_s1_sequence_consumer: pipeline.PipelineConsumer, s0_s1_sequence_producer: pipeline.PipelineProducer, - tile_sched_params: FmhaStaticTileSchedulerParams, + tile_sched_params: fmha_utils.FmhaStaticTileSchedulerParams, ): """Compute softmax on attention scores from QK matrix multiplication. @@ -1875,14 +1829,26 @@ class BlackwellFusedMultiHeadAttentionForward: :param stage: Processing stage (0 for first half, 1 for second half of attention matrix) :type stage: int + :param seqlen_k: Length of the key sequence + :type seqlen_k: Int32 + :param seqlen_q: Length of the query sequence + :type seqlen_q: Int32 + :param cum_seqlen_q: Cumulative sequence lengths for queries + :type cum_seqlen_q: cute.Tensor | None + :param cum_seqlen_k: Cumulative sequence lengths for keys + :type cum_seqlen_k: cute.Tensor | None :param scale_softmax_log2: Log2 scale factor for softmax operation :type scale_softmax_log2: Float32 :param qk_thr_mma: Thread MMA operation for QK matrix multiplication - :type qk_thr_mma: cute.core.ThrMma + :type qk_thr_mma: cute.ThrMma :param tStS: Shared tensor for softmax input/output :type tStS: cute.Tensor :param tStSi: Input tensor containing attention scores :type tStSi: cute.Tensor + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] :param mma_si_pipeline: Pipeline for synchronizing with MMA operations :type mma_si_pipeline: pipeline.PipelineAsync :param si_corr_pipeline: Pipeline for synchronizing with correction operations @@ -1890,7 +1856,7 @@ class BlackwellFusedMultiHeadAttentionForward: :param s0_s1_sequence_pipeline: Pipeline for synchronizing between stage 0 and 1 :type s0_s1_sequence_pipeline: pipeline.PipelineAsync :param tile_sched_params: Parameters for tile scheduling - :type tile_sched_params: FmhaStaticTileSchedulerParams + :type tile_sched_params: fmha_utils.FmhaStaticTileSchedulerParams """ tidx, _, _ = cute.arch.thread_idx() thread_idx = tidx % ( @@ -1948,7 +1914,7 @@ class BlackwellFusedMultiHeadAttentionForward: thr_tmem_store = tiled_tmem_store.get_slice(thread_idx) tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P) - tile_sched = create_fmha_static_tile_scheduler( + tile_sched = fmha_utils.create_fmha_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() @@ -1957,17 +1923,17 @@ class BlackwellFusedMultiHeadAttentionForward: curr_block_coord = work_tile.tile_idx batch_coord = curr_block_coord[2][1] seqlen_k_ = seqlen_k + seqlen_q_ = seqlen_q continue_cond = False - + cuseqlen_q = Int32(0) + seqlen_q_ = seqlen_q if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] - seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = ( - not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( - self.cta_tiler[0], - curr_block_coord[0], - seqlen_q, - ) + seqlen_q_ = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = not fmha_utils.FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q_, ) if not continue_cond: @@ -1976,7 +1942,13 @@ class BlackwellFusedMultiHeadAttentionForward: seqlen_k_ = cum_seqlen_k[batch_coord + 1] - cuseqlen_k row_max = -Float32.inf row_sum = 0.0 - value_args = (seqlen_k_, scale_softmax_log2) + value_args = ( + seqlen_k_, + seqlen_q_, + scale_softmax_log2, + window_size_left, + window_size_right, + ) atom_args = ( qk_thr_mma, tiled_tmem_load, @@ -1999,12 +1971,68 @@ class BlackwellFusedMultiHeadAttentionForward: ) cS = cute.domain_offset(logical_offset, cS_base) vec_i_handle = si_corr_producer.acquire_and_advance() - unmask_count = self.get_unmasked_trip_count( + + start_count = fmha_utils.FusedMask.get_trip_start( + self.mask_type, curr_block_coord, self.cta_tiler, + seqlen_q_, seqlen_k_, + window_size_left, ) - for i in cutlass.range(0, unmask_count, 1, unroll=1): + + leading_mask_count = fmha_utils.FusedMask.get_masked_leading_count( + self.mask_type, + curr_block_coord, + self.cta_tiler, + seqlen_q_, + seqlen_k_, + window_size_left, + window_size_right, + ) + for i in cutlass.range( + start_count, start_count + leading_mask_count, 1, unroll=1 + ): + cS_iter = cute.domain_offset((0, i * self.qk_mma_tiler[1]), cS) + iter_args = (cS_iter, row_max, row_sum, vec_i_handle) + pipeline_args = ( + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) + ( + row_max, + row_sum, + vec_i_handle, + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) = self.softmax_step( + stage, + True, + iter_args, + value_args, + pipeline_args, + atom_args, + tensor_args, + ) + unmask_count = fmha_utils.FusedMask.get_unmasked_trip_count( + self.mask_type, + curr_block_coord, + self.cta_tiler, + seqlen_q_, + seqlen_k_, + window_size_left, + window_size_right, + ) + for i in cutlass.range( + start_count + leading_mask_count, + start_count + leading_mask_count + unmask_count, + 1, + unroll=1, + ): cS_iter = cute.domain_offset((0, i * self.qk_mma_tiler[1]), cS) iter_args = (cS_iter, row_max, row_sum, vec_i_handle) pipeline_args = ( @@ -2030,14 +2058,24 @@ class BlackwellFusedMultiHeadAttentionForward: atom_args, tensor_args, ) - mask_count = self.get_masked_trip_count( + trailing_mask_count = fmha_utils.FusedMask.get_masked_trailing_count( + self.mask_type, curr_block_coord, self.cta_tiler, + seqlen_q_, seqlen_k_, + window_size_left, + window_size_right, ) for i in cutlass.range( - unmask_count, unmask_count + mask_count, 1, unroll=1 + start_count + leading_mask_count + unmask_count, + start_count + + leading_mask_count + + unmask_count + + trailing_mask_count, + 1, + unroll=1, ): cS_iter = cute.domain_offset((0, i * self.qk_mma_tiler[1]), cS) iter_args = (cS_iter, row_max, row_sum, vec_i_handle) @@ -2065,7 +2103,7 @@ class BlackwellFusedMultiHeadAttentionForward: tensor_args, ) si_handle = mma_si_consumer.wait_and_advance() - tTMEM_STORE_VECrS = cute.make_fragment( + tTMEM_STORE_VECrS = cute.make_rmem_tensor( tTMEM_STORE_VECcS.shape, self.qk_acc_dtype ) tTMEM_STORE_VECrS[0] = row_sum @@ -2085,7 +2123,7 @@ class BlackwellFusedMultiHeadAttentionForward: @cute.jit def correction_rescale( self, - thr_mma: cute.core.ThrMma, + thr_mma: cute.ThrMma, tOtO: cute.Tensor, scale: Float32, ): @@ -2102,7 +2140,7 @@ class BlackwellFusedMultiHeadAttentionForward: 3. Store the rescaled results back to tensor memory :param thr_mma: Thread MMA operation for the computation - :type thr_mma: cute.core.ThrMma + :type thr_mma: cute.ThrMma :param tOtO: Tensor representing partial attention output to be rescaled :type tOtO: cute.Tensor :param scale: Scaling factor to apply to the partial results @@ -2147,7 +2185,7 @@ class BlackwellFusedMultiHeadAttentionForward: tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i) - tTMrO = cute.make_fragment( + tTMrO = cute.make_rmem_tensor( (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.pv_acc_dtype ) for i in range(self.cta_tiler[2] // corr_tile_size): @@ -2174,8 +2212,15 @@ class BlackwellFusedMultiHeadAttentionForward: @cute.jit def correction_epilog( self, - thr_mma: cute.core.ThrMma, + thr_mma: cute.ThrMma, tOtO: cute.Tensor, + mLSE: Optional[cute.Tensor], + tTMEM_LOAD_VECrS: cute.Tensor, + row_idx: Int32, + cuseqlen_q: Int32, + seqlen_q: Int32, + blk_coord: Int32, + scale_softmax: Float32, scale: Float32, sO: cute.Tensor, ): @@ -2193,9 +2238,23 @@ class BlackwellFusedMultiHeadAttentionForward: 5. Preparation for efficient TMA store operations :param thr_mma: Thread MMA operation for the computation - :type thr_mma: cute.core.ThrMma + :type thr_mma: cute.ThrMma :param tOtO: Tensor containing accumulated attention output :type tOtO: cute.Tensor + :param mLSE: Tensor containing log-sum-exp values for LSE calculation + :type mLSE: cute.Tensor | None + :param tTMEM_LOAD_VECrS: Tensor containing row sum and max values for softmax calculation + :type tTMEM_LOAD_VECrS: cute.Tensor + :param row_idx: Index of the current row being processed + :type row_idx: Int32 + :param cuseqlen_q: Cumulative sequence length of the current query + :type cuseqlen_q: Int32 + :param seqlen_q: Sequence length of the current query + :type seqlen_q: Int32 + :param blk_coord: Coordinate of the current block being processed + :type blk_coord: Int32 + :param scale_softmax: Scaling factor for softmax calculation + :type scale_softmax: Float32 :param scale: Final scaling factor to apply to the output :type scale: Float32 :param sO: Shared memory tensor for the final output @@ -2245,7 +2304,7 @@ class BlackwellFusedMultiHeadAttentionForward: for i in range(self.cta_tiler[2] // corr_tile_size): tTMEM_LOADtO_i = tTMEM_LOADtO[None, 0, 0, i] tTMEM_LOADsO_i = tTMEM_LOADsO[None, 0, 0, i] - tTMrO = cute.make_fragment( + tTMrO = cute.make_rmem_tensor( tTMEM_LOADoO[None, 0, 0, i].shape, self.pv_acc_dtype ) cute.copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO) @@ -2254,120 +2313,27 @@ class BlackwellFusedMultiHeadAttentionForward: (tTMrO[j], tTMrO[j + 1]), (scale, scale), ) - tSMrO = cute.make_fragment(tTMrO.shape, self.o_dtype) + tSMrO = cute.make_rmem_tensor(tTMrO.shape, self.o_dtype) o_vec = tTMrO.load() tSMrO.store(o_vec.to(self.o_dtype)) cute.copy(tiled_smem_store, tSMrO, tTMEM_LOADsO_i) + if cutlass.const_expr(mLSE is not None): + scaled_tmp = scale_softmax * tTMEM_LOAD_VECrS[1] + lse = cute.math.log(tTMEM_LOAD_VECrS[0], fastmath=True) + scaled_tmp + if row_idx < seqlen_q: + mLSE[row_idx + cuseqlen_q, blk_coord[2]] = lse + # fence view async shared cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta, ) - def get_trip_count( - self, - blk_coord: cute.Coord, - tile_shape: cute.Shape, - seqlen_k: Int32, - ) -> Int32: - result = 0 - if ( - self.mask_type == MaskType.NO_MASK - or self.mask_type == MaskType.RESIDUAL_MASK - ): - result = cute.ceil_div(seqlen_k, tile_shape[1]) - elif self.mask_type == MaskType.CAUSAL_MASK: - max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1]) - max_blocks_q = cute.ceil_div( - (blk_coord[0] + 1) * tile_shape[0], tile_shape[1] - ) - result = cutlass.min(max_blocks_k, max_blocks_q) - return result - - @cute.jit - def get_masked_trip_count( - self, - blk_coord: cute.Coord, - tile_shape: cute.Shape, - seqlen_k: Int32, - ) -> Int32: - result = 0 - if self.mask_type == MaskType.NO_MASK: - result = 0 - elif self.mask_type == MaskType.RESIDUAL_MASK: - if seqlen_k % tile_shape[1] != 0: - result = 1 - else: - result = 0 - elif self.mask_type == MaskType.CAUSAL_MASK: - trip_count = self.get_trip_count(blk_coord, tile_shape, seqlen_k) - result = cutlass.min( - trip_count, - cute.ceil_div(tile_shape[0], tile_shape[1]), - ) - return result - - @cute.jit - def get_unmasked_trip_count( - self, - blk_coord: cute.Coord, - tile_shape: cute.Shape, - seqlen_k: Int32, - ) -> Int32: - result = 0 - if self.mask_type == MaskType.NO_MASK: - result = self.get_trip_count(blk_coord, tile_shape, seqlen_k) - elif self.mask_type == MaskType.RESIDUAL_MASK: - if seqlen_k % tile_shape[1] != 0: - result = self.get_trip_count(blk_coord, tile_shape, seqlen_k) - 1 - else: - result = self.get_trip_count(blk_coord, tile_shape, seqlen_k) - elif self.mask_type == MaskType.CAUSAL_MASK: - result = self.get_trip_count( - blk_coord, tile_shape, seqlen_k - ) - self.get_masked_trip_count(blk_coord, tile_shape, seqlen_k) - return result - - @cute.jit - def apply_mask( - self, - acc_qk: cute.Tensor, - index_qk: cute.Tensor, - seqlen_k: Int32, - ): - if self.mask_type == MaskType.RESIDUAL_MASK: - for i in range(cute.size(acc_qk)): - pos = index_qk[i] - if pos[1] >= seqlen_k: - acc_qk[i] = -Float32.inf - elif self.mask_type == MaskType.CAUSAL_MASK: - for i in range(cute.size(acc_qk)): - pos = index_qk[i] - if pos[0] < pos[1] or pos[1] >= seqlen_k: - acc_qk[i] = -Float32.inf - - @staticmethod - def _compute_grid( - o_shape: cute.Shape, - cta_tiler: Tuple[int, int, int], - is_persistent: bool, - ) -> Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]: - tile_sched_params = create_fmha_static_tile_scheduler_params( - is_persistent, - ( - cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), - cute.size(o_shape[2][0]), - cute.size(o_shape[2][1]), - ), - ) - grid = FmhaStaticTileScheduler.get_grid_shape(tile_sched_params) - return tile_sched_params, grid - def run( - q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int], - k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int], + q_shape: Union[Tuple[int, int, int, int], Tuple[int, Tuple[int, ...], int, int]], + k_shape: Union[Tuple[int, int, int, int], Tuple[int, Tuple[int, ...], int, int]], in_dtype: Type[cutlass.Numeric], out_dtype: Type[cutlass.Numeric], qk_acc_dtype: Type[cutlass.Numeric], @@ -2375,6 +2341,9 @@ def run( mma_tiler_mn: Tuple[int, int], is_persistent: bool, is_causal: bool, + bottom_right_align: bool, + lse_calculation: bool, + window_size: Tuple[int, int], scale_q: float, scale_k: float, scale_v: float, @@ -2401,11 +2370,11 @@ def run( :param q_shape: Query tensor shape (B, S_q, H, D) where B=batch size, S_q=query sequence length, H=number of heads, D=head dimension. If S_q is a tuple, it is the variable sequence length. - :type q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int] + :type q_shape: Union[Tuple[int, int, int, int], Tuple[int, Tuple[int, ...], int, int]] :param k_shape: Key tensor shape (B, S_k, H_k, D) where B=batch size, S_k=key sequence length, H_k=number of key heads (H must be divisible by H_k), D=head dimension. If S_k is a tuple, it is the variable sequence length. - :type k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int] + :type k_shape: Union[Tuple[int, int, int, int], Tuple[int, Tuple[int, ...], int, int]] :param in_dtype: Input data type for query, key and value tensors :type in_dtype: Type[cutlass.Numeric] :param out_dtype: Output data type for attention output @@ -2420,6 +2389,10 @@ def run( :type is_persistent: bool :param is_causal: Whether to apply causal masking :type is_causal: bool + :param lse_calculation: Whether to calculate lse + :type lse_calculation: bool + :param window_size: Sliding window size (left, right) for attention masking. Controls which positions each query can attend to. + :type window_size: Tuple[int, int] :param scale_q: Scaling factor for query tensor :type scale_q: float :param scale_k: Scaling factor for key tensor @@ -2447,7 +2420,7 @@ def run( :rtype: float """ - print(f"Running Blackwell SM100 FMHA test with:") + print("Running Blackwell SM100 FMHA test with:") print(f" q_shape: {q_shape}") print(f" k_shape: {k_shape}") print(f" in_dtype: {in_dtype}") @@ -2457,6 +2430,9 @@ def run( print(f" mma_tiler_mn: {mma_tiler_mn}") print(f" is_persistent: {is_persistent}") print(f" is_causal: {is_causal}") + print(f" bottom_right_align: {bottom_right_align}") + print(f" lse_calculation: {lse_calculation}") + print(f" window_size: {window_size}") print(f" scale_q: {scale_q}") print(f" scale_k: {scale_k}") print(f" scale_v: {scale_v}") @@ -2471,6 +2447,11 @@ def run( # Unpack parameters b, s_q, h_q, d = q_shape b_, s_k, h_k, d_ = k_shape + window_size_left, window_size_right = window_size + if window_size_left == -1: + window_size_left = None + if window_size_right == -1: + window_size_right = None if b != b_: raise ValueError("q & k must have the same batch size") @@ -2580,11 +2561,15 @@ def run( # From ragged to jagged if s_cumsum is not None: + if len(shape) == 4: + jagged_dim = 1 # for q,k,v,o + else: + jagged_dim = 2 # for lse torch_tensor = torch.nested.nested_tensor_from_jagged( - values=torch_tensor, offsets=s_cumsum + values=torch_tensor, offsets=s_cumsum, jagged_dim=jagged_dim ) f32_torch_tensor = torch.nested.nested_tensor_from_jagged( - values=f32_torch_tensor, offsets=s_cumsum.cpu() + values=f32_torch_tensor, offsets=s_cumsum.cpu(), jagged_dim=jagged_dim ) return ( @@ -2595,12 +2580,15 @@ def run( qo_shape = (b, s_q, h_r * h_k, d) kv_shape = (b, s_k, h_k, d) + lse_shape = (b, h_r * h_k, s_q) qo_padding = (0, 0, 0, 0, 0) kv_padding = (0, 0, 0, 0, 0) + lse_padding = (0, 0, 0, 0) if isinstance(s_q, tuple): qo_shape = (1, sum(s_q), h_r * h_k, d) qo_padding = (0, max(s_q), 0, 0, 0) + lse_shape = (1, h_r * h_k, sum(s_q)) if isinstance(s_k, tuple): kv_shape = (1, sum(s_k), h_k, d) @@ -2634,20 +2622,66 @@ def run( s_cumsum=cum_seqlen_q_torch, is_dynamic_layout=True, ) + if lse_calculation: + _, lse_tensor, lse_torch = create_and_pad_tensor( + lse_shape, + lse_padding, + cutlass.Float32, + is_dynamic_layout=True, + ) + else: + lse_tensor = None + lse_torch = None mma_tiler = (*mma_tiler_mn, d) - mask_type = MaskType.NO_MASK + mask_type = fmha_utils.MaskEnum.WINDOW_MASK + if bottom_right_align: + mask_type = fmha_utils.MaskEnum.WINDOW_MASK_INFERENCE if is_causal: - mask_type = MaskType.CAUSAL_MASK - else: + window_size_right = 0 + elif window_size_left is None and window_size_right is None: if isinstance(s_k, tuple): for i in range(len(s_k)): if s_k[i] % mma_tiler_mn[1] != 0: - mask_type = MaskType.RESIDUAL_MASK + mask_type = fmha_utils.MaskEnum.RESIDUAL_MASK else: if s_k % mma_tiler_mn[1] != 0: - mask_type = MaskType.RESIDUAL_MASK + mask_type = fmha_utils.MaskEnum.RESIDUAL_MASK + + s_q_list = s_q if isinstance(s_q, tuple) else [s_q] * b + s_k_list = s_k if isinstance(s_k, tuple) else [s_k] * b + + # To avoid mask out the whole row which results in NaN in softmax + def check_seqlen_valid( + s_q, s_k, window_size_left, window_size_right, bottom_right_align + ): + for i in range(s_q): + offset = 0 if not bottom_right_align else s_k - s_q + + s_q_start = 0 if window_size_left is None else i + offset - window_size_left + s_q_end = ( + s_q if window_size_right is None else i + offset + window_size_right + ) + s_q_min = max(s_q_start, 0) + s_q_max = min(s_q_end, s_k) + + if s_q_max - s_q_min == 0 and (i != 0 and i != s_q - 1): + return False + return True + + need_check_seqlen_valid = ( + window_size_left is not None or window_size_right is not None + ) + for i in range(b): + if need_check_seqlen_valid and not check_seqlen_valid( + s_q_list[i], + s_k_list[i], + window_size_left, + window_size_right, + bottom_right_align, + ): + raise ValueError("sliding window doesn't support current setting") fmha = BlackwellFusedMultiHeadAttentionForward( qk_acc_dtype, @@ -2673,6 +2707,7 @@ def run( problem_size = ( b, max(s_q) if isinstance(s_q, tuple) else s_q, + sum(s_q) if isinstance(s_q, tuple) else s_q, # s_lse max(s_k) if isinstance(s_k, tuple) else s_k, h_q, h_k, @@ -2691,65 +2726,107 @@ def run( problem_size, cum_seqlen_q, cum_seqlen_k, + lse_tensor.iterator if lse_calculation else None, scale_softmax_log2, + scale_softmax, scale_output, + window_size_left if window_size_left is None else Int32(window_size_left), + window_size_right if window_size_right is None else Int32(window_size_right), current_stream, ) compilation_time = time.time() - start_time print(f"Compilation time: {compilation_time:.4f} seconds") - def run_torch_fmha(q, k, v, scale_softmax=1.0, scale_output=1.0, is_causal=False): + def run_torch_fmha( + q, + k, + v, + scale_softmax=1.0, + scale_output=1.0, + is_causal=False, + bottom_right_align=False, + lse_calculation=False, + window_size_left=None, + window_size_right=None, + ): h_q = q.shape[2] h_k = k.shape[2] + if not h_q == h_k: + repeat_factor = h_q // h_k + # nested tensor can not be broadcasted directly + if k.is_nested: + k_offsets = k.offsets() + v_offsets = v.offsets() + k_values = k.values().repeat_interleave(repeat_factor, dim=1) + v_values = v.values().repeat_interleave(repeat_factor, dim=1) + + k = torch.nested.nested_tensor_from_jagged( + values=k_values, offsets=k_offsets + ) + v = torch.nested.nested_tensor_from_jagged( + values=v_values, offsets=v_offsets + ) + else: + k = k.repeat_interleave(repeat_factor, dim=2) + v = v.repeat_interleave(repeat_factor, dim=2) + # as we initialize q, k, v with shape (b, s, h, d) and SDPA of torch needs them to be (b, h, s, d) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) - # For the situation that torch has not supported, we need to handle it manually - situation1 = (h_q != h_k or is_causal) and (q.is_nested or k.is_nested) - situation2 = (q.is_nested and not k.is_nested) or ( - not q.is_nested and k.is_nested - ) - if situation1 or situation2: - # Once torch supports the situation, we can remove this fallback - batch_size = q.size(0) - ref_list = [] - for batch_idx in range(batch_size): - q_i = q[batch_idx] - k_i = k[batch_idx] - v_i = v[batch_idx] + batch_size = q.size(0) + ref_list = [] + lse_list = [] + for batch_idx in range(batch_size): + q_i = q[batch_idx] + k_i = k[batch_idx] + v_i = v[batch_idx] + s_i = torch.einsum("hqd,hkd->hqk", q_i, k_i) * scale_softmax + s_q = q_i.shape[1] + s_k = k_i.shape[1] + if is_causal: + window_size_right = 0 + if window_size_left is not None or window_size_right is not None: + q_coords = torch.arange(0, s_q).cuda().view(-1, 1) + k_coords = torch.arange(0, s_k).cuda().view(1, -1) + offset = 0 if not bottom_right_align else s_k - s_q + if window_size_left is None: + _mask = k_coords > q_coords + offset + window_size_right + elif window_size_right is None: + _mask = k_coords < q_coords + offset - window_size_left + else: + _mask = (k_coords > q_coords + offset + window_size_right) | ( + k_coords < q_coords + offset - window_size_left + ) + s_i = s_i.masked_fill(_mask.cpu(), -torch.inf) - ref_i = F.scaled_dot_product_attention( - q_i, - k_i, - v_i, - attn_mask=None, - dropout_p=0.0, - scale=scale_softmax, - is_causal=is_causal, - enable_gqa=(h_q != h_k), - ) - ref_i = ref_i.transpose(0, 1) * scale_output - ref_list.append(ref_i) - if q.is_nested: - ref = torch.nested.nested_tensor(ref_list, layout=torch.jagged) + if lse_calculation: + lse_i = torch.logsumexp(s_i, dim=-1) else: - ref = torch.stack(ref_list) + lse_i = None + + p_i = torch.softmax(s_i, dim=-1) + ref_i = torch.einsum("hqk,hkd->hqd", p_i, v_i) + ref_i = ref_i.transpose(0, 1) * scale_output + ref_list.append(ref_i) + if lse_calculation: + lse_list.append(lse_i) + if q.is_nested: + ref = torch.nested.nested_tensor(ref_list, layout=torch.jagged) + if lse_calculation: + lse = torch.cat(lse_list, dim=1).unsqueeze(0) + else: + lse = None else: - ref = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - dropout_p=0.0, - scale=scale_softmax, - is_causal=is_causal, - enable_gqa=(h_q != h_k), - ) - ref = ref.transpose(1, 2) * scale_output - return ref + ref = torch.stack(ref_list) + if lse_calculation: + lse = torch.stack(lse_list) + else: + lse = None + + return ref, lse if not skip_ref_check: # Execute kernel once for reference checking @@ -2761,13 +2838,30 @@ def run( problem_size, cum_seqlen_q, cum_seqlen_k, + lse_tensor.iterator if lse_calculation else None, scale_softmax_log2, + scale_softmax, scale_output, + window_size_left if window_size_left is None else Int32(window_size_left), + ( + window_size_right + if window_size_right is None + else Int32(window_size_right) + ), current_stream, ) print("Verifying results...") - o_ref = run_torch_fmha( - q_ref, k_ref, v_ref, scale_softmax, scale_output, is_causal + o_ref, lse_ref = run_torch_fmha( + q_ref, + k_ref, + v_ref, + scale_softmax, + scale_output, + is_causal, + bottom_right_align, + lse_calculation, + window_size_left, + window_size_right, ) if o_ref.is_nested: @@ -2812,6 +2906,10 @@ def run( # Assert close results torch.testing.assert_close(o_result, o_ref, atol=tolerance, rtol=1e-05) + if lse_calculation: + torch.testing.assert_close( + lse_torch.cpu(), lse_ref, atol=tolerance, rtol=1e-05 + ) print("Results verified successfully!") def generate_tensors(): @@ -2843,6 +2941,16 @@ def run( s_cumsum=cum_seqlen_q_torch, is_dynamic_layout=True, ) + if lse_calculation: + _, lse_tensor, lse_torch = create_and_pad_tensor( + lse_shape, + lse_padding, + cutlass.Float32, + is_dynamic_layout=True, + ) + else: + lse_tensor = None + return testing.JitArguments( q_tensor_workspace.iterator, k_tensor_workspace.iterator, @@ -2851,8 +2959,16 @@ def run( problem_size, cum_seqlen_q, cum_seqlen_k, + lse_tensor, scale_softmax_log2, + scale_softmax, scale_output, + window_size_left if window_size_left is None else Int32(window_size_left), + ( + window_size_right + if window_size_right is None + else Int32(window_size_right) + ), current_stream, ) @@ -2867,6 +2983,11 @@ def run( + k_torch_effective.numel() * k_torch_effective.element_size() + v_torch_effective.numel() * v_torch_effective.element_size() + o_torch_effective.numel() * o_torch_effective.element_size() + + ( + lse_torch.numel() * lse_torch.element_size() + if lse_torch is not None + else 0 + ) ) workspace_count = testing.get_workspace_count( one_workspace_bytes, warmup_iterations, iterations @@ -2979,6 +3100,25 @@ if __name__ == "__main__": help="Whether to use casual mask", ) + parser.add_argument( + "--bottom_right_align", + action="store_true", + help="Whether to use bottom right align, under this settion, the end of q is aligned with the end of k.", + ) + + parser.add_argument( + "--lse_calculation", + action="store_true", + help="Whether to calculate lse", + ) + + parser.add_argument( + "--window_size", + type=parse_comma_separated_ints, + default=(-1, -1), + help="Sliding window size (left, right) for attention masking.", + ) + parser.add_argument( "--q_shape", type=parse_nested_comma_separated_ints, @@ -3085,6 +3225,9 @@ if __name__ == "__main__": args.mma_tiler_mn, args.is_persistent, args.is_causal, + args.bottom_right_align, + args.lse_calculation, + args.window_size, args.scale_q, args.scale_k, args.scale_v, diff --git a/examples/python/CuTeDSL/blackwell/fmha_bwd.py b/examples/python/CuTeDSL/blackwell/fmha_bwd.py new file mode 100644 index 00000000..88820c1b --- /dev/null +++ b/examples/python/CuTeDSL/blackwell/fmha_bwd.py @@ -0,0 +1,3569 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import enum +import math +import os +import sys +import random +import time +from typing import Type, Tuple, Union, Optional + +import torch +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack +from cutlass.cute.typing import Int32, Float32, Float8E4M3FN, Float16, BFloat16, Boolean + +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.join(current_dir, "..")) +from utils import fmha_helpers as fmha_utils + +""" +A fused multi-head attention (FMHA) backward pass example for the NVIDIA Blackwell SM100 architecture using CUTE DSL + +This example demonstrates an implementation of the backward pass of fused multi-head attention +using a TMA + Blackwell SM100 TensorCore warp-specialized kernel. The implementation fuses the computation of +dQ, dK, and dV into a single kernel, avoiding intermediate data movement between +global memory and shared memory, thus improving computational efficiency. + +The kernel implements key optimizations including: +- Warp specialization for different computation phases (load, MMA, compute, reduce) +- Pipeline stages between different warps for overlapping computation and memory access +- Support causal masking +- Support for variable sequence lengths +- Support for sliding window attention + +To run this example: + +.. code-block:: bash + + python examples/blackwell/fmha_bwd.py \\ + --s_q_max 1024 --s_k_max 1024 \\ + --h_q 8 --h_k 8 --d 128 --b 1 \\ + --element_dtype float16 --acc_dtype float32 \\ + --mma_tiler_mn 128,128 + +The above example runs FMHA backward with max sequence length 1024 for Q and K, +batch size 1, 8 attention heads for Q and K, and head dimension 128. +The Blackwell tcgen05 MMA tile shape is (128, 128), and the kernel uses fp16 for input/output +with fp32 for accumulation. + +Constraints for this example: +* Supported head dimensions: 64, and 128 +* mma_tiler_mn must be 128,128 +* For causal masking, use --is_causal +* For variable sequence lengths, use --varlen +* For sliding window attention, use --window_size x,y +""" + + +class BlackwellFusedMultiHeadAttentionBackward: + def __init__( + self, + element_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + mma_tiler: Tuple[int, int, int], + varlen: bool, + mask_type: fmha_utils.MaskEnum, + ): + self.element_dtype = element_dtype + self.acc_dtype = acc_dtype + self.cta_tiler = ( + mma_tiler[0], + mma_tiler[1], + mma_tiler[2], + ) + self.tile_shape_Q = mma_tiler[0] + self.tile_shape_K = mma_tiler[1] + self.tile_shape_dQ_K = mma_tiler[2] + self.tile_shape_dV_dO = mma_tiler[2] + # For S + self.KQ_mma_tiler = ( + mma_tiler[1], + mma_tiler[0], + mma_tiler[2], + ) + # For dP + self.VdO_mma_tiler = ( + mma_tiler[1], + mma_tiler[0], + mma_tiler[2], + ) + # For dV + self.PdO_mma_tiler = ( + mma_tiler[1], + mma_tiler[2], + mma_tiler[0], + ) + # For dK + self.dSQ_mma_tiler = ( + mma_tiler[1], + mma_tiler[2], + mma_tiler[0], + ) + # For dQ + self.dSK_mma_tiler = ( + mma_tiler[0], + mma_tiler[2], + mma_tiler[1], + ) + self.cluster_shape_mn = (1, 1) + self.varlen = varlen + self.mask_type = mask_type + + # =================== Sum OdO ================================ + self.sum_OdO_max_threads_per_block = 128 + self.sum_OdO_block_q = 16 + self.sum_OdO_num_threads_d = 8 + self.sum_OdO_num_threads_q = ( + self.sum_OdO_max_threads_per_block // self.sum_OdO_num_threads_d + ) + self.sum_OdO_elem_per_load = 2 + + self.reduce_warp_id = (0, 1, 2, 3) + self.compute_warp_id = (4, 5, 6, 7, 8, 9, 10, 11) + self.mma_warp_id = 12 + self.load_warp_id = 13 + self.empty_warp_id = 14 + + self.num_reduce_warps = 4 + self.num_compute_warps = 8 + + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * ( + self.num_reduce_warps + self.num_compute_warps + 4 + ) + + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=self.threads_per_warp, + ) + self.compute_sync_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=self.num_compute_warps * self.threads_per_warp, + ) + self.epilogue_sync_barrier = pipeline.NamedBarrier( + barrier_id=4, + num_threads=self.num_compute_warps * self.threads_per_warp, + ) + self.reduce_sync_barrier = pipeline.NamedBarrier( + barrier_id=5, + num_threads=self.num_reduce_warps * self.threads_per_warp, + ) + + self.tmem_dK_offset = 0 + self.tmem_dV_offset = self.tmem_dK_offset + mma_tiler[2] + self.tmem_dQ_offset = self.tmem_dV_offset + mma_tiler[2] + self.tmem_dP_offset = self.tmem_dQ_offset + self.tmem_S_offset = self.tmem_dQ_offset + max(mma_tiler[0], mma_tiler[2]) + + self.num_regs_reduce = 152 + self.num_regs_compute = 128 + self.num_regs_mma = 96 + self.num_regs_empty = 96 + self.num_regs_load = 96 + + self.buffer_align_bytes = 1024 + + def _setup_attributes(self): + self.load_mma_Q_stage = 2 + self.load_mma_dO_stage = 1 + self.load_compute_LSE_stage = 1 + self.load_compute_sum_OdO_stage = 1 + self.mma_compute_S_stage = 1 + self.mma_compute_dP_stage = 1 + self.mma_reduce_dQ_stage = 1 + self.compute_mma_P_stage = 1 + self.compute_mma_dS_stage = 1 + self.mma_compute_dKdV_stage = 2 + self.reduce_tma_store_stage = 2 + + @cute.jit + def __call__( + self, + problem_shape: Tuple[Int32, Int32, Int32, Tuple[Tuple[Int32, Int32], Int32]], + Q: cute.Tensor, + K: cute.Tensor, + V: cute.Tensor, + O: cute.Tensor, + dQ: cute.Tensor, + dK: cute.Tensor, + dV: cute.Tensor, + dO: cute.Tensor, + LSE: cute.Tensor, + cumulative_s_q: Union[cute.Tensor, None], + cumulative_s_k: Union[cute.Tensor, None], + scale_softmax: Float32, + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + workspace: cute.Tensor, + stream: cuda.CUstream, + ): + q_seq_max, k_seq_max, d, hb = problem_shape + h, b = hb + h_r, h_k = h + # (s, d, h_r, h_k, b) -> (s, d, ((h_r, h_k), b)) + Q = cute.make_tensor( + Q.iterator, + cute.make_layout( + (Q.shape[0], Q.shape[1], hb), + stride=( + Q.stride[0], + Q.stride[1], + ( + (Q.shape[1], Q.shape[1] * Q.shape[2]), + ( + 0 + if self.varlen + else cute.assume( + Q.shape[0] * Q.shape[1] * h_r * h_k, divby=64 + ) + ), + ), + ), + ), + ) + # (s, d, 1, h_k, b) -> (s, d, ((1, h_k), b)) + K = cute.make_tensor( + K.iterator, + cute.make_layout( + (K.shape[0], K.shape[1], hb), + stride=( + K.stride[0], + K.stride[1], + ( + (0, K.shape[1]), + ( + 0 + if self.varlen + else cute.assume( + K.shape[0] * K.shape[1] * 1 * h_k, divby=64 + ) + ), + ), + ), + ), + ) + # (s, d, 1, h_k, b) -> (s, d, ((1, h_k), b)) + V = cute.make_tensor( + V.iterator, + cute.make_layout( + (V.shape[0], V.shape[1], hb), + stride=( + V.stride[0], + V.stride[1], + ( + (0, V.shape[1]), + ( + 0 + if self.varlen + else cute.assume( + V.shape[0] * V.shape[1] * 1 * h_k, divby=64 + ) + ), + ), + ), + ), + ) + O = cute.make_tensor(O.iterator, Q.layout) + + dQ = cute.make_tensor(dQ.iterator, Q.layout) + dK = cute.make_tensor(dK.iterator, K.layout) + dV = cute.make_tensor(dV.iterator, V.layout) + dO = cute.make_tensor(dO.iterator, O.layout) + + # (s, h_r, h_k, b) -> (s, ((h_r, h_k), b)) + LSE = cute.make_tensor( + LSE.iterator, + cute.make_layout( + (LSE.shape[0], hb), + stride=( + LSE.stride[0], + ( + (LSE.shape[0], LSE.shape[0] * LSE.shape[1]), + ( + 0 + if LSE.shape[3] == 1 + else LSE.shape[0] * LSE.shape[1] * LSE.shape[2] + ), + ), + ), + ), + ) + + self.Q_major_mode = utils.LayoutEnum.from_tensor(Q).mma_major_mode() + self.dQ_major_mode = utils.LayoutEnum.from_tensor(dQ).mma_major_mode() + self.K_major_mode = utils.LayoutEnum.from_tensor(K).mma_major_mode() + self.dK_major_mode = utils.LayoutEnum.from_tensor(dK).mma_major_mode() + self.V_major_mode = utils.LayoutEnum.from_tensor(V).mma_major_mode() + self.dV_major_mode = utils.LayoutEnum.from_tensor(dV).mma_major_mode() + self.O_major_mode = utils.LayoutEnum.from_tensor(O).mma_major_mode() + self.dO_major_mode = utils.LayoutEnum.from_tensor(dO).mma_major_mode() + self.dQ_layout = utils.LayoutEnum.from_tensor(dQ) + + if cutlass.const_expr(self.Q_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of q is not supported") + if cutlass.const_expr(self.dQ_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of dq is not supported") + if cutlass.const_expr(self.K_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of k is not supported") + if cutlass.const_expr(self.dK_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of dk is not supported") + if cutlass.const_expr(self.V_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of v is not supported") + if cutlass.const_expr(self.dV_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of dv is not supported") + + self._setup_attributes() + + cta_group = tcgen05.CtaGroup.ONE + + # compute S + KQ_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.element_dtype, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + self.acc_dtype, + cta_group, + self.KQ_mma_tiler[:2], + ) + # compute dP + VdO_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.element_dtype, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + self.acc_dtype, + cta_group, + self.VdO_mma_tiler[:2], + ) + # compute dV + PdO_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.element_dtype, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.MN, + self.acc_dtype, + cta_group, + self.PdO_mma_tiler[:2], + tcgen05.OperandSource.TMEM, + ) + # compute dK + dSQ_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.element_dtype, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.MN, + self.acc_dtype, + cta_group, + self.dSQ_mma_tiler[:2], + ) + # compute dQ + dSK_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.element_dtype, + tcgen05.OperandMajorMode.MN, + tcgen05.OperandMajorMode.MN, + self.acc_dtype, + cta_group, + self.dSK_mma_tiler[:2], + ) + + self.cluster_shape_mn = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mn), + (KQ_tiled_mma.thr_id.shape,), + ) + + K_smem_layout_staged = sm100_utils.make_smem_layout_a( + KQ_tiled_mma, + self.KQ_mma_tiler, + self.element_dtype, + 1, + ) + Q_smem_layout_staged = sm100_utils.make_smem_layout_b( + KQ_tiled_mma, + self.KQ_mma_tiler, + self.element_dtype, + self.load_mma_Q_stage, + ) + V_smem_layout_staged = sm100_utils.make_smem_layout_a( + VdO_tiled_mma, + self.VdO_mma_tiler, + self.element_dtype, + 1, + ) + dO_smem_layout_staged = sm100_utils.make_smem_layout_b( + VdO_tiled_mma, + self.VdO_mma_tiler, + self.element_dtype, + self.load_mma_dO_stage, + ) + dS_smem_layout_staged = sm100_utils.make_smem_layout_a( + dSK_tiled_mma, + self.dSK_mma_tiler, + self.element_dtype, + self.compute_mma_dS_stage, + ) + KT_smem_layout_staged = sm100_utils.make_smem_layout_b( + dSK_tiled_mma, + self.dSK_mma_tiler, + self.element_dtype, + 1, + ) + dST_smem_layout_staged = sm100_utils.make_smem_layout_a( + dSQ_tiled_mma, + self.dSQ_mma_tiler, + self.element_dtype, + self.compute_mma_dS_stage, + ) + QT_smem_layout_staged = sm100_utils.make_smem_layout_b( + dSQ_tiled_mma, + self.dSQ_mma_tiler, + self.element_dtype, + self.load_mma_Q_stage, + ) + P_tmem_layout_staged = sm100_utils.make_smem_layout_a( + PdO_tiled_mma, + self.PdO_mma_tiler, + self.element_dtype, + self.compute_mma_P_stage, + ) + dOT_smem_layout_staged = sm100_utils.make_smem_layout_b( + PdO_tiled_mma, + self.PdO_mma_tiler, + self.element_dtype, + self.load_mma_dO_stage, + ) + LSE_smem_layout = cute.make_layout( + (self.cta_tiler[0], self.load_compute_LSE_stage) + ) + sum_OdO_smem_layout = cute.make_layout( + (self.cta_tiler[0], self.load_compute_sum_OdO_stage) + ) + + dQ_smem_layout_atom = sm100_utils.make_smem_layout_atom( + sm100_utils.get_smem_layout_atom_ab( + tcgen05.OperandMajorMode.K, + self.acc_dtype, + (self.tile_shape_Q, 32), + ), + self.acc_dtype, + ) + dQ_smem_layout_staged = cute.tile_to_shape( + dQ_smem_layout_atom, + (self.tile_shape_Q, 32, self.reduce_tma_store_stage), + order=(1, 0, 2), + ) + + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_reduce_op = cpasync.CopyReduceBulkTensorTileS2GOp() + + K_smem_layout = cute.select(K_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + K, + K_smem_layout, + self.KQ_mma_tiler, + KQ_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + V_smem_layout = cute.select(V_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + V, + V_smem_layout, + self.VdO_mma_tiler, + VdO_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + Q_smem_layout = cute.select(Q_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + Q, + Q_smem_layout, + self.KQ_mma_tiler, + KQ_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + dO_smem_layout = cute.select(dO_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + dO, + dO_smem_layout, + self.VdO_mma_tiler, + VdO_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + self.tma_copy_Q_bytes = cute.size_in_bytes(self.element_dtype, Q_smem_layout) + self.tma_copy_K_bytes = cute.size_in_bytes(self.element_dtype, K_smem_layout) + self.tma_copy_V_bytes = cute.size_in_bytes(self.element_dtype, V_smem_layout) + self.tma_copy_dO_bytes = cute.size_in_bytes(self.element_dtype, dO_smem_layout) + + @cute.struct + class SharedStorage: + # Pipeline barriers + load_mma_Q_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.load_mma_Q_stage * 2 + ] + load_mma_dO_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.load_mma_dO_stage * 2 + ] + load_compute_lse_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.load_compute_LSE_stage * 2 + ] + load_compute_sum_OdO_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.load_compute_sum_OdO_stage * 2 + ] + mma_compute_S_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.mma_compute_S_stage * 2 + ] + mma_compute_dP_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.mma_compute_dP_stage * 2 + ] + mma_reduce_dQ_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.mma_reduce_dQ_stage * 2 + ] + compute_mma_P_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.compute_mma_P_stage * 2 + ] + compute_mma_dS_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.compute_mma_dS_stage * 2 + ] + mma_compute_dKdV_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.mma_compute_dKdV_stage * 2 + ] + tmem_holding_buf: cutlass.Int32 + # Smem tensors + sK: cute.struct.Align[ + cute.struct.MemRange[ + self.element_dtype, cute.cosize(K_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + sV: cute.struct.Align[ + cute.struct.MemRange[ + self.element_dtype, cute.cosize(V_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + sQ: cute.struct.Align[ + cute.struct.MemRange[ + self.element_dtype, cute.cosize(Q_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + sdO: cute.struct.Align[ + cute.struct.MemRange[ + self.element_dtype, cute.cosize(dO_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + sdS: cute.struct.Align[ + cute.struct.MemRange[ + self.element_dtype, cute.cosize(dS_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + sdQ: cute.struct.Align[ + cute.struct.MemRange[ + self.acc_dtype, cute.cosize(dQ_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + sLSE: cute.struct.Align[ + cute.struct.MemRange[self.acc_dtype, cute.cosize(LSE_smem_layout)], + self.buffer_align_bytes, + ] + sSum_OdO: cute.struct.Align[ + cute.struct.MemRange[self.acc_dtype, cute.cosize(sum_OdO_smem_layout)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + sum_OdO, scaled_LSE, dQ_acc = self.get_workspace_tensor( + problem_shape, workspace, self.acc_dtype + ) + + dQ_smem_layout = cute.select(dQ_smem_layout_staged, mode=[0, 1]) + + tma_atom_dQ_acc, tma_tensor_dQ_acc = cute.nvgpu.cpasync.make_tiled_tma_atom( + tma_reduce_op, + dQ_acc, + dQ_smem_layout, + (self.tile_shape_Q, 32), + ) + + # =============================== Sum OdO =============================== + sum_OdO_scale = Float32(-1.0) + LSE_scale = Float32(-math.log2(math.e)) + + sum_OdO_grid = self._compute_sum_OdO_grid(problem_shape, self.sum_OdO_block_q) + + self.sum_OdO( + O, + dO, + sum_OdO, + LSE, + scaled_LSE, + cumulative_s_q, + sum_OdO_scale, + LSE_scale, + problem_shape, + ).launch( + grid=sum_OdO_grid, + block=[self.sum_OdO_num_threads_d, self.sum_OdO_num_threads_q, 1], + cluster=[1, 1, 1], + stream=stream, + min_blocks_per_mp=1, + ) + + # =============================== Bwd =============================== + bwd_grid = self._compute_bwd_grid(problem_shape, self.cta_tiler[1]) + + self.bwd( + KQ_tiled_mma, + VdO_tiled_mma, + PdO_tiled_mma, + dSQ_tiled_mma, + dSK_tiled_mma, + tma_atom_K, + tma_tensor_K, + tma_atom_V, + tma_tensor_V, + tma_atom_Q, + tma_tensor_Q, + tma_atom_dO, + tma_tensor_dO, + tma_atom_dQ_acc, + tma_tensor_dQ_acc, + dK, + dV, + scaled_LSE, + scale_softmax, + sum_OdO, + problem_shape, + cumulative_s_q, + cumulative_s_k, + window_size_left, + window_size_right, + K_smem_layout_staged, + Q_smem_layout_staged, + V_smem_layout_staged, + dO_smem_layout_staged, + dS_smem_layout_staged, + KT_smem_layout_staged, + dST_smem_layout_staged, + QT_smem_layout_staged, + dOT_smem_layout_staged, + dQ_smem_layout_staged, + P_tmem_layout_staged, + LSE_smem_layout, + sum_OdO_smem_layout, + ).launch( + grid=bwd_grid, + block=[self.threads_per_cta, 1, 1], + cluster=[1, 1, 1], + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + + # =============================== Convert =============================== + self.block_seq = 8 + self.num_threads_D_convert = 16 + self.num_threads_seq = 128 // self.num_threads_D_convert + self.iter_seq = self.block_seq // self.num_threads_seq + self.convert_elem_per_load = 4 + + max_seq_in_qk = max(problem_shape[0], problem_shape[1]) + convert_grid_z = (max_seq_in_qk + self.block_seq - 1) // self.block_seq + convert_grid = [ + cute.size(problem_shape[3][0]), + cute.size(problem_shape[3][1]), + convert_grid_z, + ] + convert_block = [self.num_threads_D_convert, self.num_threads_seq, 1] + + self.convert( + dQ_acc, + dQ, + problem_shape[0], + problem_shape[2], + cumulative_s_q, + scale_softmax, + ).launch( + grid=convert_grid, + block=convert_block, + cluster=[1, 1, 1], + smem=0, + stream=stream, + ) + + @cute.kernel + def sum_OdO( + self, + O: cute.Tensor, + dO: cute.Tensor, + sum_OdO: cute.Tensor, + lse: cute.Tensor, + scaled_lse: cute.Tensor, + cumulative_s_q: Union[cute.Tensor, None], + sum_OdO_scale: Float32, + lse_scale: Float32, + problem_shape: Tuple[Int32, Int32, Int32, Tuple[Tuple[Int32, Int32], Int32]], + ): + bidx, bidy, bidz = cute.arch.block_idx() + tidx, tidy, tidz = cute.arch.thread_idx() + + seqlen_q = problem_shape[0] + offset = 0 + if cutlass.const_expr(self.varlen): + offset = cumulative_s_q[bidz] + seqlen_q = cumulative_s_q[bidz + 1] - offset + + for idx_q_t in cutlass.range( + tidy, self.sum_OdO_block_q, self.sum_OdO_num_threads_q, unroll_full=True + ): + idx_q = idx_q_t + self.sum_OdO_block_q * bidx + if idx_q < seqlen_q: + O_bhq = O[idx_q + offset, None, (bidy, bidz)] + O_bhq = cute.logical_divide( + O_bhq, cute.make_layout(self.sum_OdO_elem_per_load) + ) + dO_bhq = dO[idx_q + offset, None, (bidy, bidz)] + dO_bhq = cute.logical_divide( + dO_bhq, cute.make_layout(self.sum_OdO_elem_per_load) + ) + + idx_d_start = tidx + idx_d_step = self.sum_OdO_num_threads_d + acc = 0.0 + for idx_d in cutlass.range( + idx_d_start, O.shape[1] // self.sum_OdO_elem_per_load, idx_d_step + ): + O_frag = O_bhq[None, idx_d].load() + dO_frag = dO_bhq[None, idx_d].load() + prod_frag = O_frag * dO_frag + prod_frag = prod_frag.to(self.acc_dtype) + acc += prod_frag.reduce( + cute.ReductionOp.ADD, 0.0, reduction_profile=0 + ) + + acc = cute.arch.warp_reduction_sum( + acc, threads_in_group=self.sum_OdO_num_threads_d + ) + + if tidx == 0: + lse_bhq = lse[idx_q + offset, (bidy, bidz)] + sum_OdO[idx_q, (bidy, bidz)] = sum_OdO_scale * acc + scaled_lse[idx_q, (bidy, bidz)] = lse_scale * lse_bhq + + @cute.kernel + def bwd( + self, + KQ_tiled_mma: cute.TiledMma, + VdO_tiled_mma: cute.TiledMma, + PdO_tiled_mma: cute.TiledMma, + dSQ_tiled_mma: cute.TiledMma, + dSK_tiled_mma: cute.TiledMma, + tma_atom_K: cute.CopyAtom, + K_in: cute.Tensor, + tma_atom_V: cute.CopyAtom, + V_in: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + Q_in: cute.Tensor, + tma_atom_dO: cute.CopyAtom, + dO_in: cute.Tensor, + tma_atom_dQ_acc: cute.CopyAtom, + dQ_acc: cute.Tensor, + dK: cute.Tensor, + dV: cute.Tensor, + LSE: cute.Tensor, + scale_softmax: Float32, + sum_OdO: cute.Tensor, + problem_shape: Tuple[Int32, Int32, Int32, Tuple[Int32, Int32]], + cumulative_s_q: Union[cute.Tensor, None], + cumulative_s_k: Union[cute.Tensor, None], + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + K_smem_layout_staged: cute.ComposedLayout, + Q_smem_layout_staged: cute.ComposedLayout, + V_smem_layout_staged: cute.ComposedLayout, + dO_smem_layout_staged: cute.ComposedLayout, + dS_smem_layout_staged: cute.ComposedLayout, + KT_smem_layout_staged: cute.ComposedLayout, + dST_smem_layout_staged: cute.ComposedLayout, + QT_smem_layout_staged: cute.ComposedLayout, + dOT_smem_layout_staged: cute.ComposedLayout, + dQ_smem_layout_staged: cute.ComposedLayout, + P_tmem_layout_staged: cute.ComposedLayout, + LSE_smem_layout: cute.Layout, + sum_OdO_smem_layout: cute.Layout, + ): + tidx, tidy, tidz = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + grid_dim_x, grid_dim_y, grid_dim_z = cute.arch.grid_dim() + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + if warp_idx == self.load_warp_id: + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_V) + cpasync.prefetch_descriptor(tma_atom_dO) + + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + load_mma_Q_pipeline = self.make_and_init_load_mma_Q_pipeline( + storage.load_mma_Q_mbar_ptr.data_ptr() + ) + load_mma_dO_pipeline = self.make_and_init_load_mma_dO_pipeline( + storage.load_mma_dO_mbar_ptr.data_ptr() + ) + load_compute_LSE_pipeline = self.make_and_init_load_compute_LSE_pipeline( + storage.load_compute_lse_mbar_ptr.data_ptr() + ) + load_compute_sum_OdO_pipeline = ( + self.make_and_init_load_compute_sum_OdO_pipeline( + storage.load_compute_sum_OdO_mbar_ptr.data_ptr() + ) + ) + mma_compute_S_pipeline = self.make_and_init_mma_compute_S_pipeline( + storage.mma_compute_S_mbar_ptr.data_ptr() + ) + mma_compute_dP_pipeline = self.make_and_init_mma_compute_dP_pipeline( + storage.mma_compute_dP_mbar_ptr.data_ptr() + ) + mma_reduce_dQ_pipeline = self.make_and_init_mma_reduce_dQ_pipeline( + storage.mma_reduce_dQ_mbar_ptr.data_ptr() + ) + compute_mma_P_pipeline = self.make_and_init_compute_mma_P_pipeline( + storage.compute_mma_P_mbar_ptr.data_ptr() + ) + compute_mma_dS_pipeline = self.make_and_init_compute_mma_dS_pipeline( + storage.compute_mma_dS_mbar_ptr.data_ptr() + ) + mma_compute_dKdV_pipeline = self.make_and_init_mma_compute_dKdV_pipeline( + storage.mma_compute_dKdV_mbar_ptr.data_ptr() + ) + reduce_tma_store_pipeline = self.make_and_init_reduce_tma_store_pipeline() + + self.cta_sync_barrier.arrive_and_wait() + + # setup mma + sQ = storage.sQ.get_tensor( + Q_smem_layout_staged.outer, swizzle=Q_smem_layout_staged.inner + ) + sK = storage.sK.get_tensor( + K_smem_layout_staged.outer, swizzle=K_smem_layout_staged.inner + ) + sV = storage.sV.get_tensor( + V_smem_layout_staged.outer, swizzle=V_smem_layout_staged.inner + ) + sdO = storage.sdO.get_tensor( + dO_smem_layout_staged.outer, swizzle=dO_smem_layout_staged.inner + ) + sdQ = storage.sdQ.get_tensor( + dQ_smem_layout_staged.outer, swizzle=dQ_smem_layout_staged.inner + ) + sLSE = storage.sLSE.get_tensor(LSE_smem_layout) + sSum_OdO = storage.sSum_OdO.get_tensor(sum_OdO_smem_layout) + tmem_holding_buf = storage.tmem_holding_buf + + sQT_ptr = cute.recast_ptr(sQ.iterator, QT_smem_layout_staged.inner) + sQT = cute.make_tensor(sQT_ptr, QT_smem_layout_staged.outer) + sKT_ptr = cute.recast_ptr(sK.iterator, KT_smem_layout_staged.inner) + sKT = cute.make_tensor(sKT_ptr, KT_smem_layout_staged.outer) + sdS = storage.sdS.get_tensor( + dS_smem_layout_staged.outer, swizzle=dS_smem_layout_staged.inner + ) + sdST_ptr = cute.recast_ptr(sdS.iterator, dST_smem_layout_staged.inner) + sdST = cute.make_tensor(sdST_ptr, dST_smem_layout_staged.outer) + tP_fake_ptr = cute.make_ptr(self.element_dtype, 0, cute.AddressSpace.tmem) + tP = cute.make_tensor(tP_fake_ptr, P_tmem_layout_staged.outer) + sdOT_ptr = cute.recast_ptr(sdO.iterator, dOT_smem_layout_staged.inner) + sdOT = cute.make_tensor(sdOT_ptr, dOT_smem_layout_staged.outer) + + # (MMA, MMA_M, MMA_K, STAGE) + tSTrK = KQ_tiled_mma.make_fragment_A(sK) + # (MMA, MMA_N, MMA_K, STAGE) + tSTrQ = KQ_tiled_mma.make_fragment_B(sQ) + + # (MMA, MMA_M, MMA_K, STAGE) + tdPTrV = VdO_tiled_mma.make_fragment_A(sV) + # (MMA, MMA_N, MMA_K, STAGE) + tdPTrdO = VdO_tiled_mma.make_fragment_B(sdO) + + # (MMA, MMA_M, MMA_K, STAGE) + tdQrdS = dSK_tiled_mma.make_fragment_A(sdS) + # (MMA, MMA_N, MMA_K, STAGE) + tdQrKT = dSK_tiled_mma.make_fragment_B(sKT) + + # (MMA, MMA_M, MMA_K, STAGE) + tdKrdST = dSQ_tiled_mma.make_fragment_A(sdST) + # (MMA, MMA_N, MMA_K, STAGE) + tdKrQT = dSQ_tiled_mma.make_fragment_B(sQT) + + tSTtST_shape = KQ_tiled_mma.partition_shape_C( + cute.select(self.KQ_mma_tiler, mode=[0, 1]) + ) + tSTtST = KQ_tiled_mma.make_fragment_C(tSTtST_shape) + # (MMA, MMA_M, MMA_N) + tSTtST = cute.make_tensor(tSTtST.iterator + self.tmem_S_offset, tSTtST.layout) + + # (MMA, MMA_M, MMA_K, STAGE) + tdVrP = PdO_tiled_mma.make_fragment_A(tP) + tdVrP = tdVrP[None, None, None, 0] + tdVrP_iter = cute.recast_ptr(tSTtST.iterator, dtype=self.element_dtype) + tdVrP = cute.make_tensor(tdVrP_iter, tdVrP.layout) + # (MMA, MMA_N, MMA_K, STAGE) + tdVrdOT = PdO_tiled_mma.make_fragment_B(sdOT) + + tdPTtdPT_shape = VdO_tiled_mma.partition_shape_C( + cute.select(self.VdO_mma_tiler, mode=[0, 1]) + ) + tdPTtdPT = VdO_tiled_mma.make_fragment_C(tdPTtdPT_shape) + # (MMA, MMA_M, MMA_N) + tdPTtdPT = cute.make_tensor( + tdPTtdPT.iterator + self.tmem_dP_offset, tdPTtdPT.layout + ) + + tdQtdQ_shape = dSK_tiled_mma.partition_shape_C( + cute.select(self.dSK_mma_tiler, mode=[0, 1]) + ) + tdQtdQ = dSK_tiled_mma.make_fragment_C(tdQtdQ_shape) + # (MMA, MMA_M, MMA_N) + tdQtdQ = cute.make_tensor(tdQtdQ.iterator + self.tmem_dQ_offset, tdQtdQ.layout) + + tdKtdK_shape = dSQ_tiled_mma.partition_shape_C( + cute.select(self.dSQ_mma_tiler, mode=[0, 1]) + ) + tdKtdK = dSQ_tiled_mma.make_fragment_C(tdKtdK_shape) + # (MMA, MMA_M, MMA_N) + tdKtdK = cute.make_tensor(tdKtdK.iterator + self.tmem_dK_offset, tdKtdK.layout) + + tdVtdV_shape = PdO_tiled_mma.partition_shape_C( + cute.select(self.PdO_mma_tiler, mode=[0, 1]) + ) + tdVtdV = PdO_tiled_mma.make_fragment_C(tdVtdV_shape) + # (MMA, MMA_M, MMA_N) + tdVtdV = cute.make_tensor(tdVtdV.iterator + self.tmem_dV_offset, tdVtdV.layout) + + # get the current batch problem shape + + blk_coord = (Int32(0), bidx, Int32(0), ((Int32(0), bidy), bidz)) + problem_shape_cur_batch = problem_shape + blk_offset = (Int32(0), Int32(0), Int32(0), ((Int32(0), Int32(0)), Int32(0))) + if cutlass.const_expr(self.varlen): + Q_len_cur_batch = cumulative_s_q[bidz + 1] - cumulative_s_q[bidz] + K_len_cur_batch = cumulative_s_k[bidz + 1] - cumulative_s_k[bidz] + problem_shape_cur_batch = ( + Q_len_cur_batch, + K_len_cur_batch, + problem_shape[2], + problem_shape[3], + ) + blk_offset = ( + cumulative_s_q[bidz], + cumulative_s_k[bidz], + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + + trip_start = fmha_utils.FusedMask.get_trip_start( + self.mask_type, + blk_coord, + self.cta_tiler, + problem_shape_cur_batch[0], + problem_shape_cur_batch[1], + window_size_left, + window_size_right, + ) + trip_count = fmha_utils.FusedMask.get_trip_count( + self.mask_type, + blk_coord, + self.cta_tiler, + problem_shape_cur_batch[0], + problem_shape_cur_batch[1], + window_size_left, + window_size_right, + ) + trip_end = trip_start + trip_count + + trip_count = trip_count * problem_shape_cur_batch[3][0][0] + + if bidx * self.tile_shape_K < problem_shape_cur_batch[1] and trip_count > 0: + # /////////////////////////////////////////////////////////////////////////////// + # LOAD + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_load) + + self.load( + K_in, + V_in, + Q_in, + dO_in, + LSE, + sum_OdO, + sK, + sQ, + sV, + sdO, + sLSE, + sSum_OdO, + KQ_tiled_mma, + VdO_tiled_mma, + tma_atom_K, + tma_atom_Q, + tma_atom_V, + tma_atom_dO, + blk_offset, + problem_shape_cur_batch, + trip_count, + trip_start, + trip_end, + ( + load_mma_Q_pipeline, + load_compute_LSE_pipeline, + load_mma_dO_pipeline, + load_compute_sum_OdO_pipeline, + ), + ) + + # /////////////////////////////////////////////////////////////////////////////// + # MMA + # /////////////////////////////////////////////////////////////////////////////// + elif warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_mma) + + self.mma( + KQ_tiled_mma, + VdO_tiled_mma, + PdO_tiled_mma, + dSK_tiled_mma, + dSQ_tiled_mma, + tSTtST, + tSTrQ, + tSTrK, + tdPTtdPT, + tdPTrV, + tdPTrdO, + tdVtdV, + tdVrP, + tdVrdOT, + tdQtdQ, + tdQrdS, + tdQrKT, + tdKrdST, + tdKtdK, + tdKrQT, + tmem_holding_buf, + trip_count, + ( + load_mma_Q_pipeline, + mma_compute_S_pipeline, + load_mma_dO_pipeline, + mma_compute_dP_pipeline, + mma_reduce_dQ_pipeline, + compute_mma_P_pipeline, + compute_mma_dS_pipeline, + mma_compute_dKdV_pipeline, + ), + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Compute + # /////////////////////////////////////////////////////////////////////////////// + elif ( + warp_idx >= self.compute_warp_id[0] + and warp_idx <= self.compute_warp_id[-1] + ): + cute.arch.warpgroup_reg_alloc(self.num_regs_compute) + + self.compute( + tSTtST, + tdPTtdPT, + tdVrP, + sLSE, + sdS, + sSum_OdO, + dK, + dV, + tdKtdK, + tdVtdV, + PdO_tiled_mma, + dSQ_tiled_mma, + blk_coord, + blk_offset, + problem_shape_cur_batch, + trip_count, + trip_start, + trip_end, + scale_softmax, + window_size_left, + window_size_right, + ( + mma_compute_S_pipeline, + compute_mma_P_pipeline, + load_compute_LSE_pipeline, + load_compute_sum_OdO_pipeline, + mma_compute_dP_pipeline, + compute_mma_dS_pipeline, + mma_compute_dKdV_pipeline, + ), + ) + + self.epilogue_sync_barrier.arrive_and_wait() + + if warp_idx % self.num_compute_warps == 0: + tmem_ptr = cute.arch.retrieve_tmem_ptr( + Float32, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + cute.arch.dealloc_tmem(tmem_ptr, self.tmem_alloc_cols) + # /////////////////////////////////////////////////////////////////////////////// + # Reduce + # /////////////////////////////////////////////////////////////////////////////// + elif ( + warp_idx >= self.reduce_warp_id[0] + and warp_idx <= self.reduce_warp_id[-1] + ): + cute.arch.warpgroup_reg_alloc(self.num_regs_reduce) + + self.reduce( + dSK_tiled_mma, + tdQtdQ, + tma_atom_dQ_acc, + dQ_acc, + sdQ, + blk_coord, + problem_shape_cur_batch, + trip_count, + trip_start, + trip_end, + (mma_reduce_dQ_pipeline, reduce_tma_store_pipeline), + ) + + else: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + @cute.kernel + def convert( + self, + dQ_acc: cute.Tensor, + dQ: cute.Tensor, + count: Int32, + d_dim: Int32, + cumulative_s_q: Union[cute.Tensor, None], + scale_softmax: Float32, + ): + tidx, tidy, tidz = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + + seqlen = count + + offset = 0 + if cutlass.const_expr(self.varlen): + offset = cumulative_s_q[bidy] + seqlen = cumulative_s_q[bidy + 1] - offset + + for idx_s_t in cutlass.range(tidy, self.block_seq, self.num_threads_seq): + idx_s = idx_s_t + self.block_seq * bidz + if idx_s < seqlen: + dQ_acc_bhs = dQ_acc[idx_s, None, (bidx, bidy)] + dQ_acc_bhs = cute.logical_divide( + dQ_acc_bhs, cute.make_layout(self.convert_elem_per_load) + ) + dQ_bhs = dQ[idx_s + offset, None, (bidx, bidy)] + dQ_bhs = cute.logical_divide( + dQ_bhs, cute.make_layout(self.convert_elem_per_load) + ) + + thr_start = tidx + thr_step = self.num_threads_D_convert + for idx_d in cutlass.range( + thr_start, + d_dim // self.convert_elem_per_load, + thr_step, + ): + dQ_acc_frg = dQ_acc_bhs[None, idx_d].load() + dQ_acc_frg = scale_softmax * dQ_acc_frg + dQ_bhs[None, idx_d].store(dQ_acc_frg.to(self.element_dtype)) + + @cute.jit + def load( + self, + K_in: cute.Tensor, + V_in: cute.Tensor, + Q_in: cute.Tensor, + dO_in: cute.Tensor, + LSE: cute.Tensor, + sum_OdO: cute.Tensor, + sK: cute.Tensor, + sQ: cute.Tensor, + sV: cute.Tensor, + sdO: cute.Tensor, + sLSE: cute.Tensor, + sSum_OdO: cute.Tensor, + KQ_tiled_mma: cute.TiledMma, + VdO_tiled_mma: cute.TiledMma, + tma_atom_K: cute.CopyAtom, + tma_atom_Q: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_dO: cute.CopyAtom, + blk_offset: cute.Shape, + problem_shape: Tuple[Int32, Int32, Int32, Tuple[Tuple[Int32, Int32], Int32]], + iter_count: Int32, + iter_start: Int32, + iter_end: Int32, + # (load_mma_Q_pipeline, load_compute_LSE_pipeline, load_mma_dO_pipeline, load_compute_sum_OdO_pipeline) + pipeline_args: tuple, + ): + tidx, tidy, tidz = cute.arch.thread_idx() + blk_coord_k, blk_coord_h_k, blk_coord_b = cute.arch.block_idx() + blk_coord_h_r = Int32(0) + blk_coord_h = (blk_coord_h_r, blk_coord_h_k) + seq_Q, seq_K, D, HB = problem_shape + H, B = HB + iter_index = iter_start + ( + load_mma_Q_pipeline, + load_compute_LSE_pipeline, + load_mma_dO_pipeline, + load_compute_sum_OdO_pipeline, + ) = pipeline_args + + K = cute.domain_offset(cute.select(blk_offset, mode=[1, 2, 3]), K_in) + V = cute.domain_offset(cute.select(blk_offset, mode=[1, 2, 3]), V_in) + Q = cute.domain_offset(cute.select(blk_offset, mode=[0, 2, 3]), Q_in) + dO = cute.domain_offset(cute.select(blk_offset, mode=[0, 2, 3]), dO_in) + + # (bM, bK, RestM, RestK, (H, B)) + gK = cute.local_tile( + K, cute.select(self.KQ_mma_tiler, mode=[0, 2]), (None, None, None) + ) + # (bN, bK, RestN, RestK, (H, B)) + gQ = cute.local_tile( + Q, cute.select(self.KQ_mma_tiler, mode=[1, 2]), (None, None, None) + ) + # (bM, bK, RestM, RestK, (H, B)) + gV = cute.local_tile( + V, cute.select(self.VdO_mma_tiler, mode=[0, 2]), (None, None, None) + ) + # (bN, bK, RestN, RestK, (H, B)) + gdO = cute.local_tile( + dO, cute.select(self.VdO_mma_tiler, mode=[1, 2]), (None, None, None) + ) + + KQ_thr_mma = KQ_tiled_mma.get_slice(0) + VdO_thr_mma = VdO_tiled_mma.get_slice(0) + + # (MMA, MMA_M, MMA_K, RestM, RestK, (H, B)) + tSTgK = KQ_thr_mma.partition_A(gK) + # (MMA, MMA_N, MMA_K, RestN, RestK, (H, B)) + tSTgQ = KQ_thr_mma.partition_B(gQ) + # (MMA, MMA_M, MMA_K, RestM, RestK, (H, B)) + tdPTgV = VdO_thr_mma.partition_A(gV) + # (MMA, MMA_N, MMA_K, RestN, RestK, (H, B)) + tdPTgdO = VdO_thr_mma.partition_B(gdO) + + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, (H, B)) + tKsK, tKgK_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_K, + 0, + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSTgK, 0, 3), + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, (H, B)) + tQsQ, tQgQ_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_Q, + 0, + cute.make_layout(1), + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSTgQ, 0, 3), + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, (H, B)) + tVsV, tVgV_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_V, + 0, + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tdPTgV, 0, 3), + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, (H, B)) + tdOsdO, tdOgdO_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_dO, + 0, + cute.make_layout(1), + cute.group_modes(sdO, 0, 3), + cute.group_modes(tdPTgdO, 0, 3), + ) + + load_mma_Q_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_mma_Q_stage + ) + load_compute_LSE_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_compute_LSE_stage + ) + load_mma_dO_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_mma_dO_stage + ) + load_compute_sum_OdO_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_compute_sum_OdO_stage + ) + load_mma_Q_pipeline.producer_acquire(load_mma_Q_producer_state) + tma_barrier = load_mma_Q_pipeline.producer_get_barrier( + load_mma_Q_producer_state + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_expect_tx(tma_barrier, self.tma_copy_K_bytes) + + # Load K + cute.copy( + tma_atom_K, + tKgK_mkl[(None, blk_coord_k, 0, (blk_coord_h, blk_coord_b))], + tKsK[None, 0], + tma_bar_ptr=tma_barrier, + ) + + # Load Q + cute.copy( + tma_atom_Q, + tQgQ_mkl[(None, iter_index, 0, (blk_coord_h, blk_coord_b))], + tQsQ[None, load_mma_Q_producer_state.index], + tma_bar_ptr=tma_barrier, + ) + + load_mma_Q_producer_state.advance() + + load_compute_LSE_pipeline.producer_acquire(load_compute_LSE_producer_state) + + # Load LSE + # 32 threads loading 128 values of 32b each + # so 4*32b = 128b + thread_idx = tidx % self.threads_per_warp + + async_copy_num_elts = sLSE.shape[0] // self.threads_per_warp + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + self.acc_dtype, + num_bits_per_copy=self.acc_dtype.width, + ) + + sLSE_for_copy = cute.flat_divide(sLSE, (1,)) + LSE_for_copy = cute.flat_divide(LSE, (1,)) + for i in cutlass.range_constexpr(async_copy_num_elts): + LSE_idx = self.tile_shape_Q * iter_index + thread_idx * async_copy_num_elts + if cute.elem_less(LSE_idx + i, problem_shape[0]): + cute.copy( + atom_async_copy, + LSE_for_copy[None, LSE_idx + i, (blk_coord_h, blk_coord_b)], + sLSE_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + load_compute_LSE_producer_state.index, + ], + ) + else: + sLSE_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + load_compute_LSE_producer_state.index, + ].fill(0.0) + + load_compute_LSE_pipeline.producer_commit(load_compute_LSE_producer_state) + load_compute_LSE_producer_state.advance() + + load_mma_dO_pipeline.producer_acquire(load_mma_dO_producer_state) + tma_barrier = load_mma_dO_pipeline.producer_get_barrier( + load_mma_dO_producer_state + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_expect_tx(tma_barrier, self.tma_copy_V_bytes) + + # Load V + cute.copy( + tma_atom_V, + tVgV_mkl[(None, blk_coord_k, 0, (blk_coord_h, blk_coord_b))], + tVsV[(None, 0)], + tma_bar_ptr=tma_barrier, + ) + + # Load dO + cute.copy( + tma_atom_dO, + tdOgdO_mkl[(None, iter_index, 0, (blk_coord_h, blk_coord_b))], + tdOsdO[(None, load_mma_dO_producer_state.index)], + tma_bar_ptr=tma_barrier, + ) + + load_mma_dO_producer_state.advance() + + load_compute_sum_OdO_pipeline.producer_acquire( + load_compute_sum_OdO_producer_state + ) + + # load sum_OdO + sSum_OdO_for_copy = cute.flat_divide(sSum_OdO, (1,)) + sum_OdO_for_copy = cute.flat_divide(sum_OdO, (1,)) + for i in cutlass.range_constexpr(async_copy_num_elts): + sum_OdO_idx = ( + self.tile_shape_Q * iter_index + thread_idx * async_copy_num_elts + ) + if cute.elem_less(sum_OdO_idx + i, problem_shape[0]): + cute.copy( + atom_async_copy, + sum_OdO_for_copy[None, sum_OdO_idx + i, (blk_coord_h, blk_coord_b)], + sSum_OdO_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + load_compute_sum_OdO_producer_state.index, + ], + ) + else: + sSum_OdO_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + load_compute_sum_OdO_producer_state.index, + ].fill(0.0) + + load_compute_sum_OdO_pipeline.producer_commit( + load_compute_sum_OdO_producer_state + ) + load_compute_sum_OdO_producer_state.advance() + + iter_count -= 1 + iter_index += 1 + + while iter_count > 0: + if iter_index == iter_end: + iter_index = iter_start + blk_coord_h_r += 1 + blk_coord_h = (blk_coord_h_r, blk_coord_h_k) + + load_mma_Q_pipeline.producer_acquire(load_mma_Q_producer_state) + tma_barrier = load_mma_Q_pipeline.producer_get_barrier( + load_mma_Q_producer_state + ) + + # Load Q + cute.copy( + tma_atom_Q, + tQgQ_mkl[(None, iter_index, 0, (blk_coord_h, blk_coord_b))], + tQsQ[None, load_mma_Q_producer_state.index], + tma_bar_ptr=tma_barrier, + ) + + load_mma_Q_producer_state.advance() + + load_compute_LSE_pipeline.producer_acquire(load_compute_LSE_producer_state) + + # Load LSE + sLSE_for_copy = cute.flat_divide(sLSE, (1,)) + LSE_for_copy = cute.flat_divide(LSE, (1,)) + for i in cutlass.range_constexpr(async_copy_num_elts): + LSE_idx = ( + self.tile_shape_Q * iter_index + thread_idx * async_copy_num_elts + ) + if cute.elem_less(LSE_idx + i, problem_shape[0]): + cute.copy( + atom_async_copy, + LSE_for_copy[None, LSE_idx + i, (blk_coord_h, blk_coord_b)], + sLSE_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + load_compute_LSE_producer_state.index, + ], + ) + else: + sLSE_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + load_compute_LSE_producer_state.index, + ].fill(0.0) + + load_compute_LSE_pipeline.producer_commit(load_compute_LSE_producer_state) + load_compute_LSE_producer_state.advance() + + load_mma_dO_pipeline.producer_acquire(load_mma_dO_producer_state) + tma_barrier = load_mma_dO_pipeline.producer_get_barrier( + load_mma_dO_producer_state + ) + + # Load dO + cute.copy( + tma_atom_dO, + tdOgdO_mkl[(None, iter_index, 0, (blk_coord_h, blk_coord_b))], + tdOsdO[None, load_mma_dO_producer_state.index], + tma_bar_ptr=tma_barrier, + ) + + load_mma_dO_producer_state.advance() + + load_compute_sum_OdO_pipeline.producer_acquire( + load_compute_sum_OdO_producer_state + ) + + # load sum_OdO + sSum_OdO_for_copy = cute.flat_divide(sSum_OdO, (1,)) + sum_OdO_for_copy = cute.flat_divide(sum_OdO, (1,)) + for i in cutlass.range_constexpr(async_copy_num_elts): + sum_OdO_idx = ( + self.tile_shape_Q * iter_index + thread_idx * async_copy_num_elts + ) + if cute.elem_less(sum_OdO_idx + i, problem_shape[0]): + cute.copy( + atom_async_copy, + sum_OdO_for_copy[ + None, sum_OdO_idx + i, (blk_coord_h, blk_coord_b) + ], + sSum_OdO_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + load_compute_sum_OdO_producer_state.index, + ], + ) + else: + sSum_OdO_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + load_compute_sum_OdO_producer_state.index, + ].fill(0.0) + + load_compute_sum_OdO_pipeline.producer_commit( + load_compute_sum_OdO_producer_state + ) + load_compute_sum_OdO_producer_state.advance() + + iter_count -= 1 + iter_index += 1 + + @cute.jit + def mma( + self, + KQ_tiled_mma: cute.TiledMma, + VdO_tiled_mma: cute.TiledMma, + PdO_tiled_mma: cute.TiledMma, + dSK_tiled_mma: cute.TiledMma, + dSQ_tiled_mma: cute.TiledMma, + tSTtST: cute.Tensor, + tSTrQ: cute.Tensor, + tSTrK: cute.Tensor, + tdPTtdPT: cute.Tensor, + tdPTrV: cute.Tensor, + tdPTrdO: cute.Tensor, + tdVtdV: cute.Tensor, + tdVrP: cute.Tensor, + tdVrdOT: cute.Tensor, + tdQtdQ: cute.Tensor, + tdQrdS: cute.Tensor, + tdQrKT: cute.Tensor, + tdKrdST: cute.Tensor, + tdKtdK: cute.Tensor, + tdKrQT: cute.Tensor, + tmem_holding_buf: Int32, + iter_count: Int32, + # (load_mma_Q_pipeline, mma_compute_S_pipeline, load_mma_dO_pipeline, mma_compute_dP_pipeline, mma_reduce_dQ_pipeline, compute_mma_P_pipeline, compute_mma_dS_pipeline, mma_compute_dKdV_pipeline) + pipeline_args: tuple, + ): + ( + load_mma_Q_pipeline, + mma_compute_S_pipeline, + load_mma_dO_pipeline, + mma_compute_dP_pipeline, + mma_reduce_dQ_pipeline, + compute_mma_P_pipeline, + compute_mma_dS_pipeline, + mma_compute_dKdV_pipeline, + ) = pipeline_args + # Alloc tmem buffer + tmem_alloc_cols = cutlass.Int32(self.tmem_alloc_cols) + cute.arch.alloc_tmem(tmem_alloc_cols, tmem_holding_buf) + self.tmem_alloc_barrier.arrive_and_wait() + load_mma_Q_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_mma_Q_stage + ) + load_mma_Q_release_state = load_mma_Q_consumer_state.clone() + + mma_compute_S_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_compute_S_stage + ) + compute_mma_dS_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.compute_mma_dS_stage + ) + mma_compute_dP_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_compute_dP_stage + ) + mma_reduce_dQ_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_reduce_dQ_stage + ) + load_mma_dO_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_mma_dO_stage + ) + compute_mma_P_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.compute_mma_P_stage + ) + mma_compute_dKdV_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_compute_dKdV_stage + ) + + load_mma_Q_pipeline.consumer_wait(load_mma_Q_consumer_state) + mma_compute_S_pipeline.producer_acquire(mma_compute_S_producer_state) + + # S = K * Q + KQ_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for k_block in cutlass.range(0, cute.size(tSTrQ, mode=[2]), unroll_full=True): + cute.gemm( + KQ_tiled_mma, + tSTtST, + tSTrK[None, None, k_block, 0], + tSTrQ[None, None, k_block, load_mma_Q_consumer_state.index], + tSTtST, + ) + KQ_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + load_mma_Q_consumer_state.advance() + mma_compute_S_pipeline.producer_commit(mma_compute_S_producer_state) + mma_compute_S_producer_state.advance() + + load_mma_dO_pipeline.consumer_wait(load_mma_dO_consumer_state) + + mma_compute_dP_pipeline.producer_acquire(mma_compute_dP_producer_state) + mma_reduce_dQ_pipeline.producer_acquire(mma_reduce_dQ_producer_state) + + # dP = V * dO + VdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for k_block in cutlass.range(0, cute.size(tdPTrV, mode=[2]), unroll_full=True): + cute.gemm( + VdO_tiled_mma, + tdPTtdPT, + tdPTrV[None, None, k_block, 0], + tdPTrdO[None, None, k_block, load_mma_dO_consumer_state.index], + tdPTtdPT, + ) + VdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + mma_compute_dP_pipeline.producer_commit(mma_compute_dP_producer_state) + mma_compute_dP_producer_state.advance() + + compute_mma_P_pipeline.consumer_wait(compute_mma_P_consumer_state) + + # dV = P * dO + for k_block in cutlass.range(0, cute.size(tdVrP, mode=[2]), unroll_full=True): + cute.gemm( + PdO_tiled_mma, + tdVtdV, + tdVrP[None, None, k_block], + tdVrdOT[None, None, k_block, load_mma_dO_consumer_state.index], + tdVtdV, + ) + PdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + compute_mma_P_pipeline.consumer_release(compute_mma_P_consumer_state) + compute_mma_P_consumer_state.advance() + + load_mma_dO_pipeline.consumer_release(load_mma_dO_consumer_state) + load_mma_dO_consumer_state.advance() + + iter_count -= 1 + + while iter_count > 0: + load_mma_Q_pipeline.consumer_wait(load_mma_Q_consumer_state) + mma_compute_S_pipeline.producer_acquire(mma_compute_S_producer_state) + + # S = K * Q + KQ_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for k_block in cutlass.range( + 0, cute.size(tSTrQ, mode=[2]), unroll_full=True + ): + cute.gemm( + KQ_tiled_mma, + tSTtST, + tSTrK[None, None, k_block, 0], + tSTrQ[None, None, k_block, load_mma_Q_consumer_state.index], + tSTtST, + ) + KQ_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + load_mma_Q_consumer_state.advance() + mma_compute_S_pipeline.producer_commit(mma_compute_S_producer_state) + mma_compute_S_producer_state.advance() + + compute_mma_dS_pipeline.consumer_wait(compute_mma_dS_consumer_state) + + # We need to acquire dP here, because tmem dQ == tmem dP + mma_compute_dP_pipeline.producer_acquire(mma_compute_dP_producer_state) + + # dQ = dS * K + dSK_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for k_block in cutlass.range( + 0, cute.size(tdQrdS, mode=[2]), unroll_full=True + ): + cute.gemm( + dSK_tiled_mma, + tdQtdQ, + tdQrdS[ + None, + None, + k_block, + compute_mma_dS_consumer_state.index, + ], + tdQrKT[None, None, k_block, 0], + tdQtdQ, + ) + dSK_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + mma_reduce_dQ_pipeline.producer_commit(mma_reduce_dQ_producer_state) + mma_reduce_dQ_producer_state.advance() + + # dK = dS * Q + for k_block in cutlass.range( + 0, cute.size(tdKrdST, mode=[2]), unroll_full=True + ): + cute.gemm( + dSQ_tiled_mma, + tdKtdK, + tdKrdST[ + None, + None, + k_block, + compute_mma_dS_consumer_state.index, + ], + tdKrQT[None, None, k_block, load_mma_Q_release_state.index], + tdKtdK, + ) + dSQ_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + load_mma_Q_pipeline.consumer_release(load_mma_Q_release_state) + load_mma_Q_release_state.advance() + + compute_mma_dS_pipeline.consumer_release(compute_mma_dS_consumer_state) + compute_mma_dS_consumer_state.advance() + + # We grab dQ here, because in tmem dQ == dP + mma_reduce_dQ_pipeline.producer_acquire(mma_reduce_dQ_producer_state) + load_mma_dO_pipeline.consumer_wait(load_mma_dO_consumer_state) + + # dP = V * dO + VdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for k_block in cutlass.range( + 0, cute.size(tdPTrV, mode=[2]), unroll_full=True + ): + cute.gemm( + VdO_tiled_mma, + tdPTtdPT, + tdPTrV[None, None, k_block, 0], + tdPTrdO[None, None, k_block, load_mma_dO_consumer_state.index], + tdPTtdPT, + ) + VdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + mma_compute_dP_pipeline.producer_commit(mma_compute_dP_producer_state) + mma_compute_dP_producer_state.advance() + + compute_mma_P_pipeline.consumer_wait(compute_mma_P_consumer_state) + + # dV = P * dO + for k_block in cutlass.range( + 0, cute.size(tdVrP, mode=[2]), unroll_full=True + ): + cute.gemm( + PdO_tiled_mma, + tdVtdV, + tdVrP[None, None, k_block], + tdVrdOT[None, None, k_block, load_mma_dO_consumer_state.index], + tdVtdV, + ) + PdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + compute_mma_P_pipeline.consumer_release(compute_mma_P_consumer_state) + compute_mma_P_consumer_state.advance() + + load_mma_dO_pipeline.consumer_release(load_mma_dO_consumer_state) + load_mma_dO_consumer_state.advance() + + iter_count -= 1 + + # Signal to the epilogue that dV is ready + mma_compute_dKdV_pipeline.producer_acquire(mma_compute_dKdV_producer_state) + mma_compute_dKdV_pipeline.producer_commit(mma_compute_dKdV_producer_state) + mma_compute_dKdV_producer_state.advance() + + mma_compute_dKdV_pipeline.producer_acquire(mma_compute_dKdV_producer_state) + + compute_mma_dS_pipeline.consumer_wait(compute_mma_dS_consumer_state) + + # dK = dS * Q + for k_block in cutlass.range(0, cute.size(tdKrdST, mode=[2]), unroll_full=True): + cute.gemm( + dSQ_tiled_mma, + tdKtdK, + tdKrdST[None, None, k_block, compute_mma_dS_consumer_state.index], + tdKrQT[None, None, k_block, load_mma_Q_release_state.index], + tdKtdK, + ) + dSQ_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Signal to epilogue that dK is ready + mma_compute_dKdV_pipeline.producer_commit(mma_compute_dKdV_producer_state) + mma_compute_dKdV_producer_state.advance() + + # We've already acquired mma_reduce_dq in the loop + + # dQ = dS * K + dSK_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for k_block in cutlass.range(0, cute.size(tdQrdS, mode=[2]), unroll_full=True): + cute.gemm( + dSK_tiled_mma, + tdQtdQ, + tdQrdS[None, None, k_block, compute_mma_dS_consumer_state.index], + tdQrKT[None, None, k_block, 0], + tdQtdQ, + ) + dSK_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + mma_reduce_dQ_pipeline.producer_commit(mma_reduce_dQ_producer_state) + mma_reduce_dQ_producer_state.advance() + + load_mma_Q_pipeline.consumer_release(load_mma_Q_release_state) + load_mma_Q_release_state.advance() + + compute_mma_dS_pipeline.consumer_release(compute_mma_dS_consumer_state) + compute_mma_dS_consumer_state.advance() + + @cute.jit + def compute( + self, + tSTtST: cute.Tensor, + tdPTtdPT: cute.Tensor, + tdVrP: cute.Tensor, + sLSE: cute.Tensor, + sdS: cute.Tensor, + sSum_OdO: cute.Tensor, + dK: cute.Tensor, + dV: cute.Tensor, + tdKtdK: cute.Tensor, + tdVtdV: cute.Tensor, + PdO_tiled_mma: cute.TiledMma, + dSQ_tiled_mma: cute.TiledMma, + blk_coord: cute.Coord, + blk_offset: cute.Shape, + problem_shape: Tuple[Int32, Int32, Int32, Tuple[Tuple[Int32, Int32], Int32]], + iter_count: Int32, + iter_start: Int32, + iter_end: Int32, + scale_softmax: Float32, + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + # (mma_compute_S_pipeline, compute_mma_P_pipeline, load_compute_LSE_pipeline, load_compute_sum_OdO_pipeline, mma_compute_dP_pipeline, compute_mma_dS_pipeline, mma_compute_dKdV_pipeline) + pipeline_args: tuple, + ): + tidx, tidy, tidz = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + + Q, K, D, HB = problem_shape + blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch = blk_coord + iter_index = iter_start + ( + mma_compute_S_pipeline, + compute_mma_P_pipeline, + load_compute_LSE_pipeline, + load_compute_sum_OdO_pipeline, + mma_compute_dP_pipeline, + compute_mma_dS_pipeline, + mma_compute_dKdV_pipeline, + ) = pipeline_args + + mma_compute_S_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_compute_S_stage + ) + compute_mma_P_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.compute_mma_P_stage + ) + load_compute_LSE_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_compute_LSE_stage + ) + load_compute_sum_OdO_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_compute_sum_OdO_stage + ) + mma_compute_dP_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_compute_dP_stage + ) + compute_mma_dS_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.compute_mma_dS_stage + ) + mma_compute_dKdV_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_compute_dKdV_stage + ) + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)), + self.acc_dtype, + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), + self.element_dtype, + ) + + tSTtST = tSTtST[(None, None), 0, 0] + tdPTtdPT = tdPTtdPT[(None, None), 0, 0] + + cST = cute.make_identity_tensor(cute.select(self.KQ_mma_tiler, mode=[0, 1])) + cdPT = cute.make_identity_tensor(cute.select(self.VdO_mma_tiler, mode=[0, 1])) + + num_warp_groups = self.num_compute_warps // 4 + dp_idx = tidx % 128 + wg_idx = (tidx % (self.num_compute_warps * self.threads_per_warp)) // 128 + tiled_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tSTtST) + thr_t2r = tiled_t2r.get_slice(dp_idx) + + tTR_cST = thr_t2r.partition_D(cST) + tTR_cST = self.split_wg(tTR_cST, num_warp_groups, wg_idx) + tTR_rST = cute.make_rmem_tensor(tTR_cST.shape, self.acc_dtype) + + tTR_tST = thr_t2r.partition_S(tSTtST) + tTR_tST = self.split_wg(tTR_tST, num_warp_groups, wg_idx) + + tTR_cdPT_p = thr_t2r.partition_D(cdPT) + tTR_cdPT = self.split_wg(tTR_cdPT_p, num_warp_groups, wg_idx) + tTR_rdPT = cute.make_rmem_tensor(tTR_cdPT.shape, self.acc_dtype) + + tTR_tdPT = thr_t2r.partition_S(tdPTtdPT) + tTR_tdPT = self.split_wg(tTR_tdPT, num_warp_groups, wg_idx) + + tdVcST = PdO_tiled_mma.get_slice(0).partition_A(cST) + + tiled_r2t = tcgen05.make_tmem_copy(tmem_store_atom, tdVrP) + thr_r2t = tiled_r2t.get_slice(dp_idx) + + tRT_tP = thr_r2t.partition_D(tdVrP) + tRT_tP = self.split_wg(tRT_tP, num_warp_groups, wg_idx) + + tRT_cST = thr_r2t.partition_S(tdVcST) + tRT_cST = self.split_wg(tRT_cST, num_warp_groups, wg_idx) + + masked_leading_count = fmha_utils.FusedMask.get_masked_leading_count( + self.mask_type, + blk_coord, + self.cta_tiler, + Q, + K, + window_size_left, + window_size_right, + ) + unmasked_count = fmha_utils.FusedMask.get_unmasked_trip_count( + self.mask_type, + blk_coord, + self.cta_tiler, + Q, + K, + window_size_left, + window_size_right, + ) + masked_trailing_count = fmha_utils.FusedMask.get_masked_trailing_count( + self.mask_type, + blk_coord, + self.cta_tiler, + Q, + K, + window_size_left, + window_size_right, + ) + + while iter_count > 0: + # Wait for S and P + mma_compute_S_pipeline.consumer_wait(mma_compute_S_consumer_state) + compute_mma_P_pipeline.producer_acquire(compute_mma_P_producer_state) + # Wait for LSE + load_compute_LSE_pipeline.consumer_wait(load_compute_LSE_consumer_state) + + iter_num = iter_index - iter_start + 1 + # Compute P = softmax(S, LSE) + cute.copy(tiled_t2r, tTR_tST, tTR_rST) + + is_residual_k = Boolean(False) + is_residual_k = blk_coord_k * self.tile_shape_K + self.tile_shape_K >= K + + is_masked_tile = ( + is_residual_k + or iter_num <= masked_leading_count + or ( + iter_num > masked_leading_count + unmasked_count + and iter_num + <= masked_leading_count + unmasked_count + masked_trailing_count + ) + ) + + if is_masked_tile: + fmha_utils.FusedMask.apply_mask( + self.mask_type, + tTR_rST, + tTR_cST, + Q, + K, + window_size_left, + window_size_right, + lambda index_k, index_q: ( + index_q + iter_index * self.tile_shape_Q, + index_k + blk_coord_k * self.tile_shape_K, + ), + ) + + log2_e = Float32(math.log2(math.e)) + softmax_scale_log2_e = scale_softmax * log2_e + + for i in cutlass.range(0, cute.size(tTR_rST), 2, unroll_full=True): + lse = ( + sLSE[ + cute.get(tTR_cST[i], mode=[1]), + load_compute_LSE_consumer_state.index, + ], + sLSE[ + cute.get(tTR_cST[i + 1], mode=[1]), + load_compute_LSE_consumer_state.index, + ], + ) + tTR_rST[i], tTR_rST[i + 1] = cute.arch.fma_packed_f32x2( + (tTR_rST[i], tTR_rST[i + 1]), + (softmax_scale_log2_e, softmax_scale_log2_e), + lse, + ) + tTR_rST[i] = cute.math.exp2(tTR_rST[i], fastmath=True) + tTR_rST[i + 1] = cute.math.exp2(tTR_rST[i + 1], fastmath=True) + + # convert fp32 P to fp16 P which will be used in the PdO + tRT_rST = self.quantize(tTR_rST, 4) + + tRT_rST_reshaped = cute.make_tensor( + tRT_rST.iterator, cute.make_layout(tRT_cST.shape) + ) + + cute.arch.fence_view_async_tmem_load() + self.compute_sync_barrier.arrive_and_wait() + cute.arch.fence_view_async_tmem_load() + + cute.copy(tiled_r2t, tRT_rST_reshaped, tRT_tP) + + cute.arch.fence_view_async_tmem_store() + + # Notify for P + compute_mma_P_pipeline.producer_commit(compute_mma_P_producer_state) + compute_mma_P_producer_state.advance() + + # Release S + mma_compute_S_pipeline.consumer_release(mma_compute_S_consumer_state) + mma_compute_S_consumer_state.advance() + + # Release LSE + load_compute_LSE_pipeline.consumer_release(load_compute_LSE_consumer_state) + load_compute_LSE_consumer_state.advance() + + # Wait for OdO + load_compute_sum_OdO_pipeline.consumer_wait( + load_compute_sum_OdO_consumer_state + ) + # Wait for dP + mma_compute_dP_pipeline.consumer_wait(mma_compute_dP_consumer_state) + + # Wait for dS + compute_mma_dS_pipeline.producer_acquire(compute_mma_dS_producer_state) + + # Compute dS = dsoftmax(P, dP, sum_OdO) + cute.copy(tiled_t2r, tTR_tdPT, tTR_rdPT) + + for i in cutlass.range(0, cute.size(tTR_rdPT), 2, unroll_full=True): + tTR_rdPT[i], tTR_rdPT[i + 1] = cute.arch.add_packed_f32x2( + (tTR_rdPT[i], tTR_rdPT[i + 1]), + ( + sSum_OdO[ + cute.get(tTR_cdPT[i], mode=[1]), + load_compute_sum_OdO_consumer_state.index, + ], + sSum_OdO[ + cute.get(tTR_cdPT[i + 1], mode=[1]), + load_compute_sum_OdO_consumer_state.index, + ], + ), + ) + tTR_rdPT[i], tTR_rdPT[i + 1] = cute.arch.mul_packed_f32x2( + (tTR_rdPT[i], tTR_rdPT[i + 1]), (tTR_rST[i], tTR_rST[i + 1]) + ) + # convert fp32 dS to fp16 dS which will be used in the computation of dK and DQ + tTR_rdST = self.quantize(tTR_rdPT, 4) + + # Release dP + cute.arch.fence_view_async_tmem_load() + mma_compute_dP_pipeline.consumer_release(mma_compute_dP_consumer_state) + mma_compute_dP_consumer_state.advance() + + sdS_slice = sdS[None, None, None, compute_mma_dS_producer_state.index] + + thread_layout = cute.make_ordered_layout((128, 128), (1, 0)) + sdS_slice_tmp = cute.composition(sdS_slice, thread_layout) + sdS_slice_p = cute.composition( + sdS_slice_tmp[dp_idx, None], cute.make_layout(tTR_cdPT_p.shape) + ) + sdS_slice = self.split_wg(sdS_slice_p, num_warp_groups, wg_idx) + + cute.autovec_copy(tTR_rdST, sdS_slice) + + # Notify for dS + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + compute_mma_dS_pipeline.producer_commit(compute_mma_dS_producer_state) + compute_mma_dS_producer_state.advance() + + # Release OdO + load_compute_sum_OdO_pipeline.consumer_release( + load_compute_sum_OdO_consumer_state + ) + load_compute_sum_OdO_consumer_state.advance() + + iter_count -= 1 + iter_index += 1 + if iter_index == iter_end: + iter_index = iter_start + + # Epilogue + self.epilogue( + blk_coord, + blk_offset, + problem_shape, + dK, + dV, + tdKtdK, + tdVtdV, + scale_softmax, + (mma_compute_dKdV_pipeline, mma_compute_dKdV_consumer_state), + ) + + @cute.jit + def reduce( + self, + dSK_tiled_mma: cute.TiledMma, + tdQtdQ: cute.Tensor, + tma_atom_dQ_acc: cute.CopyAtom, + mdQ_acc: cute.Tensor, + sdQ: cute.Tensor, + blk_coord: cute.Coord, + problem_shape: Tuple[Int32, Int32, Int32, Tuple[Tuple[Int32, Int32], Int32]], + iter_count: Int32, + iter_start: Int32, + iter_end: Int32, + # (mma_reduce_dQ_pipeline, reduce_tma_store_pipeline) + pipeline_args: tuple, + ): + tidx, tidy, tidz = cute.arch.thread_idx() + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + Q, K, D, HB = problem_shape + blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch = blk_coord + + blk_coord_h, blk_coord_b = blk_coord_batch + blk_coord_h_r, blk_coord_h_k = blk_coord_h + + iter_index = iter_start + + mma_reduce_dQ_pipeline, reduce_tma_store_pipeline = pipeline_args + + mma_reduce_dQ_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_reduce_dQ_stage + ) + reduce_tma_store_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.reduce_tma_store_stage + ) + + load_op = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.acc_dtype, + ) + + gdQ = cute.local_tile(mdQ_acc, (self.KQ_mma_tiler[1], 32), (None, None, None)) + + cdQ = cute.make_identity_tensor((self.dSK_mma_tiler[0], self.dSK_mma_tiler[1])) + + thread_idx = tidx % (self.num_compute_warps * self.threads_per_warp) + + tdQtdQ = tdQtdQ[(None, None), 0, 0] + + tiled_t2r = tcgen05.make_tmem_copy(load_op, tdQtdQ) + thr_t2r = tiled_t2r.get_slice(thread_idx) + + tTR_cdQ = thr_t2r.partition_D(cdQ) + tTR_gdQ = thr_t2r.partition_D(gdQ) + tTR_sdQ = thr_t2r.partition_D(sdQ) + tTR_tdQ = thr_t2r.partition_S(tdQtdQ) + + tdQsdQ, tdQgdQ = cute.nvgpu.cpasync.tma_partition( + tma_atom_dQ_acc, + 0, + cute.make_layout(1), + cute.group_modes(sdQ, 0, 2), + cute.group_modes(gdQ, 0, 2), + ) + + while iter_count > 0: + mma_reduce_dQ_pipeline.consumer_wait(mma_reduce_dQ_consumer_state) + + tTR_rdQ = cute.make_rmem_tensor(tTR_cdQ.shape, self.acc_dtype) + + # Load dQ from tmem to rmem + cute.copy(tiled_t2r, tTR_tdQ, tTR_rdQ) + + cute.arch.fence_view_async_tmem_load() + + mma_reduce_dQ_pipeline.consumer_release(mma_reduce_dQ_consumer_state) + mma_reduce_dQ_consumer_state.advance() + + # We don't have enough smem to dump it all to smem, so we do it in stages + for i in cutlass.range(0, cute.size(tTR_cdQ, mode=[2]), unroll_full=True): + if warp_idx == 0: + reduce_tma_store_pipeline.producer_acquire() + # Wait in all threads for the acquire to complete + self.reduce_sync_barrier.arrive_and_wait() + + cute.autovec_copy( + tTR_rdQ[None, None, i], + tTR_sdQ[None, None, 0, reduce_tma_store_producer_state.index], + ) + + # Wait for the stores to all be visible to the TMA + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.reduce_sync_barrier.arrive_and_wait() + + if warp_idx == 0: + cute.copy( + tma_atom_dQ_acc, + tdQsdQ[None, reduce_tma_store_producer_state.index], + tdQgdQ[None, iter_index, i, blk_coord_batch], + ) + + reduce_tma_store_pipeline.producer_commit() + + reduce_tma_store_producer_state.advance() + + iter_count -= 1 + iter_index += 1 + if iter_index == iter_end: + iter_index = iter_start + blk_coord_h_r += 1 + blk_coord_batch = ((blk_coord_h_r, blk_coord_h_k), blk_coord_b) + + reduce_tma_store_pipeline.producer_tail() + + @cute.jit + def split_wg( + self, + t: cute.Tensor, + num_warp_groups: Int32, + wg_idx: Int32, + ) -> cute.Tensor: + ret = None + if cutlass.const_expr(cute.rank(t.layout) == 3): + p = cute.composition( + t, + cute.make_layout( + ( + t.shape[0], + t.shape[1], + (num_warp_groups, cute.size(t, mode=[2]) // num_warp_groups), + ) + ), + ) + ret = p[None, None, (wg_idx, None)] + else: + p = cute.composition( + t, + cute.make_layout( + ( + t.shape[0], + t.shape[1], + t.shape[2], + (num_warp_groups, cute.size(t, mode=[3]) // num_warp_groups), + ) + ), + ) + ret = p[None, None, None, (wg_idx, None)] + return ret + + @cute.jit + def quantize( + self, + input: cute.Tensor, + frg_cnt: Int32, + ) -> cute.Tensor: + tidx, tidy, tidz = cute.arch.thread_idx() + output = cute.make_rmem_tensor(input.shape, self.element_dtype) + frg_tile = cute.size(input) // frg_cnt + t_frg = cute.logical_divide(input, cute.make_layout(frg_cnt)) + output_frg = cute.make_tensor(output.iterator, t_frg.layout) + for i in cutlass.range(frg_tile, unroll_full=True): + frg_vec = t_frg[None, i].load() + output_frg[None, i].store(frg_vec.to(self.element_dtype)) + return output + + @cute.jit + def store( + self, + gmem: cute.Tensor, + regs: cute.Tensor, + coord: cute.Tensor, + tensor_shape: cute.Shape, + ): + copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.element_dtype, + num_bits_per_copy=128, + ) + copy_op = cute.make_cotiled_copy( + copy_atom, + cute.make_layout((1, 128 // self.element_dtype.width)), + regs.layout, + ) + thr_copy = copy_op.get_slice(0) + + tCg = thr_copy.partition_D(gmem) + tCr = thr_copy.partition_S(self.quantize(regs, 4)) + tPc = thr_copy.partition_D(coord) + + preds_shape = (tPc.shape[0][1], tPc.shape[1], tPc.shape[2], tPc.shape[3]) + preds = cute.make_rmem_tensor(preds_shape, Boolean) + for v in cutlass.range_constexpr(preds.shape[0]): + for m in cutlass.range_constexpr(preds.shape[1]): + for n in cutlass.range_constexpr(preds.shape[2]): + for k in cutlass.range_constexpr(preds.shape[3]): + lhs = tPc[(0, v), m, n, k] + val = cute.elem_less(lhs, tensor_shape) + preds[v, m, n, k] = val + + cute.copy(copy_atom, tCr, tCg, pred=preds) + + @cute.jit + def epilogue( + self, + blk_coord: cute.Coord, + blk_offset: cute.Shape, + problem_shape: Tuple[Int32, Int32, Int32, Tuple[Tuple[Int32, Int32], Int32]], + dK: cute.Tensor, + dV: cute.Tensor, + tdKtdK: cute.Tensor, + tdVtdV: cute.Tensor, + scale_softmax: Float32, + # (mma_compute_dKdV_pipeline, mma_compute_dKdV_consumer_state) + pipeline_args: tuple, + ): + tidx, tidy, tidz = cute.arch.thread_idx() + Q, K, D, HB = problem_shape + blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch = blk_coord + mma_compute_dKdV_pipeline, mma_compute_dKdV_consumer_state = pipeline_args + + load_op = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)), + self.acc_dtype, + ) + + tdKtdK = tdKtdK[(None, None), 0, 0] + + mdK = cute.make_tensor( + dK.iterator + cute.assume(blk_offset[1] * dK.stride[0], divby=64), + cute.make_layout((K, self.tile_shape_dQ_K, HB), stride=dK.stride), + ) + gdK = cute.local_tile( + mdK, (self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1]), (None, None, None) + ) + gdK = gdK[None, None, blk_coord_k, 0, blk_coord_batch] + + cdK = cute.domain_offset( + (blk_coord_k * self.tile_shape_K, 0), + cute.make_identity_tensor((self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1])), + ) + + num_warp_groups = self.num_compute_warps // 4 + dp_idx = tidx % 128 + wg_idx = (tidx % (self.num_compute_warps * self.threads_per_warp)) // 128 + + tiled_t2r_dK = tcgen05.make_tmem_copy(load_op, tdKtdK) + thread_t2r_dK = tiled_t2r_dK.get_slice(dp_idx) + + tTR_cdK = thread_t2r_dK.partition_D(cdK) + tTR_cdK = self.split_wg(tTR_cdK, num_warp_groups, wg_idx) + tTR_gdK = thread_t2r_dK.partition_D(gdK) + tTR_gdK = self.split_wg(tTR_gdK, num_warp_groups, wg_idx) + tTR_rdK = cute.make_rmem_tensor(tTR_cdK.shape, self.acc_dtype) + tTR_tdK = thread_t2r_dK.partition_S(tdKtdK) + tTR_tdK = self.split_wg(tTR_tdK, num_warp_groups, wg_idx) + + mdV_in = cute.make_tensor( + dV.iterator, cute.make_layout((K, self.cta_tiler[2], HB), stride=dV.stride) + ) + mdV = cute.make_tensor( + mdV_in.iterator + cute.assume(blk_offset[1] * mdV_in.stride[0], divby=64), + mdV_in.layout, + ) + gdV = cute.local_tile( + mdV, (self.PdO_mma_tiler[0], self.PdO_mma_tiler[1]), (None, None, None) + ) + gdV = gdV[None, None, blk_coord_k, 0, blk_coord_batch] + + cdV = cute.domain_offset( + (blk_coord_k * self.cta_tiler[1], 0), + cute.make_identity_tensor((self.PdO_mma_tiler[0], self.PdO_mma_tiler[1])), + ) + + tdVtdV = tdVtdV[(None, None), 0, 0] + + tiled_t2r_dV = tcgen05.make_tmem_copy(load_op, tdVtdV) + thread_t2r_dV = tiled_t2r_dV.get_slice(dp_idx) + + tTR_cdV = thread_t2r_dV.partition_D(cdV) + tTR_cdV = self.split_wg(tTR_cdV, num_warp_groups, wg_idx) + tTR_gdV = thread_t2r_dV.partition_D(gdV) + tTR_gdV = self.split_wg(tTR_gdV, num_warp_groups, wg_idx) + tTR_rdV = cute.make_rmem_tensor(tTR_cdV.shape, self.acc_dtype) + tTR_tdV = thread_t2r_dV.partition_S(tdVtdV) + tTR_tdV = self.split_wg(tTR_tdV, num_warp_groups, wg_idx) + + mma_compute_dKdV_pipeline.consumer_wait(mma_compute_dKdV_consumer_state) + + # Load tdVtdV + cute.copy(tiled_t2r_dV, tTR_tdV, tTR_rdV) + + # Store tdVgdV + self.store(tTR_gdV, tTR_rdV, tTR_cdV, (K, D)) + + cute.arch.fence_view_async_tmem_load() + + mma_compute_dKdV_pipeline.consumer_release(mma_compute_dKdV_consumer_state) + mma_compute_dKdV_consumer_state.advance() + + mma_compute_dKdV_pipeline.consumer_wait(mma_compute_dKdV_consumer_state) + + cute.copy(tiled_t2r_dK, tTR_tdK, tTR_rdK) + + for i in cutlass.range(cute.size(tTR_rdK), unroll_full=True): + tTR_rdK[i] = scale_softmax * tTR_rdK[i] + + self.store(tTR_gdK, tTR_rdK, tTR_cdK, (K, D)) + + cute.arch.fence_view_async_tmem_load() + mma_compute_dKdV_pipeline.consumer_release(mma_compute_dKdV_consumer_state) + mma_compute_dKdV_consumer_state.advance() + + def get_workspace_tensor( + self, + problem_shape: Tuple[Int32, Int32, Int32, Tuple[Tuple[Int32, Int32], Int32]], + workspace: cute.Tensor, + acc_dtype: Type[cutlass.Numeric], + ) -> Tuple[cute.Tensor, cute.Tensor, cute.Tensor]: + Q, D, HB = ( + problem_shape[0], + problem_shape[2], + problem_shape[3], + ) + H, B = cute.size(problem_shape[3][0]), cute.size(problem_shape[3][1]) + H_r, H_k = problem_shape[3][0] + D = cute.round_up(D, 8) + Q = cute.round_up(Q, 8) + + acc_bytes = acc_dtype.width // 8 + sum_OdO_bytes = cute.assume(B * H * Q * acc_bytes, divby=acc_bytes) + scaled_lse_bytes = cute.assume(B * H * Q * acc_bytes, divby=acc_bytes) + dQ_acc_bytes = cute.assume(B * H * Q * D * acc_bytes, divby=acc_bytes) + + sum_OdO_iter = workspace.iterator + scaled_lse_iter = sum_OdO_iter + sum_OdO_bytes + dQ_acc_iter = scaled_lse_iter + scaled_lse_bytes + + sum_OdO_iter = cute.recast_ptr(sum_OdO_iter, dtype=self.acc_dtype) + scaled_lse_iter = cute.recast_ptr(scaled_lse_iter, dtype=self.acc_dtype) + dQ_acc_iter = cute.recast_ptr(dQ_acc_iter, dtype=self.acc_dtype) + + sum_OdO = cute.make_tensor( + sum_OdO_iter, + cute.make_layout((Q, ((H_r, H_k), B)), stride=(1, ((Q, Q * H_r), Q * H))), + ) + scaled_lse = cute.make_tensor( + scaled_lse_iter, + cute.make_layout((Q, ((H_r, H_k), B)), stride=(1, ((Q, Q * H_r), Q * H))), + ) + dQ_acc = cute.make_tensor( + dQ_acc_iter, + cute.make_layout( + (Q, D, ((H_r, H_k), B)), + stride=(D, 1, ((D * Q, D * Q * H_r), D * Q * H)), + ), + ) + + return sum_OdO, scaled_lse, dQ_acc + + @staticmethod + def _compute_sum_OdO_grid( + problem_shape: Tuple[Int32, Int32, Int32, Tuple[Tuple[Int32, Int32], Int32]], + block_q: int, + ) -> Tuple[int, int, int]: + grid = ( + cute.ceil_div(cute.size(problem_shape[0]), block_q), + cute.size(problem_shape[3][0]), # H + cute.size(problem_shape[3][1]), # B + ) + return grid + + @staticmethod + def _compute_bwd_grid( + problem_shape: Tuple[Int32, Int32, Int32, Tuple[Tuple[Int32, Int32], Int32]], + block_k: int, + ) -> Tuple[int, int, int]: + K = problem_shape[1] + H_R, H_K = problem_shape[3][0] + B = problem_shape[3][1] + return (cute.ceil_div(K, block_k), cute.size(H_K), cute.size(B)) + + @staticmethod + def _get_workspace_size( + q: int, k: int, d: int, h: int, b: int, acc_dtype: Type[cutlass.Numeric] + ): + d = (d + 7) // 8 * 8 # round up to 8 + q = (q + 7) // 8 * 8 # round up to 8 + workspace_bytes = 0 + # OdO vector + workspace_bytes += b * h * q * acc_dtype.width // 8 + # scaled LSE vector + workspace_bytes += b * h * q * acc_dtype.width // 8 + # FP32 versions of outputs that are churned (start off with Q only) + workspace_bytes += b * h * q * d * acc_dtype.width // 8 + return workspace_bytes + + def make_and_init_load_mma_Q_pipeline(self, load_mma_Q_mbar_ptr): + load_mma_Q_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_warp_id]) + ) + load_mma_Q_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_mma_Q_mbar_ptr, + num_stages=self.load_mma_Q_stage, + producer_group=load_mma_Q_producer_group, + consumer_group=load_mma_Q_consumer_group, + tx_count=self.tma_copy_Q_bytes, + ) + + def make_and_init_load_mma_dO_pipeline(self, load_mma_dO_mbar_ptr): + load_mma_dO_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_warp_id]) + ) + load_mma_dO_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_mma_dO_mbar_ptr, + num_stages=self.load_mma_dO_stage, + producer_group=load_mma_dO_producer_group, + consumer_group=load_mma_dO_consumer_group, + tx_count=self.tma_copy_dO_bytes, + ) + + def make_and_init_load_compute_LSE_pipeline(self, load_compute_lse_mbar_ptr): + load_compute_lse_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp, + ) + load_compute_lse_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * self.num_compute_warps, + ) + return pipeline.PipelineCpAsync.create( + barrier_storage=load_compute_lse_mbar_ptr, + num_stages=self.load_compute_LSE_stage, + producer_group=load_compute_lse_producer_group, + consumer_group=load_compute_lse_consumer_group, + ) + + def make_and_init_load_compute_sum_OdO_pipeline( + self, load_compute_sum_OdO_mbar_ptr + ): + load_compute_sum_OdO_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp, + ) + load_compute_sum_OdO_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * self.num_compute_warps, + ) + return pipeline.PipelineCpAsync.create( + barrier_storage=load_compute_sum_OdO_mbar_ptr, + num_stages=self.load_compute_sum_OdO_stage, + producer_group=load_compute_sum_OdO_producer_group, + consumer_group=load_compute_sum_OdO_consumer_group, + ) + + def make_and_init_mma_compute_S_pipeline(self, mma_compute_S_mbar_ptr): + mma_compute_S_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_warp_id]), + ) + mma_compute_S_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_compute_warps * self.threads_per_warp, + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_compute_S_mbar_ptr, + num_stages=self.mma_compute_S_stage, + producer_group=mma_compute_S_producer_group, + consumer_group=mma_compute_S_consumer_group, + ) + + def make_and_init_mma_compute_dP_pipeline(self, mma_compute_dP_mbar_ptr): + mma_compute_dP_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_warp_id]), + ) + mma_compute_dP_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_compute_warps * self.threads_per_warp, + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_compute_dP_mbar_ptr, + num_stages=self.mma_compute_dP_stage, + producer_group=mma_compute_dP_producer_group, + consumer_group=mma_compute_dP_consumer_group, + ) + + def make_and_init_mma_reduce_dQ_pipeline(self, mma_reduce_dQ_mbar_ptr): + mma_reduce_dQ_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_warp_id]), + ) + mma_reduce_dQ_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_reduce_warps * self.threads_per_warp, + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_reduce_dQ_mbar_ptr, + num_stages=self.mma_reduce_dQ_stage, + producer_group=mma_reduce_dQ_producer_group, + consumer_group=mma_reduce_dQ_consumer_group, + ) + + def make_and_init_compute_mma_P_pipeline(self, compute_mma_P_mbar_ptr): + compute_mma_P_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_compute_warps * self.threads_per_warp, + ) + compute_mma_P_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_warp_id]), + ) + return pipeline.PipelineAsyncUmma.create( + barrier_storage=compute_mma_P_mbar_ptr, + num_stages=self.compute_mma_P_stage, + producer_group=compute_mma_P_producer_group, + consumer_group=compute_mma_P_consumer_group, + ) + + def make_and_init_compute_mma_dS_pipeline(self, compute_mma_dS_mbar_ptr): + compute_mma_dS_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_compute_warps * self.threads_per_warp, + ) + compute_mma_dS_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_warp_id]), + ) + + return pipeline.PipelineAsyncUmma.create( + barrier_storage=compute_mma_dS_mbar_ptr, + num_stages=self.compute_mma_dS_stage, + producer_group=compute_mma_dS_producer_group, + consumer_group=compute_mma_dS_consumer_group, + ) + + def make_and_init_mma_compute_dKdV_pipeline(self, mma_compute_dKdV_mbar_ptr): + mma_compute_dKdV_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_warp_id]), + ) + mma_compute_dKdV_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_compute_warps * self.threads_per_warp, + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_compute_dKdV_mbar_ptr, + num_stages=self.mma_compute_dKdV_stage, + producer_group=mma_compute_dKdV_producer_group, + consumer_group=mma_compute_dKdV_consumer_group, + ) + + def make_and_init_reduce_tma_store_pipeline(self): + reduce_tma_store_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_reduce_warps * self.threads_per_warp, + ) + return pipeline.PipelineTmaStore.create( + num_stages=self.reduce_tma_store_stage, + producer_group=reduce_tma_store_producer_group, + ) + + +def run( + s_q: int | Tuple[int, ...], + s_k: int | Tuple[int, ...], + h_q: int, + h_k: int, + d: int, + b: int, + is_causal: bool, + bottom_right_align: bool, + element_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + scale_softmax: float, + window_size: Tuple[int, int], + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, +): + print("Running Blackwell SM100 FMHA bwd test with:") + print(f" s_q: {s_q}") + print(f" s_k: {s_k}") + print(f" h_q: {h_q}") + print(f" h_k: {h_k}") + print(f" d: {d}") + print(f" b: {b}") + print(f" is_causal: {is_causal}") + print(f" bottom_right_align: {bottom_right_align}") + print(f" element_dtype: {element_dtype}") + print(f" acc_dtype: {acc_dtype}") + print(f" mma_tiler_mn: {mma_tiler_mn}") + print(f" scale_softmax: {scale_softmax}") + print(f" window_size: {window_size}") + print(f" warmup_iterations: {warmup_iterations}") + print(f" iterations: {iterations}") + print(f" skip_ref_check: {skip_ref_check}") + + torch.manual_seed(42) + random.seed(123) + + if d not in {64, 128}: + raise ValueError("head dimension must be 64, or 128") + + if h_q % h_k != 0: + raise ValueError("h_q must be divisible by h_k") + + if element_dtype not in {Float8E4M3FN, Float16, BFloat16}: + raise ValueError("in_dtype must be Float8E4M3FN or Float16 or BFloat16") + + if acc_dtype not in {Float32}: + raise ValueError("acc_dtype must be Float32") + + if iterations < 1: + raise ValueError("iterations must be at least 1") + + if isinstance(s_q, tuple) and len(s_q) != b: + raise ValueError("s_q must be a tuple of length b") + + window_size_left, window_size_right = window_size + if window_size_left == -1: + window_size_left = None + if window_size_right == -1: + window_size_right = None + + h_r = h_q // h_k + orig_b = b + + if scale_softmax == 0.0: + scale_softmax = 1.0 / math.sqrt(d) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + def create_and_permute_tensor( + shape, + permute_order, + dtype, + min_val=-2, + max_val=2, + is_dynamic_layout=True, + zero_out=False, + ): + # (b, s, h_k, h_r, d) -> (s, d, h_r, h_k, b) + ref_tensor = ( + torch.empty(*shape, dtype=torch.float32) + .random_(min_val, max_val) + .permute(permute_order) + ) + if zero_out: + ref_tensor.zero_() + + torch_dtype = cutlass_torch.dtype(dtype) + + dst_tensor = ref_tensor.to(dtype=torch_dtype).cuda() + cute_tensor = from_dlpack(dst_tensor, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic( + leading_dim=1 + ).mark_compact_shape_dynamic( + mode=1, stride_order=(4, 0, 3, 2, 1), divisibility=64 + ) + + return ref_tensor, cute_tensor, dst_tensor + + s_q_list = s_q if isinstance(s_q, tuple) else [s_q] * b + s_k_list = s_k if isinstance(s_k, tuple) else [s_k] * b + + # To avoid mask out the whole row which results in NaN in softmax + def check_seqlen_valid( + s_q, s_k, window_size_left, window_size_right, bottom_right_align + ): + for i in range(s_q): + offset = 0 if not bottom_right_align else s_k - s_q + + s_q_start = 0 if window_size_left is None else i + offset - window_size_left + s_q_end = ( + s_q if window_size_right is None else i + offset + window_size_right + ) + s_q_min = max(s_q_start, 0) + s_q_max = min(s_q_end, s_k) + + if s_q_max - s_q_min == 0 and (i != 0 and i != s_q - 1): + return False + return True + + need_check_seqlen_valid = ( + window_size_left is not None or window_size_right is not None + ) + for i in range(b): + if need_check_seqlen_valid and not check_seqlen_valid( + s_q_list[i], + s_k_list[i], + window_size_left, + window_size_right, + bottom_right_align, + ): + raise ValueError("sliding window doesn't support current setting") + + # create sequence lengths for variable length inputs + cumulative_s_q = [0] + cumulative_s_k = [0] + varlen = False + if isinstance(s_q, tuple): + varlen = True + for i in range(b): + cumulative_s_q.append(cumulative_s_q[-1] + s_q[i]) + cumulative_s_k.append(cumulative_s_k[-1] + s_k[i]) + s_q_max = max(s_q) + s_k_max = max(s_k) + s_q = sum(s_q) + s_k = sum(s_k) + b = 1 + else: + s_q_max = s_q + s_k_max = s_k + + mask_type = fmha_utils.MaskEnum.WINDOW_MASK_BWD + if bottom_right_align: + mask_type = fmha_utils.MaskEnum.WINDOW_MASK_BWD_INFERENCE + if is_causal: + window_size_right = 0 + elif window_size_left is None and window_size_right is None: + if varlen or s_q % mma_tiler_mn[0] != 0: + mask_type = fmha_utils.MaskEnum.RESIDUAL_MASK_BWD + + problem_shape = (s_q_max, s_k_max, d, ((h_r, h_k), orig_b)) + cumulative_s_q_torch_tensor = ( + torch.tensor(cumulative_s_q, dtype=torch.int32).cuda() if varlen else None + ) + cumulative_s_k_torch_tensor = ( + torch.tensor(cumulative_s_k, dtype=torch.int32).cuda() if varlen else None + ) + cumulative_s_q_cute_tensor = ( + from_dlpack(cumulative_s_q_torch_tensor).mark_layout_dynamic() + if varlen + else None + ) + cumulative_s_k_cute_tensor = ( + from_dlpack(cumulative_s_k_torch_tensor).mark_layout_dynamic() + if varlen + else None + ) + + q_ref, q_tensor, q_torch = create_and_permute_tensor( + (b, s_q, h_k, h_r, d), (1, 4, 3, 2, 0), element_dtype, is_dynamic_layout=True + ) + dq_ref, dq_tensor, dq_torch = create_and_permute_tensor( + (b, s_q, h_k, h_r, d), + (1, 4, 3, 2, 0), + element_dtype, + is_dynamic_layout=True, + zero_out=True, + ) + k_ref, k_tensor, k_torch = create_and_permute_tensor( + (b, s_k, h_k, 1, d), (1, 4, 3, 2, 0), element_dtype, is_dynamic_layout=True + ) + dk_ref, dk_tensor, dk_torch = create_and_permute_tensor( + (b, s_k, h_k, 1, d), + (1, 4, 3, 2, 0), + element_dtype, + is_dynamic_layout=True, + zero_out=True, + ) + v_ref, v_tensor, v_torch = create_and_permute_tensor( + (b, s_k, h_k, 1, d), (1, 4, 3, 2, 0), element_dtype, is_dynamic_layout=True + ) + dv_ref, dv_tensor, dv_torch = create_and_permute_tensor( + (b, s_k, h_k, 1, d), + (1, 4, 3, 2, 0), + element_dtype, + is_dynamic_layout=True, + zero_out=True, + ) + do_ref, do_tensor, do_torch = create_and_permute_tensor( + (b, s_q, h_k, h_r, d), (1, 4, 3, 2, 0), element_dtype, is_dynamic_layout=True + ) + o_ref, o_tensor, o_torch = create_and_permute_tensor( + (b, s_q, h_k, h_r, d), (1, 4, 3, 2, 0), element_dtype, is_dynamic_layout=True + ) + + lse_ref = cutlass_torch.create_and_permute_torch_tensor( + (b, h_k, h_r, s_q), + cutlass.torch.dtype(acc_dtype), + permute_order=(3, 2, 1, 0), + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=cutlass.torch.RandomInitConfig(min_val=10, max_val=11), + ) + lse_torch = lse_ref.cuda() + lse_tensor = from_dlpack(lse_torch, assumed_align=16) + lse_tensor = lse_tensor.mark_layout_dynamic(leading_dim=0) + + mma_tiler = (*mma_tiler_mn, d) + + fmha_bwd = BlackwellFusedMultiHeadAttentionBackward( + element_dtype, acc_dtype, mma_tiler, varlen, mask_type + ) + + workspace_size = BlackwellFusedMultiHeadAttentionBackward._get_workspace_size( + s_q_max, s_k_max, d, h_q, orig_b, acc_dtype + ) + workspace_torch = torch.zeros(workspace_size, dtype=torch.uint8).cuda() + workspace = from_dlpack(workspace_torch, assumed_align=16).mark_layout_dynamic() + + # Initialize Stream + current_stream = cutlass_torch.default_stream() + + print("Compiling kernel with cute.compile ...") + start_time = time.time() + compiled_fmha_bwd = cute.compile( + fmha_bwd, + problem_shape, + q_tensor, + k_tensor, + v_tensor, + o_tensor, + dq_tensor, + dk_tensor, + dv_tensor, + do_tensor, + lse_tensor, + cumulative_s_q_cute_tensor, + cumulative_s_k_cute_tensor, + scale_softmax, + window_size_left if window_size_left is None else Int32(window_size_left), + window_size_right if window_size_right is None else Int32(window_size_right), + workspace, + current_stream, + ) + compilation_time = time.time() - start_time + print(f"Compilation time: {compilation_time:.4f} seconds") + + for _ in range(warmup_iterations): + compiled_fmha_bwd( + problem_shape, + q_tensor, + k_tensor, + v_tensor, + o_tensor, + dq_tensor, + dk_tensor, + dv_tensor, + do_tensor, + lse_tensor, + cumulative_s_q_cute_tensor, + cumulative_s_k_cute_tensor, + scale_softmax, + window_size_left if window_size_left is None else Int32(window_size_left), + ( + window_size_right + if window_size_right is None + else Int32(window_size_right) + ), + workspace, + current_stream, + ) + + for _ in range(iterations): + compiled_fmha_bwd( + problem_shape, + q_tensor, + k_tensor, + v_tensor, + o_tensor, + dq_tensor, + dk_tensor, + dv_tensor, + do_tensor, + lse_tensor, + cumulative_s_q_cute_tensor, + cumulative_s_k_cute_tensor, + scale_softmax, + window_size_left if window_size_left is None else Int32(window_size_left), + ( + window_size_right + if window_size_right is None + else Int32(window_size_right) + ), + workspace, + current_stream, + ) + + if not skip_ref_check: + workspace_torch.fill_(0) + compiled_fmha_bwd( + problem_shape, + q_tensor, + k_tensor, + v_tensor, + o_tensor, + dq_tensor, + dk_tensor, + dv_tensor, + do_tensor, + lse_tensor, + cumulative_s_q_cute_tensor, + cumulative_s_k_cute_tensor, + scale_softmax, + window_size_left if window_size_left is None else Int32(window_size_left), + ( + window_size_right + if window_size_right is None + else Int32(window_size_right) + ), + workspace, + current_stream, + ) + torch.cuda.synchronize() + print("Verifying results...") + + q_ref = q_ref.cuda().to(cutlass.torch.dtype(element_dtype)) + k_ref = k_ref.cuda().to(cutlass.torch.dtype(element_dtype)) + v_ref = v_ref.cuda().to(cutlass.torch.dtype(element_dtype)) + o_ref = o_ref.cuda().to(cutlass.torch.dtype(element_dtype)) + do_ref = do_ref.cuda().to(cutlass.torch.dtype(element_dtype)) + dv = dv_torch.to(dtype=torch.float32) + dk = dk_torch.to(dtype=torch.float32) + dq = dq_torch.to(dtype=torch.float32) + + dv_ref, dk_ref, dq_ref = fmha_bwd_reference( + problem_shape, + q_ref, + k_ref, + v_ref, + do_ref, + o_ref, + lse_torch, + cumulative_s_q_torch_tensor, + cumulative_s_k_torch_tensor, + is_causal, + bottom_right_align, + window_size_left, + window_size_right, + ) + dv_pt, dk_pt, dq_pt = fmha_bwd_reference( + problem_shape, + q_ref, + k_ref, + v_ref, + do_ref, + o_ref, + lse_torch, + cumulative_s_q_torch_tensor, + cumulative_s_k_torch_tensor, + is_causal, + bottom_right_align, + window_size_left, + window_size_right, + upcast=False, + reorder_ops=True, + ) + + rtol = 2 + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + print(f"Pytorch dv max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"Pytorch dv mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + print(f"Pytorch dk max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"Pytorch dk mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"Pytorch dq max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"Pytorch dq mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + + print(f"dv max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dv mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dk max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dk mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dq max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dq mean diff: {(dq - dq_ref).abs().mean().item()}") + + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + + print("Results verified successfully!") + + def generate_tensors(): + _, q_tensor_new, _ = create_and_permute_tensor( + (b, s_q, h_k, h_r, d), + (1, 4, 3, 2, 0), + element_dtype, + is_dynamic_layout=True, + ) + _, dq_tensor_new, _ = create_and_permute_tensor( + (b, s_q, h_k, h_r, d), + (1, 4, 3, 2, 0), + element_dtype, + is_dynamic_layout=True, + ) + _, k_tensor_new, _ = create_and_permute_tensor( + (b, s_k, h_k, 1, d), + (1, 4, 3, 2, 0), + element_dtype, + is_dynamic_layout=True, + ) + _, dk_tensor_new, _ = create_and_permute_tensor( + (b, s_k, h_k, 1, d), (1, 4, 3, 2, 0), element_dtype, is_dynamic_layout=True + ) + _, v_tensor_new, _ = create_and_permute_tensor( + (b, s_k, h_k, 1, d), (1, 4, 3, 2, 0), element_dtype, is_dynamic_layout=True + ) + _, dv_tensor_new, _ = create_and_permute_tensor( + (b, s_k, h_k, 1, d), (1, 4, 3, 2, 0), element_dtype, is_dynamic_layout=True + ) + _, do_tensor_new, _ = create_and_permute_tensor( + (b, s_q, h_k, h_r, d), + (1, 4, 3, 2, 0), + element_dtype, + is_dynamic_layout=True, + ) + _, o_tensor_new, _ = create_and_permute_tensor( + (b, s_q, h_k, h_r, d), + (1, 4, 3, 2, 0), + element_dtype, + is_dynamic_layout=True, + ) + + lse_ref_new = cutlass_torch.create_and_permute_torch_tensor( + (b, h_k, h_r, s_q), + cutlass.torch.dtype(acc_dtype), + permute_order=(3, 2, 1, 0), + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=cutlass.torch.RandomInitConfig(min_val=10, max_val=11), + ) + lse_torch_new = lse_ref_new.cuda() + lse_tensor_new = from_dlpack(lse_torch_new, assumed_align=16) + lse_tensor_new = lse_tensor_new.mark_layout_dynamic(leading_dim=0) + + return testing.JitArguments( + problem_shape, + q_tensor_new, + k_tensor_new, + v_tensor_new, + o_tensor_new, + dq_tensor_new, + dk_tensor_new, + dv_tensor_new, + do_tensor_new, + lse_tensor_new, + cumulative_s_q_cute_tensor, + cumulative_s_k_cute_tensor, + scale_softmax, + window_size_left if window_size_left is None else Int32(window_size_left), + ( + window_size_right + if window_size_right is None + else Int32(window_size_right) + ), + workspace, + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + q_torch.numel() * q_torch.element_size() + + dq_torch.numel() * dq_torch.element_size() + + k_torch.numel() * k_torch.element_size() + + dk_torch.numel() * dk_torch.element_size() + + v_torch.numel() * v_torch.element_size() + + dv_torch.numel() * dv_torch.element_size() + + do_torch.numel() * do_torch.element_size() + + o_torch.numel() * o_torch.element_size() + + lse_torch.numel() * lse_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_fmha_bwd, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +def fmha_bwd_reference( + problem_shape: Tuple[int, int, int, Tuple[Tuple[int, int], int]], + Q: torch.Tensor, # [Q, D, H_R, H_K, B] + K: torch.Tensor, # [K, D, 1, H_K, B] + V: torch.Tensor, # [K, D, 1, H_K, B] + dO: torch.Tensor, # [Q, D, H_R, H_K, B] + O: torch.Tensor, # [Q, D, H_R, H_K, B] + LSE: torch.Tensor, # [Q, H_R, H_K, B] + cumulative_s_q: Union[torch.Tensor, None], + cumulative_s_k: Union[torch.Tensor, None], + is_causal: bool, + bottom_right_align: bool, + window_size_left=None, + window_size_right=None, + upcast=True, + reorder_ops=False, +): + s_q_max, s_k_max, d, hb = problem_shape + h, orig_b = hb + h_r, h_k = h + is_gqa = h_r != 1 + + if upcast: + Q = Q.to(dtype=torch.float32) + K = K.to(dtype=torch.float32) + V = V.to(dtype=torch.float32) + dO = dO.to(dtype=torch.float32) + O = O.to(dtype=torch.float32) + LSE = LSE.to(dtype=torch.float32) + + softmax_scale = 1.0 / math.sqrt(problem_shape[2]) + dV = torch.zeros_like(V) + dK = torch.zeros_like(K) + dQ = torch.zeros_like(Q) + + for b in range(orig_b): + q_offset = cumulative_s_q[b] if cumulative_s_q is not None else 0 + k_offset = cumulative_s_k[b] if cumulative_s_k is not None else 0 + s_q = ( + cumulative_s_q[b + 1] - cumulative_s_q[b] + if cumulative_s_q is not None + else s_q_max + ) + s_k = ( + cumulative_s_k[b + 1] - cumulative_s_k[b] + if cumulative_s_k is not None + else s_k_max + ) + + for h_k_idx in range(h_k): + b_idx = b if cumulative_s_k is None else 0 + cur_K = K[k_offset : k_offset + s_k, :, 0, h_k_idx, b_idx] + cur_V = V[k_offset : k_offset + s_k, :, 0, h_k_idx, b_idx] + for h_r_idx in range(h_r): + cur_Q = Q[q_offset : q_offset + s_q, :, h_r_idx, h_k_idx, b_idx] + cur_dO = dO[q_offset : q_offset + s_q, :, h_r_idx, h_k_idx, b_idx] + cur_O = O[q_offset : q_offset + s_q, :, h_r_idx, h_k_idx, b_idx] + cur_LSE = LSE[q_offset : q_offset + s_q, h_r_idx, h_k_idx, b_idx] + cur_LSE = cur_LSE.unsqueeze(-1) + + if not reorder_ops: + cur_S = torch.einsum("qd,kd->qk", cur_Q * softmax_scale, cur_K) + else: + cur_S = torch.einsum("qd,kd->qk", cur_Q, cur_K * softmax_scale) + + if is_causal: + window_size_right = 0 + if window_size_left is not None or window_size_right is not None: + q_coords = torch.arange(0, s_q).cuda().view(-1, 1) + k_coords = torch.arange(0, s_k).cuda().view(1, -1) + offset = 0 if not bottom_right_align else s_k - s_q + if window_size_left is None: + mask = k_coords > q_coords + offset + window_size_right + elif window_size_right is None: + mask = k_coords < q_coords + offset - window_size_left + else: + mask = (k_coords > q_coords + offset + window_size_right) | ( + k_coords < q_coords + offset - window_size_left + ) + cur_S = cur_S.masked_fill(mask, -torch.inf) + + cur_P = torch.exp(cur_S - cur_LSE) + cur_PT = cur_P.transpose(1, 0).to(dtype=Q.dtype) + cur_dV = torch.einsum("kq,qd->kd", [cur_PT, cur_dO]) + + cur_dP = torch.einsum("qd,kd->qk", cur_dO, cur_V) + cur_D = torch.einsum("qd,qd->qd", cur_O, cur_dO) + cur_D = cur_D.sum(dim=1) + cur_D = cur_D.reshape(cur_D.shape[0], 1) + cur_dS = cur_P * (cur_dP - cur_D) * softmax_scale + cur_dS = cur_dS.to(dtype=Q.dtype) + cur_dST = cur_dS.transpose(1, 0) + cur_dK = torch.einsum("kq,qd->kd", cur_dST, cur_Q) + cur_dQ = torch.einsum("qk,kd->qd", cur_dS, cur_K) + + dQ[q_offset : q_offset + s_q, :, h_r_idx, h_k_idx, b_idx] = cur_dQ + if is_gqa: + dV_orig = dV[k_offset : k_offset + s_k, :, 0, h_k_idx, b_idx] + cur_dV_sum = dV_orig.to(dtype=torch.float32) + cur_dV.to( + dtype=torch.float32 + ) + dV[k_offset : k_offset + s_k, :, 0, h_k_idx, b_idx] = cur_dV_sum.to( + dtype=V.dtype + ) + dK_orig = dK[k_offset : k_offset + s_k, :, 0, h_k_idx, b_idx] + dK_cur_sum = dK_orig.to(dtype=torch.float32) + cur_dK.to( + dtype=torch.float32 + ) + dK[k_offset : k_offset + s_k, :, 0, h_k_idx, b_idx] = dK_cur_sum.to( + dtype=K.dtype + ) + else: + dV[k_offset : k_offset + s_k, :, 0, h_k_idx, b_idx] = cur_dV + dK[k_offset : k_offset + s_k, :, 0, h_k_idx, b_idx] = cur_dK + + dV = dV.to(dtype=torch.float32) + dK = dK.to(dtype=torch.float32) + dQ = dQ.to(dtype=torch.float32) + + return dV, dK, dQ + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...] | int: + try: + seqlen = tuple(int(x.strip()) for x in s.split(",")) + if len(seqlen) == 1: + return seqlen[0] + return seqlen + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser(description="Example of bwd FMHA on Blackwell.") + + parser.add_argument( + "--element_dtype", + type=cutlass.dtype, + default=Float16, + help="Input data type", + ) + + parser.add_argument( + "--acc_dtype", + type=cutlass.dtype, + default=Float32, + help="accumulator data type", + ) + + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="MMA tile shape (M, N)", + ) + + parser.add_argument( + "--is_causal", + action="store_true", + help="Whether to use causal mask", + ) + + parser.add_argument( + "--bottom_right_align", + action="store_true", + help="Whether to use bottom right align, under this settion, the end of q is aligned with the end of k.", + ) + + parser.add_argument( + "--s_q", + type=parse_comma_separated_ints, + default=1024, + help="max sequence length of Q", + ) + + parser.add_argument( + "--s_k", + type=parse_comma_separated_ints, + default=1024, + help="max sequence length of K", + ) + + parser.add_argument( + "--d", + type=int, + default=128, + help="head dimension", + ) + + parser.add_argument( + "--h_q", + type=int, + default=8, + help="number of heads of Q", + ) + + parser.add_argument( + "--h_k", + type=int, + default=8, + help="number of heads of K", + ) + + parser.add_argument( + "--b", + type=int, + default=1, + help="batch size", + ) + + parser.add_argument( + "--scale_softmax", + type=float, + default=0.0, + help="Scaling factor to scale S (i.e. Q*K); if zero, defaults to 1/sqrt(D)", + ) + + parser.add_argument( + "--window_size", + type=parse_comma_separated_ints, + default=(-1, -1), + help="Sliding window size (left, right) for attention masking.", + ) + + parser.add_argument( + "--warmup_iterations", + type=int, + default=0, + help="Number of iterations for warmup", + ) + + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations after warmup", + ) + + parser.add_argument( + "--skip_ref_check", + action="store_true", + help="Skip reference check", + ) + + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if args.mma_tiler_mn != (128, 128): + parser.error("--mma_tiler_mn only supports (128, 128)") + + run( + args.s_q, + args.s_k, + args.h_q, + args.h_k, + args.d, + args.b, + args.is_causal, + args.bottom_right_align, + args.element_dtype, + args.acc_dtype, + args.mma_tiler_mn, + args.scale_softmax, + args.window_size, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + + print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/grouped_blockscaled_gemm.py b/examples/python/CuTeDSL/blackwell/grouped_blockscaled_gemm.py new file mode 100644 index 00000000..e8001903 --- /dev/null +++ b/examples/python/CuTeDSL/blackwell/grouped_blockscaled_gemm.py @@ -0,0 +1,3222 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import functools +from typing import List, Type, Tuple, Union +from inspect import isclass + +import torch +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.runtime import from_dlpack + +""" +This example provides an experimental implementation of the SM100 grouped blockscaled GEMM kernel, please note that the APIs and implementation details related to this kernel may change in future releases. + +A grouped blockscaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTE DSL + +This example demonstrates an implementation of grouped blockscaled GEMM using a TMA plus Blackwell SM100 TensorCore +warp-specialized persistent kernel. +The grouped GEMM workload computes a batch of GEMM operations with distinct problem sizes. Pointers to matrices +in global memory are passed to the kernel in an array (also held in global memory). Similarly, problem shapes and +strides are also stored in arrays in GMEM. + +This differs from "Batched Array" GEMM since the size of each GEMM problem in the grouped GEMM concept may be distinct. + +To run this example: + +.. code-block:: bash + + python examples/blackwell/grouped_blockscaled_gemm.py \ + --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ + --c_dtype Float16 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,1 \ + --problem_sizes_mnkl "(8192,1280,32,1),(32,384,1536,1),(640,1280,32,1),(640,160,32,1)" \ + --num_groups 4 + +The above example command makes 4 groups of different m, n, k sizes. The Blackwell tcgen05 MMA tile shape +is specified as (128, 64) and the cluster shape is (1,1). The input, mma accumulator and output data type +are set as fp16, fp32 and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/grouped_blockscaled_gemm.py \ + --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ + --c_dtype Float16 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,1 \ + --problem_sizes_mnkl "(8192,1280,32,1),(32,384,1536,1),(640,1280,32,1),(640,160,32,1)" \ + --num_groups 4 + --warmup_iterations 1 --iterations 10 --skip_ref_check + +Constraints: +* Supported input data types: mxf8, mxf4, nvf4 + see detailed valid dtype combinations in below Sm100GroupedBlockScaledGemmKernel class documentation +* A/B tensors must have the same data type, mixed data type is not supported (e.g., mxf8 x mxf4) +* Mma tiler M must be 128 or 256(use_2cta_instrs) +* Mma tiler N must be 128 or 256 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors +* Cluster shape M must be multiple of 2 if Mma tiler M is 256(use_2cta_instrs) +* The l mode(aka, batch size) for each group must be 1. +* The majorness for A, B and C must be the same across all groups. +* The contiguous dimension of A/B/C tensors in each group must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 16 and 32 for Float8 and Float4, respectively. +""" + + +class Sm100GroupedBlockScaledGemmKernel: + """This example demonstrates an implementation of grouped blockscaled GEMM using a TMA plus Blackwell SM100 TensorCore + warp-specialized persistent kernel. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: In current version, A and B tensors must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported combinations of A/B data types, SF data typs and SF vector size: + - MXF8: A/B: Float8E5M2/Float8E4M3FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - MXF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - NVF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU/Float8E4M3FN + sf_vec_size: 16 + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float32 + - Float16/BFloat16 + - Float8E4M3FN/Float8E5M2 + :note: Constraints: + - MMA tiler M must be 128 or 256 (use_2cta_instrs) + - MMA tiler N must be 128/256 + - Cluster shape M must be multiple of 2 if Mma tiler M is 256 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + - Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors + """ + + def __init__( + self, + sf_vec_size: int, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ): + """Initializes the configuration for a Blackwell grouped blockscaled GEMM kernel. + + Besides configurations for dense persistent blockscaled GEMM, there is an extra config specific to grouped blockscaled GEMM: + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: tuple[int, int] + :param cluster_shape_mn: tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: tuple[int, int] + """ + self.acc_dtype = cutlass.Float32 + self.sf_vec_size = sf_vec_size + self.use_2cta_instrs = mma_tiler_mn[0] == 256 + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.tensormap_update_mode = utils.TensorMapUpdateMode.SMEM + + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.threads_per_cta = 32 * len( + (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) + ) + # Set barrier for cta sync, epilogue sync and tmem ptr sync + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)), + ) + # Barrier used by MMA/TMA warps to signal A/B tensormap initialization completion + self.tensormap_ab_init_barrier = pipeline.NamedBarrier( + barrier_id=4, + num_threads=64, + ) + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + # Set up configurations that dependent on gemm inputs. + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B/SFA/SFB + - Computing epilogue subtile + - Setting up A/B/SFA/SFB/C stage counts in shared memory + - Computing A/B/SFA/SFB/C shared memory layout + - Checking reserved smem bytes size capacity for mbar, tensor memory management and tensormap updates utilization + """ + # Compute mma instruction shapes + # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K) + self.mma_inst_shape_mn = ( + self.mma_tiler[0], + self.mma_tiler[1], + ) + # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.mma_tiler_sfb = ( + self.mma_inst_shape_mn_sfb[0], + self.mma_inst_shape_mn_sfb[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + self.cluster_tile_shape_mnk = tuple( + x * y for x, y in zip(self.cta_tile_shape_mnk, (*self.cluster_shape_mn, 1)) + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 + + # Compute epilogue subtile + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sf_dtype, + self.sf_vec_size, + self.smem_capacity, + self.occupancy, + ) + + # Compute A/B/SFA/SFB/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + + mbar_smem_bytes = self._get_mbar_smem_bytes( + num_acc_stage=self.num_acc_stage, + num_ab_stage=self.num_ab_stage, + num_c_stage=self.num_c_stage, + ) + + # Use utils.TensorMapUpdateMode.SMEM by default + tensormap_smem_bytes = ( + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap + * Sm100GroupedBlockScaledGemmKernel.num_tensormaps + ) + if ( + mbar_smem_bytes + + tensormap_smem_bytes + + Sm100GroupedBlockScaledGemmKernel.tensor_memory_management_bytes + > self.reserved_smem_bytes + ): + raise ValueError( + f"smem consumption for mbar and tensormap {mbar_smem_bytes + tensormap_smem_bytes} exceeds the " + f"reserved smem bytes {self.reserved_smem_bytes}" + ) + + @cute.jit + def __call__( + self, + initial_a: cute.Tensor, + initial_b: cute.Tensor, + initial_c: cute.Tensor, + initial_sfa: cute.Tensor, + initial_sfb: cute.Tensor, + group_count: cutlass.Constexpr[int], + problem_shape_mnkl: cute.Tensor, + strides_abc: cute.Tensor, + tensor_address_abc: cute.Tensor, + tensor_address_sfasfb: cute.Tensor, + total_num_clusters: cutlass.Constexpr[int], + tensormap_cute_tensor: cute.Tensor, + max_active_clusters: cutlass.Constexpr[int], + stream: cuda.CUstream, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + For grouped GEMM, tensor shapes, tensor strides, and tensor address are all provided + by different tensors in global memory. The "initial" tensors only carry data type and + majorness information. + + :param initial_a: Initial tensor A, used for data type and majorness information. + :type initial_a: cute.Tensor + :param initial_b: Initial tensor B, used for data type and majorness information. + :type initial_b: cute.Tensor + :param initial_c: Initial tensor C, used for data type and majorness information. + :type initial_c: cute.Tensor + :param initial_sfa: Initial tensor SFA, used for data type and majorness information. + :type initial_sfa: cute.Tensor + :param initial_sfb: Initial tensor SFB, used for data type and majorness information. + :type initial_sfb: cute.Tensor + :param group_count: The number of GEMM groups. + :type group_count: cutlass.Constexpr[int] + :param problem_shape_mnkl: Tensor containing the (M, N, K, L) shape for each group. + :type problem_shape_mnkl: cute.Tensor + :param strides_abc: Tensor containing the strides for A, B, and C for each group. + :type strides_abc: cute.Tensor + :param tensor_address_abc: Tensor containing the base addresses for A, B, and C for each group. + :type tensor_address_abc: cute.Tensor + :param tensor_address_sfasfb: Tensor containing the base addresses for SFA and SFB for each group. + :type tensor_address_sfasfb: cute.Tensor + :param total_num_clusters: Total number of clusters needed for all groups. + :type total_num_clusters: cutlass.Constexpr[int] + :param tensormap_cute_tensor: Tensor for storing tensormaps. + :type tensormap_cute_tensor: cute.Tensor + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr[int] + :param stream: CUDA stream for asynchronous execution. + :type stream: cuda.CUstream + :raises TypeError: If A and B data types do not match. + """ + self.a_dtype = initial_a.element_type + self.b_dtype = initial_b.element_type + self.sf_dtype = initial_sfa.element_type + self.c_dtype = initial_c.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(initial_a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(initial_b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(initial_c) + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + initial_a.shape, self.sf_vec_size + ) + initial_sfa = cute.make_tensor(initial_sfa.iterator, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + initial_b.shape, self.sf_vec_size + ) + initial_sfb = cute.make_tensor(initial_sfb.iterator, sfb_layout) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + initial_a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + initial_b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for SFA + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfa_smem_layout = cute.slice_( + self.sfa_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, + initial_sfa, + sfa_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # Setup TMA load for SFB + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfb_smem_layout = cute.slice_( + self.sfb_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + initial_sfb, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes = ( + a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size + ) * atom_thr_size + + # Setup TMA store for C + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + initial_c, + epi_smem_layout, + self.epi_tile, + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + total_num_clusters, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + self.size_tensormap_in_i64 = ( + Sm100GroupedBlockScaledGemmKernel.num_tensormaps + * Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap + // 8 + ) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + tensormap_buffer: cute.struct.MemRange[ + cutlass.Int64, self.size_tensormap_in_i64 + ] + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.c_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sSFA: cute.struct.Align[ + cute.struct.MemRange[ + self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sSFB: cute.struct.Align[ + cute.struct.MemRange[ + self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tiled_mma_sfb, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + group_count, + problem_shape_mnkl, + strides_abc, + tensor_address_abc, + tensor_address_sfasfb, + tensormap_cute_tensor, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + smem=self.shared_storage.size_in_bytes(), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + group_count: cutlass.Constexpr, + problem_sizes_mnkl: cute.Tensor, + strides_abc: cute.Tensor, + ptrs_abc: cute.Tensor, + ptrs_sfasfb: cute.Tensor, + tensormaps: cute.Tensor, + ): + """ + GPU device kernel performing the grouped GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + if warp_idx == self.tma_warp_id: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_a) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_b) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_sfa) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_sfb) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: tensormap buffer, a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tensormap_smem_ptr = storage.tensormap_buffer.data_ptr() + tensormap_a_smem_ptr = tensormap_smem_ptr + tensormap_b_smem_ptr = ( + tensormap_a_smem_ptr + + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8 + ) + tensormap_sfa_smem_ptr = ( + tensormap_b_smem_ptr + + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8 + ) + tensormap_sfb_smem_ptr = ( + tensormap_sfa_smem_ptr + + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8 + ) + tensormap_c_smem_ptr = ( + tensormap_sfb_smem_ptr + + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8 + ) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * ( + 2 if use_2cta_instrs else 1 + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == self.tma_warp_id: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/SFA/SFB/C + # + sC = storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + + # + # Compute multicast mask for A/B/SFA/SFB buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bK, RestM, RestK, RestL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # TMA Load SFA partition_S/D + sfa_cta_layout = a_cta_layout + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfa, + block_in_cluster_coord_vmnk[2], + sfa_cta_layout, + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + + # TMA Load SFB partition_S/D + sfb_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + self.cta_sync_barrier.arrive_and_wait() + + # + # Get tensormap buffer address + # + grid_dim = cute.arch.grid_dim() + tensormap_workspace_idx = ( + bidz * grid_dim[1] * grid_dim[0] + bidy * grid_dim[0] + bidx + ) + + tensormap_manager = utils.TensorMapManager( + utils.TensorMapUpdateMode.SMEM, + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap, + ) + tensormap_a_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 0, None)].iterator + ) + tensormap_b_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 1, None)].iterator + ) + tensormap_sfa_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 2, None)].iterator + ) + tensormap_sfb_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 3, None)].iterator + ) + tensormap_c_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 4, None)].iterator + ) + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), grid_dim + ) + # grouped gemm tile scheduler helper will compute the group index for the tile we're working on + group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper( + group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + utils.create_initial_search_state(), + ) + tensormap_init_done = cutlass.Boolean(False) + # group index of last tile + last_group_idx = cutlass.Int32(-1) + + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + grouped_gemm_cta_tile_info = group_gemm_ts_helper.delinearize_z( + cur_tile_coord, + problem_sizes_mnkl, + ) + cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k + cur_group_idx = grouped_gemm_cta_tile_info.group_idx + is_group_changed = cur_group_idx != last_group_idx + # skip tensormap update if we're working on the same group + if is_group_changed: + real_tensor_a = self.make_tensor_abc_for_tensormap_update( + cur_group_idx, + self.a_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 0, # 0 for tensor A + ) + real_tensor_b = self.make_tensor_abc_for_tensormap_update( + cur_group_idx, + self.b_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 1, # 1 for tensor B + ) + real_tensor_sfa = self.make_tensor_sfasfb_for_tensormap_update( + cur_group_idx, + self.sf_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + ptrs_sfasfb, + 0, # 0 for tensor SFA + ) + real_tensor_sfb = self.make_tensor_sfasfb_for_tensormap_update( + cur_group_idx, + self.sf_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + ptrs_sfasfb, + 1, # 1 for tensor SFB + ) + if tensormap_init_done == False: + # wait tensormap initialization complete + self.tensormap_ab_init_barrier.arrive_and_wait() + tensormap_init_done = True + + tensormap_manager.update_tensormap( + ( + real_tensor_a, + real_tensor_b, + real_tensor_sfa, + real_tensor_sfb, + ), + (tma_atom_a, tma_atom_b, tma_atom_sfa, tma_atom_sfb), + ( + tensormap_a_gmem_ptr, + tensormap_b_gmem_ptr, + tensormap_sfa_gmem_ptr, + tensormap_sfb_gmem_ptr, + ), + self.tma_warp_id, + ( + tensormap_a_smem_ptr, + tensormap_b_smem_ptr, + tensormap_sfa_smem_ptr, + tensormap_sfb_smem_ptr, + ), + ) + + mma_tile_coord_mnl = ( + grouped_gemm_cta_tile_info.cta_tile_idx_m + // cute.size(tiled_mma.thr_id.shape), + grouped_gemm_cta_tile_info.cta_tile_idx_n, + 0, + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # ((atom_v, rest_v), RestK) + tAgSFA_slice = tAgSFA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgSFB_slice = tBgSFB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < cur_k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + if is_group_changed: + tensormap_manager.fence_tensormap_update(tensormap_a_gmem_ptr) + tensormap_manager.fence_tensormap_update(tensormap_b_gmem_ptr) + tensormap_manager.fence_tensormap_update(tensormap_sfa_gmem_ptr) + tensormap_manager.fence_tensormap_update(tensormap_sfb_gmem_ptr) + # + # Tma load loop + # + for k_tile in cutlass.range(0, cur_k_tile_cnt, 1, unroll=1): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status + ) + + # TMA load A/B/SFA/SFB + cute.copy( + tma_atom_a, + tAgA_slice[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_a_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_b_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + cute.copy( + tma_atom_sfa, + tAgSFA_slice[(None, ab_producer_state.count)], + tAsSFA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=sfa_full_mcast_mask, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_sfa_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + cute.copy( + tma_atom_sfb, + tBgSFB_slice[(None, ab_producer_state.count)], + tBsSFB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=sfb_full_mcast_mask, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_sfb_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < cur_k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + last_group_idx = cur_group_idx + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # + # Initialize tensormaps for A, B, SFA and SFB + # + tensormap_manager.init_tensormap_from_atom( + tma_atom_a, tensormap_a_smem_ptr, self.mma_warp_id + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_b, tensormap_b_smem_ptr, self.mma_warp_id + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_sfa, tensormap_sfa_smem_ptr, self.mma_warp_id + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_sfb, tensormap_sfb_smem_ptr, self.mma_warp_id + ) + # indicate tensormap initialization has finished + self.tensormap_ab_init_barrier.arrive_and_wait() + + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + self.tmem_alloc_barrier.arrive_and_wait() + + # + # Retrieving tensor memory ptr and make accumulator/SFA/SFB tensor + # + # Make accumulator tmem tensor + acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # Make SFA tmem tensor + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base), + dtype=self.sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # Make SFB tmem tensor + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=self.sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + # + # Partition for S2T copy of SFA/SFB + # + tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = ( + self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ) + tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = ( + self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), grid_dim + ) + # grouped gemm tile scheduler helper will compute the group index for the tile we're working on + group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper( + group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + utils.create_initial_search_state(), + ) + + work_tile = tile_sched.initial_work_tile_info() + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + # MMA warp is only interested in number of tiles along K dimension + ( + cur_k_tile_cnt, + cur_group_idx, + ) = group_gemm_ts_helper.search_cluster_tile_count_k( + cur_tile_coord, + problem_sizes_mnkl, + ) + + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < cur_k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # + # Mma mainloop + # + for k_tile in range(cur_k_tile_cnt): + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait( + ab_consumer_state, peek_ab_full_status + ) + + # Copy SFA/SFB from smem to tmem + s2t_stage_coord = ( + None, + None, + None, + None, + ab_consumer_state.index, + ) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) + + # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kblock_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB[sf_kblock_coord].iterator, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < cur_k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Async arrive accumulator buffer full + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # initialize tensorap for C + tensormap_manager.init_tensormap_from_atom( + tma_atom_c, + tensormap_c_smem_ptr, + self.epilog_warp_id[0], + ) + # + # Alloc tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, + tmem_holding_buf, + is_two_cta=use_2cta_instrs, + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + self.tmem_alloc_barrier.arrive_and_wait() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + ### Start from here + # + # Partition for epilogue + # + epi_tidx = tidx + tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = ( + self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs + ) + ) + + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + tma_atom_c, bSG_sC, bSG_gC_partitioned = ( + self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC + ) + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), grid_dim + ) + # grouped gemm tile scheduler helper will compute the group index for the tile we're working on + group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper( + group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + utils.create_initial_search_state(), + ) + + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + # group index to start searching + last_group_idx = cutlass.Int32(-1) + + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + grouped_gemm_cta_tile_info = group_gemm_ts_helper.delinearize_z( + cur_tile_coord, + problem_sizes_mnkl, + ) + cur_group_idx = grouped_gemm_cta_tile_info.group_idx + is_group_changed = cur_group_idx != last_group_idx + + if is_group_changed: + # construct tensor c based on real shape, stride information + real_tensor_c = self.make_tensor_abc_for_tensormap_update( + cur_group_idx, + self.c_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 2, # 2 for tensor C + ) + tensormap_manager.update_tensormap( + ((real_tensor_c),), + ((tma_atom_c),), + ((tensormap_c_gmem_ptr),), + self.epilog_warp_id[0], + (tensormap_c_smem_ptr,), + ) + + mma_tile_coord_mnl = ( + grouped_gemm_cta_tile_info.cta_tile_idx_m + // cute.size(tiled_mma.thr_id.shape), + grouped_gemm_cta_tile_info.cta_tile_idx_n, + 0, + ) + cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k + + # + # Slice to per mma tile index + # + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # + # Wait for accumulator buffer full + # + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + if is_group_changed: + if warp_idx == self.epilog_warp_id[0]: + tensormap_manager.fence_tensormap_update(tensormap_c_gmem_ptr) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + tRS_rC.store(acc_vec.to(self.c_dtype)) + + # + # Store C to shared memory + # + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.epilog_sync_barrier.arrive_and_wait() + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_c_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + last_group_idx = cur_group_idx + + # + # Dealloc the tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + self.epilog_sync_barrier.arrive_and_wait() + if warp_idx == self.epilog_warp_id[0]: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + acc_tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + # + # Wait for C store complete + # + c_pipeline.producer_tail() + + @cute.jit + def make_tensor_abc_for_tensormap_update( + self, + group_idx: cutlass.Int32, + dtype: Type[cutlass.Numeric], + problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32], + strides_abc: cute.Tensor, + tensor_address_abc: cute.Tensor, + tensor_index: int, + ): + """Extract stride and tensor address for a given group and construct a global tensor for A, B or C. + + This function is used within the kernel to dynamically create a CUTE tensor + representing A, B, or C for the current group being processed, using the + group-specific address, shape, and stride information. + + :param group_idx: The index of the current group within the grouped GEMM. + :type group_idx: cutlass.Int32 + :param dtype: The data type of the tensor elements (e.g., cutlass.Float16). + :type dtype: Type[cutlass.Numeric] + :param problem_shape_mnk: The (M, N, K) problem shape for the current group. + :type problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] + :param strides_abc: Tensor containing strides for A, B, C for all groups. Layout: (group_count, 3, 2). + :type strides_abc: cute.Tensor + :param tensor_address_abc: Tensor containing global memory addresses for A, B, C for all groups. Layout: (group_count, 3). + :type tensor_address_abc: cute.Tensor + :param tensor_index: Specifies which tensor to create: 0 for A, 1 for B, 2 for C. + :type tensor_index: int + :return: A CUTE tensor representing the requested global memory tensor (A, B, or C) for the specified group. + :rtype: cute.Tensor + :raises TypeError: If the provided dtype is not a subclass of cutlass.Numeric. + """ + ptr_i64 = tensor_address_abc[(group_idx, tensor_index)] + if cutlass.const_expr( + not isclass(dtype) or not issubclass(dtype, cutlass.Numeric) + ): + raise TypeError( + f"dtype must be a type of cutlass.Numeric, got {type(dtype)}" + ) + tensor_gmem_ptr = cute.make_ptr( + dtype, ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + + strides_tensor_gmem = strides_abc[(group_idx, tensor_index, None)] + strides_tensor_reg = cute.make_rmem_tensor( + cute.make_layout(2), + strides_abc.element_type, + ) + cute.autovec_copy(strides_tensor_gmem, strides_tensor_reg) + stride_mn = strides_tensor_reg[0] + stride_k = strides_tensor_reg[1] + c1 = cutlass.Int32(1) + c0 = cutlass.Int32(0) + + if cutlass.const_expr(tensor_index == 0): # tensor A + m = problem_shape_mnk[0] + k = problem_shape_mnk[2] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((m, k, c1), stride=(stride_mn, stride_k, c0)), + ) + elif cutlass.const_expr(tensor_index == 1): # tensor B + n = problem_shape_mnk[1] + k = problem_shape_mnk[2] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((n, k, c1), stride=(stride_mn, stride_k, c0)), + ) + else: # tensor C + m = problem_shape_mnk[0] + n = problem_shape_mnk[1] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((m, n, c1), stride=(stride_mn, stride_k, c0)), + ) + + @cute.jit + def make_tensor_sfasfb_for_tensormap_update( + self, + group_idx: cutlass.Int32, + dtype: Type[cutlass.Numeric], + problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32], + tensor_address_sfasfb: cute.Tensor, + tensor_index: int, + ): + """Extract tensor address for a given group and construct a global tensor for SFA or SFB. + + This function is used within the kernel to dynamically create a CUTE tensor + representing SFA or SFB for the current group being processed, using the + group-specific address, shape information. + + :param group_idx: The index of the current group within the grouped GEMM. + :type group_idx: cutlass.Int32 + :param dtype: The data type of the tensor elements (e.g., cutlass.Float16). + :type dtype: Type[cutlass.Numeric] + :param problem_shape_mnk: The (M, N, K) problem shape for the current group. + :type problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] + :param tensor_address_sfasfb: Tensor containing global memory addresses for SFA, SFB for all groups. Layout: (group_count, 2). + :type tensor_address_sfasfb: cute.Tensor + :param tensor_index: Specifies which tensor to create: 0 for SFA, 1 for SFB. + :type tensor_index: int + :return: A CUTE tensor representing the requested global memory tensor (SFA, SFB) for the specified group. + :rtype: cute.Tensor + :raises TypeError: If the provided dtype is not a subclass of cutlass.Numeric. + """ + ptr_i64 = tensor_address_sfasfb[(group_idx, tensor_index)] + if cutlass.const_expr( + not isclass(dtype) or not issubclass(dtype, cutlass.Numeric) + ): + raise TypeError( + f"dtype must be a type of cutlass.Numeric, got {type(dtype)}" + ) + tensor_gmem_ptr = cute.make_ptr( + dtype, ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + + c1 = cutlass.Int32(1) + if cutlass.const_expr(tensor_index == 0): # tensor SFA + m = problem_shape_mnk[0] + k = problem_shape_mnk[2] + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + (m, k, c1), self.sf_vec_size + ) + return cute.make_tensor( + tensor_gmem_ptr, + sfa_layout, + ) + else: # tensor SFB + n = problem_shape_mnk[1] + k = problem_shape_mnk[2] + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + (n, k, c1), self.sf_vec_size + ) + return cute.make_tensor( + tensor_gmem_ptr, + sfb_layout, + ) + + def mainloop_s2t_copy_and_partition( + self, + sSF: cute.Tensor, + tSF: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination). + + :param sSF: The scale factor tensor in smem + :type sSF: cute.Tensor + :param tSF: The scale factor tensor in tmem + :type tSF: cute.Tensor + + :return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where: + - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t) + - tCsSF_compact_s2t: The partitioned scale factor tensor in smem + - tSF_compact_s2t: The partitioned scale factor tensor in tmem + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSF_compact = cute.filter_zeros(sSF) + # (MMA, MMA_MN, MMA_K) + tCtSF_compact = cute.filter_zeros(tSF) + + # Make S2T CopyAtom and tiledCopy + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(self.cta_group), + self.sf_dtype, + ) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t, tCsSF_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C. + :type c_layout: utils.LayoutEnum + :param sf_dtype: Data type of Scale factor. + :type sf_dtype: type[cutlass.Numeric] + :param sf_vec_size: Scale factor vector size. + :type sf_vec_size: int + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # ACC stages + num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 + + # Default C stages + num_c_stage = 2 + + # Calculate smem layout and size for one stage of A, B, SFA, SFB and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = ( + cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + ) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B/SFA/SFB stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B/SFA/SFB stage + num_ab_stage = ( + smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B/SFA/SFB stages and reserved bytes + # Add remaining unused smem to epilogue + num_c_stage += ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes) + ) // (occupancy * c_bytes_per_stage) + + return num_acc_stage, num_ab_stage, num_c_stage + + @staticmethod + def _compute_grid( + total_num_clusters: int, + cluster_shape_mn: tuple[int, int], + max_active_clusters: cutlass.Constexpr[int], + ) -> tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]: + """Compute tile scheduler parameters and grid shape for grouped GEMM operations. + + :param total_num_clusters: Total number of clusters to process across all groups. + :type total_num_clusters: int + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr[int] + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: tuple[utils.PersistentTileSchedulerParams, tuple[int, ...]] + """ + # Create problem shape with M, N dimensions from cluster shape + # and L dimension representing the total number of clusters. + problem_shape_ntile_mnl = ( + cluster_shape_mn[0], + cluster_shape_mn[1], + cutlass.Int32(total_num_clusters), + ) + + tile_sched_params = utils.PersistentTileSchedulerParams( + problem_shape_ntile_mnl, (*cluster_shape_mn, 1) + ) + + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _get_mbar_smem_bytes(**kwargs_stages: int) -> int: + """Calculate shared memory consumption for memory barriers based on provided stages. + + Each stage requires 2 barriers, and each barrier consumes 8 bytes of shared memory. + The total consumption is the sum across all provided stages. This function calculates the total + shared memory needed for these barriers. + + :param kwargs_stages: Variable keyword arguments where each key is a stage name + (e.g., num_acc_stage, num_ab_stage) and each value is the + number of stages of that type. + :type kwargs_stages: int + :return: Total shared memory bytes required for all memory barriers. + :rtype: int + """ + num_barriers_per_stage = 2 + num_bytes_per_barrier = 8 + mbar_smem_consumption = sum( + [ + num_barriers_per_stage * num_bytes_per_barrier * stage + for stage in kwargs_stages.values() + ] + ) + return mbar_smem_consumption + + @staticmethod + def is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes and sf_vec_size are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size of the scale factor + :type sf_vec_size: int + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes and sf_vec_size are valid, False otherwise + :rtype: bool + """ + is_valid = True + + # Check valid ab_dtype + if ab_dtype not in { + cutlass.Float4E2M1FN, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + # Check valid sf_vec_size + if sf_vec_size not in {16, 32}: + is_valid = False + + # Check valid sf_dtype + if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}: + is_valid = False + + # Check valid sf_dtype and sf_vec_size combinations + if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32: + is_valid = False + if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} and sf_vec_size == 16: + is_valid = False + + # Check valid c_dtype + if c_dtype not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + return is_valid + + @staticmethod + def is_valid_layouts( + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if layouts and dtypes are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major dimension of the A tensor + :type a_major: str + :param b_major: The major dimension of the B tensor + :type b_major: str + :param c_major: The major dimension of the C tensor + :type c_major: str + + :return: True if the layouts are valid, False otherwise + :rtype: bool + """ + is_valid = True + + if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"): + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if mma_tiler_mn[0] not in [128, 256]: + is_valid = False + if mma_tiler_mn[1] not in [128, 256]: + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + # Special cluster shape check for scale factor multicasts. + # Due to limited size of scale factors, we can't multicast among more than 4 CTAs. + or cluster_shape_mn[0] > 4 + or cluster_shape_mn[1] > 4 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + problem_sizes_mnkl: List[Tuple[int, int, int, int]], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param problem_sizes_mnkl: The problem shape for each group + :type problem_sizes_mnkl: List[Tuple[int, int, int, int]] + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + for m, n, k, l in problem_sizes_mnkl: + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment( + ab_dtype, b_major == "n", (n, k, l) + ) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + problem_sizes_mnkl: List[Tuple[int, int, int, int]], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor tensor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size + :type sf_vec_size: int + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not Sm100GroupedBlockScaledGemmKernel.is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype, sf_dtype, sf_vec_size, c_dtype + ): + can_implement = False + # Skip unsupported layouts + if not Sm100GroupedBlockScaledGemmKernel.is_valid_layouts( + ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not Sm100GroupedBlockScaledGemmKernel.is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not Sm100GroupedBlockScaledGemmKernel.is_valid_tensor_alignment( + problem_sizes_mnkl, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + return can_implement + + # Size of smem we reserved for mbarrier, tensor memory management and tensormap update + reserved_smem_bytes = 1024 + bytes_per_tensormap = 128 + num_tensormaps = 5 + # size of smem used for tensor memory management + tensor_memory_management_bytes = 12 + + +# Create tensor and return the pointer, tensor, and stride +def create_tensor_and_stride( + l: int, + mode0: int, + mode1: int, + is_mode0_major: bool, + dtype: type[cutlass.Numeric], + is_dynamic_layout: bool = True, +) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]: + """Create GPU tensor from either a new or existing CPU tensor. + + :param torch_tensor_cpu: Optional existing CPU tensor to reuse. If None, creates a new one. + :type torch_tensor_cpu: torch.Tensor, optional + """ + + # Create new CPU tensor + torch_tensor_cpu = cutlass_torch.matrix( + l, + mode0, + mode1, + is_mode0_major, + cutlass.Float32, + ) + + # Create GPU tensor from CPU tensor (new or existing) + cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like( + torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16 + ) + + # Mark tensor with element divisibility for 16B alignment + cute_tensor.mark_compact_shape_dynamic( + mode=0 if is_mode0_major else 1, + stride_order=(2, 1, 0) if is_mode0_major else (2, 0, 1), + divisibility=32 if dtype == cutlass.Float4E2M1FN else 16, + ) + + # omit stride for L mode as it is always 1 + stride = (1, mode0) if is_mode0_major else (mode1, 1) + + return ( + torch_tensor.data_ptr(), + torch_tensor, + cute_tensor, + torch_tensor_cpu, + stride, + ) + + +def create_tensors_abc_for_all_groups( + problem_sizes_mnkl: List[tuple[int, int, int, int]], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, +) -> tuple[ + List[List[int]], + List[List[torch.Tensor]], + List[tuple], + List[List[tuple]], + List[List[torch.Tensor]], +]: + ref_torch_fp32_tensors_abc = [] + torch_tensors_abc = [] + cute_tensors_abc = [] + strides_abc = [] + ptrs_abc = [] + + # Iterate through all groups and create tensors for each group + for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl): + # Create tensors A, B, C + ( + ptr_a, + torch_tensor_a, + cute_tensor_a, + ref_torch_fp32_tensor_a, + stride_mk_a, + ) = create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype) + + ( + ptr_b, + torch_tensor_b, + cute_tensor_b, + ref_torch_fp32_tensor_b, + stride_nk_b, + ) = create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype) + + ( + ptr_c, + torch_tensor_c, + cute_tensor_c, + ref_torch_fp32_tensor_c, + stride_mn_c, + ) = create_tensor_and_stride(l, m, n, c_major == "m", c_dtype) + + ref_torch_fp32_tensors_abc.append( + [ref_torch_fp32_tensor_a, ref_torch_fp32_tensor_b, ref_torch_fp32_tensor_c] + ) + + ptrs_abc.append([ptr_a, ptr_b, ptr_c]) + torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c]) + strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c]) + cute_tensors_abc.append( + ( + cute_tensor_a, + cute_tensor_b, + cute_tensor_c, + ) + ) + + return ( + ptrs_abc, + torch_tensors_abc, + cute_tensors_abc, + strides_abc, + ref_torch_fp32_tensors_abc, + ) + + +@cute.jit +def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + sf_ref_tensor: cute.Tensor, + sf_mma_tensor: cute.Tensor, +): + """Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout""" + # sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l) + # group to ((32, 4, rest_m), (4, rest_k), l) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) + for i in cutlass.range(cute.size(sf_ref_tensor)): + mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) + sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] + + +# Create scale factor tensor SFA/SFB +def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype): + def ceil_div(a, b): + return (a + b - 1) // b + + sf_k = ceil_div(k, sf_vec_size) + ref_shape = (l, mn, sf_k) + + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + + ref_permute_order = (1, 2, 0) + mma_permute_order = (3, 4, 1, 5, 2, 0) + + # Create f32 ref torch tensor (cpu) + ref_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + ref_shape, + torch.float32, + permute_order=ref_permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=1, + max_val=3, + ), + ) + + # Create f32 cute torch tensor (cpu) + cute_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + mma_shape, + torch.float32, + permute_order=mma_permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=0, + max_val=1, + ), + ) + + # convert ref f32 tensor to cute f32 tensor + cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + from_dlpack(ref_f32_torch_tensor_cpu), + from_dlpack(cute_f32_torch_tensor_cpu), + ) + cute_f32_torch_tensor = cute_f32_torch_tensor_cpu.cuda() + + # reshape makes memory contiguous + ref_f32_torch_tensor_cpu = ( + ref_f32_torch_tensor_cpu.permute(2, 0, 1) + .unsqueeze(-1) + .expand(l, mn, sf_k, sf_vec_size) + .reshape(l, mn, sf_k * sf_vec_size) + .permute(*ref_permute_order) + ) + # prune to mkl for reference check. + ref_f32_torch_tensor_cpu = ref_f32_torch_tensor_cpu[:, :k, :] + + # Create dtype cute torch tensor (cpu) + cute_tensor, cute_torch_tensor = cutlass_torch.cute_tensor_like( + cute_f32_torch_tensor_cpu, + dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + + # Convert f32 cute tensor to dtype cute tensor + cute_tensor = cutlass_torch.convert_cute_tensor( + cute_f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=True, + ) + # get pointer of the tensor + ptr = cute_torch_tensor.data_ptr() + return ref_f32_torch_tensor_cpu, ptr, cute_tensor, cute_torch_tensor + + +def create_tensors_sfasfb_for_all_groups( + problem_sizes_mnkl: List[tuple[int, int, int, int]], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, +) -> tuple[ + List[List[int]], + List[List[torch.Tensor]], + List[tuple], + List[List[torch.Tensor]], +]: + ptrs_sfasfb = [] + torch_tensors_sfasfb = [] + cute_tensors_sfasfb = [] + refs_sfasfb = [] + + # Iterate through all groups and create tensors for each group + for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl): + sfa_ref, ptr_sfa, sfa_tensor, sfa_torch = create_scale_factor_tensor( + l, m, k, sf_vec_size, sf_dtype + ) + sfb_ref, ptr_sfb, sfb_tensor, sfb_torch = create_scale_factor_tensor( + l, n, k, sf_vec_size, sf_dtype + ) + ptrs_sfasfb.append([ptr_sfa, ptr_sfb]) + torch_tensors_sfasfb.append([sfa_torch, sfb_torch]) + cute_tensors_sfasfb.append( + ( + sfa_tensor, + sfb_tensor, + ) + ) + refs_sfasfb.append([sfa_ref, sfb_ref]) + + return ( + ptrs_sfasfb, + torch_tensors_sfasfb, + cute_tensors_sfasfb, + refs_sfasfb, + ) + + +def run( + num_groups: int, + problem_sizes_mnkl: List[Tuple[int, int, int, int]], + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + tolerance: float = 1e-01, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + """Run SM100 grouped blockscaledGEMM example with specified configurations. + + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :return: Execution time of the GEMM kernel in microseconds + :rtype: float + """ + print("Running Blackwell Grouped GEMM test with:") + print(f"{num_groups} groups") + for i, (m, n, k, l) in enumerate(problem_sizes_mnkl): + print(f"Group {i}: {m}x{n}x{k}x{l}") + print(f"AB dtype: {ab_dtype}, SF dtype: {sf_dtype}, SF Vec size: {sf_vec_size}") + print(f"C dtype: {c_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") + + # Skip unsupported testcase + if not Sm100GroupedBlockScaledGemmKernel.can_implement( + ab_dtype, + sf_dtype, + sf_vec_size, + c_dtype, + mma_tiler_mn, + cluster_shape_mn, + problem_sizes_mnkl, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {problem_sizes_mnkl}, {a_major}, {b_major}, {c_major}" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(2025) + + # Create tensors A, B, C for all groups + ( + ptrs_abc, + torch_tensors_abc, + cute_tensors_abc, + strides_abc, + ref_f32_torch_tensors_abc, + ) = create_tensors_abc_for_all_groups( + problem_sizes_mnkl, + ab_dtype, + c_dtype, + a_major, + b_major, + c_major, + ) + # Create tensors SFA, SFB for all groups + ( + ptrs_sfasfb, + torch_tensors_sfasfb, + cute_tensors_sfasfb, + refs_f32_torch_tensors_sfasfb, + ) = create_tensors_sfasfb_for_all_groups( + problem_sizes_mnkl, + sf_dtype, + sf_vec_size, + ) + + # Choose A, B, C, SFA, SFB with the smallest size to create initial tensormaps + key_size_a = lambda item: item[1][0] * item[1][2] + key_size_b = lambda item: item[1][1] * item[1][2] + key_size_c = lambda item: item[1][0] * item[1][1] + # Find the indices of the groups with the smallest tensor sizes + min_a_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_a) + min_b_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_b) + min_c_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_c) + initial_cute_tensors_abc = [ + cute_tensors_abc[min_a_idx][0], # A with smallest (m, k) + cute_tensors_abc[min_b_idx][1], # B with smallest (n, k) + cute_tensors_abc[min_c_idx][2], # C with smallest (m, n) + ] + initial_cute_tensors_sfasfb = [ + cute_tensors_sfasfb[min_a_idx][0], # SFA with smallest (m, k)'s group + cute_tensors_sfasfb[min_b_idx][1], # SFB with smallest (n, k)'s group + ] + + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_max_active_clusters(1) + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + # Prepare tensormap buffer for each SM + num_tensormap_buffers = sm_count + tensormap_shape = ( + num_tensormap_buffers, + Sm100GroupedBlockScaledGemmKernel.num_tensormaps, + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8, + ) + tensor_of_tensormap, tensor_of_tensormap_torch = cutlass_torch.cute_tensor_like( + torch.empty(tensormap_shape, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + ) + + grouped_blockscaled_gemm = Sm100GroupedBlockScaledGemmKernel( + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + ) + + # layout (num_groups, 4):(4, 1) + ( + tensor_of_dim_size_mnkl, + tensor_of_dim_size_mnkl_torch, + ) = cutlass_torch.cute_tensor_like( + torch.tensor(problem_sizes_mnkl, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + # layout (num_groups, 3, 2):(6, 2, 1) + tensor_of_strides_abc, tensor_of_strides_abc_torch = cutlass_torch.cute_tensor_like( + torch.tensor(strides_abc, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + # layout (num_groups,3):(3, 1) + tensor_of_ptrs_abc, tensor_of_ptrs_abc_torch = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_abc, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + # layout (num_groups,2):(2, 1) + tensor_of_ptrs_sfasfb, tensor_of_ptrs_sfasfb_torch = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_sfasfb, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + # Compute total number of cluster tiles we need to compute for given grouped GEMM problem + def compute_total_num_clusters( + problem_sizes_mnkl: List[tuple[int, int, int, int]], + cluster_tile_shape_mn: tuple[int, int], + ) -> int: + total_num_clusters = 0 + for m, n, _, _ in problem_sizes_mnkl: + num_clusters_mn = tuple( + (x + y - 1) // y for x, y in zip((m, n), cluster_tile_shape_mn) + ) + total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) + return total_num_clusters + + # Compute cluster tile shape + def compute_cluster_tile_shape( + mma_tiler_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + ) -> tuple[int, int]: + cta_tile_shape_mn = [128, mma_tiler_mn[1]] + return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) + + cluster_tile_shape_mn = compute_cluster_tile_shape(mma_tiler_mn, cluster_shape_mn) + total_num_clusters = compute_total_num_clusters( + problem_sizes_mnkl, cluster_tile_shape_mn + ) + + # Initialize Stream + current_stream = cutlass_torch.default_stream() + + # Compile grouped GEMM kernel + compiled_grouped_gemm = cute.compile( + grouped_blockscaled_gemm, + initial_cute_tensors_abc[0], + initial_cute_tensors_abc[1], + initial_cute_tensors_abc[2], + initial_cute_tensors_sfasfb[0], + initial_cute_tensors_sfasfb[1], + num_groups, + tensor_of_dim_size_mnkl, + tensor_of_strides_abc, + tensor_of_ptrs_abc, + tensor_of_ptrs_sfasfb, + total_num_clusters, + tensor_of_tensormap, + max_active_clusters, + current_stream, + ) + + # reference check + if not skip_ref_check: + compiled_grouped_gemm( + initial_cute_tensors_abc[0], + initial_cute_tensors_abc[1], + initial_cute_tensors_abc[2], + initial_cute_tensors_sfasfb[0], + initial_cute_tensors_sfasfb[1], + tensor_of_dim_size_mnkl, + tensor_of_strides_abc, + tensor_of_ptrs_abc, + tensor_of_ptrs_sfasfb, + tensor_of_tensormap, + current_stream, + ) + print("Verifying results...") + + for i, ( + (a_ref, b_ref, c_ref), + (sfa_ref, sfb_ref), + (a_tensor, b_tensor, c_tensor), + (m, n, k, l), + ) in enumerate( + zip( + ref_f32_torch_tensors_abc, + refs_f32_torch_tensors_sfasfb, + cute_tensors_abc, + problem_sizes_mnkl, + ) + ): + ref_res_a = torch.einsum("mkl,mkl->mkl", a_ref, sfa_ref) + ref_res_b = torch.einsum("nkl,nkl->nkl", b_ref, sfb_ref) + ref = torch.einsum("mkl,nkl->mnl", ref_res_a, ref_res_b) + + print(f"checking group {i}") + c_ref_device = c_ref.cuda() + + cute.testing.convert( + c_tensor, + from_dlpack(c_ref_device, assumed_align=16).mark_layout_dynamic( + leading_dim=(1 if c_major == "n" else 0) + ), + ) + + c_ref = c_ref_device.cpu() + + if c_dtype in (cutlass.Float32, cutlass.Float16, cutlass.BFloat16): + torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02) + elif c_dtype in (cutlass.Float8E5M2, cutlass.Float8E4M3FN): + # Convert ref : f32 -> f8 -> f32 + ref_f8_ = torch.empty( + *(l, m, n), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + ref_f8 = from_dlpack(ref_f8_, assumed_align=16).mark_layout_dynamic( + leading_dim=1 + ) + ref_f8.element_type = c_dtype + ref_device = ref.permute(2, 0, 1).contiguous().permute(1, 2, 0).cuda() + ref_tensor = from_dlpack( + ref_device, assumed_align=16 + ).mark_layout_dynamic(leading_dim=1) + cute.testing.convert(ref_tensor, ref_f8) + cute.testing.convert(ref_f8, ref_tensor) + ref = ref_device.cpu() + torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02) + def generate_tensors(): + ( + ptrs_abc_workspace, + torch_tensors_abc_workspace, + cute_tensors_abc_workspace, + strides_abc_workspace, + _, + ) = create_tensors_abc_for_all_groups( + problem_sizes_mnkl, + ab_dtype, + c_dtype, + a_major, + b_major, + c_major, + ) + + ( + ptrs_sfasfb_workspace, + torch_tensors_sfasfb_workspace, + cute_tensors_sfasfb_workspace, + _, + ) = create_tensors_sfasfb_for_all_groups( + problem_sizes_mnkl, + sf_dtype, + sf_vec_size, + ) + + initial_cute_tensors_abc_workspace = [ + cute_tensors_abc_workspace[min_a_idx][0], # A with smallest (m, k) + cute_tensors_abc_workspace[min_b_idx][1], # B with smallest (n, k) + cute_tensors_abc_workspace[min_c_idx][2], # C with smallest (m, n) + ] + + initial_cute_tensors_sfasfb_workspace = [ + cute_tensors_sfasfb_workspace[min_a_idx][ + 0 + ], # SFA with smallest (m, k)'s group + cute_tensors_sfasfb_workspace[min_b_idx][ + 1 + ], # SFB with smallest (n, k)'s group + ] + + # Create new tensors for this workspace + tensor_of_strides_abc_workspace, _ = cutlass_torch.cute_tensor_like( + torch.tensor(strides_abc_workspace, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + tensor_of_ptrs_abc_workspace, _ = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_abc_workspace, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + tensor_of_ptrs_sfasfb_workspace, _ = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_sfasfb_workspace, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + tensormap_workspace, _ = cutlass_torch.cute_tensor_like( + torch.empty(tensormap_shape, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + ) + + return cute.testing.JitArguments( + initial_cute_tensors_abc_workspace[0], + initial_cute_tensors_abc_workspace[1], + initial_cute_tensors_abc_workspace[2], + initial_cute_tensors_sfasfb_workspace[0], + initial_cute_tensors_sfasfb_workspace[1], + tensor_of_dim_size_mnkl, + tensor_of_strides_abc_workspace, + tensor_of_ptrs_abc_workspace, + tensor_of_ptrs_sfasfb_workspace, + tensormap_workspace, + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + sum( + [ + sum( + [ + torch_tensor.numel() * torch_tensor.element_size() + for torch_tensor in group_tensors + ] + ) + for group_tensors in torch_tensors_abc + torch_tensors_sfasfb + ] + ) + + + # Add size of strides tensor + tensor_of_strides_abc_torch.numel() + * tensor_of_strides_abc_torch.element_size() + + + # Add size of ptrs tensor A, B, C + tensor_of_ptrs_abc_torch.numel() * tensor_of_ptrs_abc_torch.element_size() + + + # Add size of ptrs tensor SFA, SFB + tensor_of_ptrs_sfasfb_torch.numel() + * tensor_of_ptrs_sfasfb_torch.element_size() + + + # Add size of tensormap tensor + tensor_of_tensormap_torch.numel() * tensor_of_tensormap_torch.element_size() + ) + workspace_count = cute.testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = cute.testing.benchmark( + compiled_grouped_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + def parse_comma_separated_tuples(s: str) -> List[tuple[int, ...]]: + if s.strip().startswith("("): + # Split on ),( to separate tuples + tuples = s.strip("()").split("),(") + result = [] + tuple_len = None + + for t in tuples: + # Parse individual tuple + nums = [int(x.strip()) for x in t.split(",")] + + # Validate tuple length consistency + if tuple_len is None: + tuple_len = len(nums) + elif len(nums) != tuple_len: + raise argparse.ArgumentTypeError( + "All tuples must have the same length" + ) + + result.append(tuple(nums)) + return result + + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers or list of tuples" + ) + + parser = argparse.ArgumentParser( + description="Example of Grouped GEMM on Blackwell." + ) + parser.add_argument( + "--num_groups", + type=int, + default=2, + help="Number of groups", + ) + parser.add_argument( + "--problem_sizes_mnkl", + type=parse_comma_separated_tuples, + default=((128, 128, 128, 1), (128, 128, 128, 1)), + help="a tuple of problem sizes for each group (comma-separated tuples)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float4E2M1FN) + parser.add_argument("--sf_dtype", type=cutlass.dtype, default=cutlass.Float8E8M0FNU) + parser.add_argument("--sf_vec_size", type=int, default=16) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float16) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if ( + len(args.problem_sizes_mnkl) != 0 + and len(args.problem_sizes_mnkl) != args.num_groups + ): + parser.error("--problem_sizes_mnkl must contain exactly num_groups tuples") + + # l mode must be 1 for all groups + for _, _, _, l in args.problem_sizes_mnkl: + if l != 1: + parser.error("l must be 1 for all groups") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run( + args.num_groups, + args.problem_sizes_mnkl, + args.ab_dtype, + args.sf_dtype, + args.sf_vec_size, + args.c_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/grouped_gemm.py b/examples/python/CuTeDSL/blackwell/grouped_gemm.py index c279f758..50d455da 100644 --- a/examples/python/CuTeDSL/blackwell/grouped_gemm.py +++ b/examples/python/CuTeDSL/blackwell/grouped_gemm.py @@ -38,6 +38,7 @@ import cutlass import cutlass.cute as cute import cutlass.cute.testing as testing import cutlass.utils as utils +import cutlass.pipeline as pipeline from cutlass.cute.nvgpu import cpasync, tcgen05 import cutlass.utils.blackwell_helpers as sm100_utils import cutlass.torch as cutlass_torch @@ -152,12 +153,24 @@ class GroupedGemmKernel: self.threads_per_cta = 32 * len( (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) ) - # Set barrier id for cta sync, epilog sync, tmem ptr sync and tensormap update sync - self.cta_sync_bar_id = 0 - self.epilog_sync_bar_id = 1 - self.tmem_ptr_sync_bar_id = 2 - # Barrier ID used by MMA/TMA warps to signal A/B tensormap initialization completion - self.tensormap_ab_init_bar_id = 4 + # Set barrier for cta sync, epilog sync, tmem ptr sync and tensormap update sync + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)), + ) + # Barrier used by MMA/TMA warps to signal A/B tensormap initialization completion + self.tensormap_ab_init_barrier = pipeline.NamedBarrier( + barrier_id=4, + num_threads=32 * (len(self.epilog_warp_id) + 1), + ) self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") self.num_tma_load_bytes = 0 @@ -251,14 +264,6 @@ class GroupedGemmKernel: self.num_epi_stage, ) - tensor_smem_bytes = self._get_tensor_smem_bytes( - self.a_smem_layout_staged, - self.a_dtype, - self.b_smem_layout_staged, - self.b_dtype, - self.epi_smem_layout_staged, - self.c_dtype, - ) mbar_smem_bytes = self._get_mbar_smem_bytes( num_acc_stage=self.num_acc_stage, num_ab_stage=self.num_ab_stage, @@ -390,15 +395,12 @@ class GroupedGemmKernel: # Setup TMA store for C tma_atom_c = None tma_tensor_c = None - c_cta_v_layout = cute.composition( - cute.make_identity_layout(initial_c.shape), self.epi_tile - ) epi_smem_layout = cute.slice_(self.epi_smem_layout_staged, (None, None, 0)) tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), initial_c, epi_smem_layout, - c_cta_v_layout, + self.epi_tile, ) self.tile_sched_params, grid = self._compute_grid( @@ -558,8 +560,6 @@ class GroupedGemmKernel: ab_empty_mbar_ptr = storage.ab_empty_mbar_ptr.data_ptr() acc_full_mbar_ptr = storage.acc_full_mbar_ptr.data_ptr() acc_empty_mbar_ptr = storage.acc_empty_mbar_ptr.data_ptr() - tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr - tmem_holding_buf = storage.tmem_holding_buf # init barrier for loading A, B with TMA if warp_idx == self.epilog_warp_id[0]: @@ -579,13 +579,13 @@ class GroupedGemmKernel: acc_empty_mbar_ptr + acc_stage, 8 if use_2cta_instrs else 4 ) # Tensor memory dealloc barrier init - if use_2cta_instrs: - if warp_idx == self.tma_warp_id: - num_tmem_dealloc_threads = 32 - with cute.arch.elect_one(): - cute.arch.mbarrier_init( - tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads - ) + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) cute.arch.mbarrier_init_fence() # Cluster arrive after barrier init @@ -721,9 +721,7 @@ class GroupedGemmKernel: if cute.size(self.cluster_shape_mn) > 1: cute.arch.cluster_wait() else: - cute.arch.barrier( - barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta - ) + self.cta_sync_barrier.arrive_and_wait() # # Get tensormap buffer address @@ -826,10 +824,7 @@ class GroupedGemmKernel: # wait tensormap initialization complete before update if tensormap_init_done == False: if cutlass.const_expr(self.delegate_tensormap_ab_init): - cute.arch.barrier( - barrier_id=self.tensormap_ab_init_bar_id, - number_of_threads=64, - ) + self.tensormap_ab_init_barrier.arrive_and_wait() tensormap_manager.fence_tensormap_initialization() tensormap_init_done = True @@ -951,33 +946,13 @@ class GroupedGemmKernel: # Specialized MMA warp # if warp_idx == self.mma_warp_id: - # initialize tensormap A, B for TMA warp - if cutlass.const_expr(self.delegate_tensormap_ab_init): - tensormap_manager.init_tensormap_from_atom( - tma_atom_a, tensormap_a_init_ptr, self.mma_warp_id - ) - tensormap_manager.init_tensormap_from_atom( - tma_atom_b, tensormap_b_init_ptr, self.mma_warp_id - ) - # signal tensormap initialization has finished - cute.arch.barrier( - barrier_id=self.tensormap_ab_init_bar_id, number_of_threads=64 - ) # Bar sync for retrieve tmem ptr from shared mem - tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) - cute.arch.barrier( - barrier_id=self.tmem_ptr_sync_bar_id, - number_of_threads=tmem_ptr_read_threads, - ) + tmem.wait_for_alloc() # # Retrieving tensor memory ptr and make accumulator tensor # - tmem_ptr = cute.arch.retrieve_tmem_ptr( - self.acc_dtype, - alignment=16, - ptr_to_buffer_holding_addr=tmem_holding_buf, - ) + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) @@ -1019,22 +994,22 @@ class GroupedGemmKernel: # Peek (try_wait) AB buffer full for k_tile = 0 mma_rd_k_tile = cutlass.Int32(0) smem_rd_buffer = (num_prev_k_blk + mma_rd_k_tile) % self.num_ab_stage - need_check_rd_buffer_full = ( - mma_rd_k_tile < cur_k_tile_cnt and is_leader_cta - ) - mma_rd_ab_full_phase = ( - (num_prev_k_blk + mma_rd_k_tile) // self.num_ab_stage % 2 - ) - peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait( - need_check_rd_buffer_full, - ab_full_mbar_ptr + smem_rd_buffer, - mma_rd_ab_full_phase, - ) - - # - # Wait for accumulator buffer empty - # if is_leader_cta: + need_check_rd_buffer_full = ( + mma_rd_k_tile < cur_k_tile_cnt and is_leader_cta + ) + mma_rd_ab_full_phase = ( + (num_prev_k_blk + mma_rd_k_tile) // self.num_ab_stage % 2 + ) + peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait( + need_check_rd_buffer_full, + ab_full_mbar_ptr + smem_rd_buffer, + mma_rd_ab_full_phase, + ) + + # + # Wait for accumulator buffer empty + # acc_empty_phase = ( tile_sched.num_tiles_executed // self.num_acc_stage % 2 ^ 1 ) @@ -1042,25 +1017,24 @@ class GroupedGemmKernel: acc_empty_mbar_ptr + acc_buf_idx, acc_empty_phase ) - # - # Reset the ACCUMULATE field for each tile - # - tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) - # - # Mma mainloop - # - for k_tile in range(cur_k_tile_cnt): - mma_rd_k_tile_next = cutlass.Int32(k_tile + 1) - smem_rd_buffer_next = ( - num_prev_k_blk + mma_rd_k_tile_next - ) % self.num_ab_stage - mma_rd_ab_full_phase_next = ( - mma_rd_ab_full_phase ^ 1 - if smem_rd_buffer_next == 0 - else mma_rd_ab_full_phase - ) - if is_leader_cta: + # + # Mma mainloop + # + for k_tile in range(cur_k_tile_cnt): + mma_rd_k_tile_next = cutlass.Int32(k_tile + 1) + smem_rd_buffer_next = ( + num_prev_k_blk + mma_rd_k_tile_next + ) % self.num_ab_stage + mma_rd_ab_full_phase_next = ( + mma_rd_ab_full_phase ^ 1 + if smem_rd_buffer_next == 0 + else mma_rd_ab_full_phase + ) # Wait for AB buffer full if peek_ab_full_status == 0: cute.arch.mbarrier_wait( @@ -1090,25 +1064,24 @@ class GroupedGemmKernel: self.cta_group, ) - # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 - need_check_rd_buffer_full = ( - mma_rd_k_tile_next < cur_k_tile_cnt and is_leader_cta - ) + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + need_check_rd_buffer_full = ( + mma_rd_k_tile_next < cur_k_tile_cnt and is_leader_cta + ) - peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait( - need_check_rd_buffer_full, - ab_full_mbar_ptr + smem_rd_buffer_next, - mma_rd_ab_full_phase_next, - ) + peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait( + need_check_rd_buffer_full, + ab_full_mbar_ptr + smem_rd_buffer_next, + mma_rd_ab_full_phase_next, + ) - mma_rd_k_tile = mma_rd_k_tile_next - smem_rd_buffer = smem_rd_buffer_next - mma_rd_ab_full_phase = mma_rd_ab_full_phase_next + mma_rd_k_tile = mma_rd_k_tile_next + smem_rd_buffer = smem_rd_buffer_next + mma_rd_ab_full_phase = mma_rd_ab_full_phase_next - # - # Async arrive accumulator buffer full - # - if is_leader_cta: + # + # Async arrive accumulator buffer full + # with cute.arch.elect_one(): tcgen05.commit( acc_full_mbar_ptr + acc_buf_idx, @@ -1126,6 +1099,16 @@ class GroupedGemmKernel: # Specialized epilogue warps # if warp_idx < self.mma_warp_id: + # initialize tensormap A, B for TMA warp + if cutlass.const_expr(self.delegate_tensormap_ab_init): + tensormap_manager.init_tensormap_from_atom( + tma_atom_a, tensormap_a_init_ptr, self.epilog_warp_id[0] + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_b, tensormap_b_init_ptr, self.epilog_warp_id[0] + ) + # signal tensormap initialization has finished + self.tensormap_ab_init_barrier.arrive_and_wait() # initialize tensorap for C tensormap_manager.init_tensormap_from_atom( tma_atom_c, @@ -1133,30 +1116,17 @@ class GroupedGemmKernel: self.epilog_warp_id[0], ) # Alloc tensor memory buffer - if warp_idx == self.epilog_warp_id[0]: - cute.arch.alloc_tmem( - self.num_tmem_alloc_cols, - tmem_holding_buf, - is_two_cta=use_2cta_instrs, - ) + tmem.allocate(self.num_tmem_alloc_cols) # # Bar sync for retrieve tensor memory ptr from shared memory # - tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) - cute.arch.barrier( - barrier_id=self.tmem_ptr_sync_bar_id, - number_of_threads=tmem_ptr_read_threads, - ) + tmem.wait_for_alloc() # # Retrieving tensor memory ptr and make accumulator tensor # - tmem_ptr = cute.arch.retrieve_tmem_ptr( - self.acc_dtype, - alignment=16, - ptr_to_buffer_holding_addr=tmem_holding_buf, - ) + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) @@ -1172,7 +1142,7 @@ class GroupedGemmKernel: epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs ) - tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( tiled_copy_t2r, tTR_rC, epi_tidx, sC ) @@ -1303,11 +1273,7 @@ class GroupedGemmKernel: cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta, ) - epilog_threads = 32 * len(self.epilog_warp_id) - cute.arch.barrier( - barrier_id=self.epilog_sync_bar_id, - number_of_threads=epilog_threads, - ) + self.epilog_sync_barrier.arrive_and_wait() # # store C to global memory with TMA # @@ -1325,10 +1291,7 @@ class GroupedGemmKernel: cute.arch.cp_async_bulk_wait_group( self.num_epi_stage - 1, read=True ) - cute.arch.barrier( - barrier_id=self.epilog_sync_bar_id, - number_of_threads=epilog_threads, - ) + self.epilog_sync_barrier.arrive_and_wait() # # Async arrive accumulator buffer empty # @@ -1348,21 +1311,9 @@ class GroupedGemmKernel: # # Dealloc the tensor memory buffer # - if warp_idx == self.epilog_warp_id[0]: - cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) - epilog_threads = 32 * len(self.epilog_warp_id) - cute.arch.barrier( - barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads - ) - if warp_idx == self.epilog_warp_id[0]: - if use_2cta_instrs: - cute.arch.mbarrier_arrive( - tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 - ) - cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) - cute.arch.dealloc_tmem( - tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs - ) + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(tmem_ptr) # # Wait a/b buffer empty @@ -1417,7 +1368,7 @@ class GroupedGemmKernel: ) strides_tensor_gmem = strides_abc[(group_idx, tensor_index, None)] - strides_tensor_reg = cute.make_fragment( + strides_tensor_reg = cute.make_rmem_tensor( cute.make_layout(2), strides_abc.element_type, ) @@ -1507,7 +1458,7 @@ class GroupedGemmKernel: # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) # (T2R, T2R_M, T2R_N) - tTR_rAcc = cute.make_fragment( + tTR_rAcc = cute.make_rmem_tensor( tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype ) return tiled_copy_t2r, tTR_tAcc, tTR_rAcc @@ -1761,23 +1712,6 @@ class GroupedGemmKernel: else: raise ValueError(f"Invalid tensormap update mode: {tensormap_update_mode}") - @staticmethod - def _get_tensor_smem_bytes( - a_smem_layout_staged: cute.Layout, - a_dtype: Type[cutlass.Numeric], - b_smem_layout_staged: cute.Layout, - b_dtype: Type[cutlass.Numeric], - epi_smem_layout_staged: cute.Layout, - c_dtype: Type[cutlass.Numeric], - ) -> int: - """Compute the total SMEM consumption for tensor A, B and C.""" - ab_bytes = cute.size_in_bytes( - a_dtype, a_smem_layout_staged - ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged) - - epi_bytes = cute.size_in_bytes(c_dtype, epi_smem_layout_staged) - return ab_bytes + epi_bytes - @staticmethod def _compute_num_tmem_alloc_cols( tiled_mma: cute.TiledMma, @@ -1821,7 +1755,7 @@ def create_tensor_and_stride( is_dynamic_layout: bool = True, torch_tensor_cpu: torch.Tensor = None, ) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]: - """Create a GPU tensor from scratch or based on an existing CPU tensor. + """Create GPU tensor from either a new or existing CPU tensor. :param torch_tensor_cpu: Optional existing CPU tensor to reuse. If None, creates a new one. :type torch_tensor_cpu: torch.Tensor, optional @@ -1967,7 +1901,7 @@ def run( :return: Execution time of the GEMM kernel in microseconds :rtype: float """ - print(f"Running Blackwell Grouped GEMM test with:") + print("Running Blackwell Grouped GEMM test with:") print(f"{num_groups} groups") for i, (m, n, k, l) in enumerate(problem_sizes_mnkl): print(f"Group {i}: {m}x{n}x{k}x{l}") @@ -2102,6 +2036,7 @@ def run( is_dynamic_layout=False, assumed_align=16, ) + # layout (num_groups, 3, 2):(6, 2, 1) tensor_of_strides_abc, tensor_of_strides_abc_torch = cutlass_torch.cute_tensor_like( torch.tensor(strides_abc, dtype=torch.int32), diff --git a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py index 829d6b7e..afa8830b 100644 --- a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py +++ b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py @@ -86,9 +86,9 @@ class SSDKernel: cutlass.BFloat16, }, "Do not support other I/O types." assert acc_dtype in {cutlass.Float32}, "Do not support other ACC types." - assert cumsum_delta_dtype in { - cutlass.Float32 - }, "Do not support other cumsum types." + assert cumsum_delta_dtype in {cutlass.Float32}, ( + "Do not support other cumsum types." + ) assert not (not has_d and d_has_hdim), "D cannot have Hdim if has_d is False" # Hardcode default setting @@ -129,10 +129,18 @@ class SSDKernel: self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") # Named barriers - self.pre_inter_sync_bar_id = 1 - self.epilog_sync_bar_id = 2 - self.pre_intra_sync_bar_id = 3 - self.tmem_dealloc_sync_bar_id = 4 + self.pre_inter_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=len(self.pre_inter_warp_id) * 32, + ) + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=len(self.epilog_warp_id) * 32, + ) + self.tmem_dealloc_sync_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=self.threads_per_cta, + ) # Number of registers used by each warp self.num_regs_uniform_warps = 24 @@ -467,15 +475,12 @@ class SSDKernel: ) # TMA store for y - y_cta_v_layout = cute.composition( - cute.make_identity_layout(y.shape), self.epi_tile - ) y_smem_layout = cute.slice_(self.y_smem_layout, (None, None, 0)) tma_atom_y, tma_tensor_y = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), y, y_smem_layout, - y_cta_v_layout, + self.epi_tile, ) # TMA store for fstate(p) @@ -512,7 +517,9 @@ class SSDKernel: d_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore # Intra1 acc stage barriers intra1_acc_full: cute.struct.MemRange[cutlass.Int64, self.intra1_acc_stages] # type: ignore - intra1_acc_empty: cute.struct.MemRange[cutlass.Int64, self.intra1_acc_stages] # type: ignore + intra1_acc_empty: cute.struct.MemRange[ + cutlass.Int64, self.intra1_acc_stages + ] # type: ignore # Internal stage barriers intra2_q_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore intra2_q_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore @@ -811,23 +818,22 @@ class SSDKernel: if cute.size(self.cluster_shape_mnk) > 1: cute.arch.cluster_wait() - # Alloc tmem buffer - if warp_idx == self.epilog_warp_id[0]: - cute.arch.alloc_tmem( - self.num_tmem_cols_total, - smem_storage.tmem_holding_buf, - is_two_cta=self.use_2cta_instrs, - ) + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=0, + num_threads=self.threads_per_cta, + ) + tmem = utils.TmemAllocator( + smem_storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + ) + tmem.allocate(self.num_tmem_cols_total) - # Bar sync before retrieving tmem ptr from shared mem - cute.arch.barrier() + # Barrier before retrieve tensor memory ptr from shared memory + tmem.wait_for_alloc() # Retrieve tmem ptr - tmem_ptr_base = cute.arch.retrieve_tmem_ptr( - self.acc_dtype, - alignment=16, - ptr_to_buffer_holding_addr=smem_storage.tmem_holding_buf, - ) + tmem_ptr_base = tmem.retrieve_ptr(self.acc_dtype) # Specialized TMA load Delta/CumsumDelta/X warp if warp_idx == self.tma_deltas_x_d_warp_id: @@ -1579,7 +1585,7 @@ class SSDKernel: ) = self.pre_inter_tmem_load_and_partition_p(local_tidx, tInter1, smem_pt) # Make fragment for register to hold P after post-processing (in acc dtype) - tState = cute.make_fragment(tTR_rP.shape, self.acc_dtype) + tState = cute.make_rmem_tensor(tTR_rP.shape, self.acc_dtype) # Make tiledCopy and partition smem/register tensor for smem store INTER2_P # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N) @@ -1621,7 +1627,7 @@ class SSDKernel: tma_p_pipeline = pipeline.PipelineTmaStore.create( num_stages=self.internal_stages, producer_group=pipeline.CooperativeGroup( - pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id) ), ) @@ -1808,10 +1814,7 @@ class SSDKernel: cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta, ) - cute.arch.barrier( - barrier_id=self.pre_inter_sync_bar_id, - number_of_threads=len(self.pre_inter_warp_id) * 32, - ) + self.pre_inter_sync_barrier.arrive_and_wait() if local_warp_idx == 0: # TMA store P @@ -1824,10 +1827,7 @@ class SSDKernel: tma_p_pipeline.producer_commit() tma_p_pipeline.producer_acquire() - cute.arch.barrier( - barrier_id=self.pre_inter_sync_bar_id, - number_of_threads=len(self.pre_inter_warp_id) * 32, - ) + self.pre_inter_sync_barrier.arrive_and_wait() tma_p_pipeline.producer_tail() # Advance to next tile @@ -2085,7 +2085,7 @@ class SSDKernel: local_tidx, smem_y, tiled_t2r_inter2 ) - tRS_rCompute = cute.make_fragment(tRS_rY.shape, self.acc_dtype) + tRS_rCompute = cute.make_rmem_tensor(tRS_rY.shape, self.acc_dtype) tiled_s2r_x = None tSR_sX = None @@ -2128,7 +2128,7 @@ class SSDKernel: tma_y_pipeline = pipeline.PipelineTmaStore.create( num_stages=self.output_stages, producer_group=pipeline.CooperativeGroup( - pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128 + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id) ), ) @@ -2328,10 +2328,7 @@ class SSDKernel: space=cute.arch.SharedSpace.shared_cta, ) # Sync before TMA store - cute.arch.barrier( - barrier_id=self.epilog_sync_bar_id, - number_of_threads=len(self.epilog_warp_id) * 32, - ) + self.epilog_sync_barrier.arrive_and_wait() # Async arrive Delta/INTRA2_ACC/INTER2_ACC buffer empty if ( @@ -2366,10 +2363,7 @@ class SSDKernel: # Wait for TMA store tma_y_pipeline.producer_acquire() # Sync before smem store - cute.arch.barrier( - barrier_id=self.epilog_sync_bar_id, - number_of_threads=len(self.epilog_warp_id) * 32, - ) + self.epilog_sync_barrier.arrive_and_wait() # Advance deltas/intra2_acc/inter2_acc consumer states deltas_consumer_state.advance() @@ -2406,22 +2400,12 @@ class SSDKernel: # Producer tail for TMA store Y tma_y_pipeline.producer_tail() + # Release tensor memory allocation lock + tmem.relinquish_alloc_permit() + # Sync before deallocating tmem + self.tmem_dealloc_sync_barrier.arrive_and_wait() # Dealloc tmem buffer - if warp_idx == self.epilog_warp_id[0]: - cute.arch.barrier( - barrier_id=self.tmem_dealloc_sync_bar_id, - number_of_threads=self.threads_per_cta, - ) - cute.arch.dealloc_tmem( - tmem_ptr_base, - self.num_tmem_cols_total, - is_two_cta=self.use_2cta_instrs, - ) - else: - cute.arch.barrier_arrive( - barrier_id=self.tmem_dealloc_sync_bar_id, - number_of_threads=self.threads_per_cta, - ) + tmem.free(tmem_ptr_base) return @@ -2597,7 +2581,7 @@ class SSDKernel: len([self.mma_intra_warp_id, self.mma_inter_warp_id]), ) x_consumer_group_async = pipeline.CooperativeGroup( - pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128 + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id) ) return pipeline.PipelineTmaMultiConsumersAsync.create( num_stages=self.input_stages, @@ -2616,7 +2600,7 @@ class SSDKernel: pipeline.Agent.Thread, len([self.mma_intra_warp_id]) ) b_consumer_group_async = pipeline.CooperativeGroup( - pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id) ) return pipeline.PipelineTmaMultiConsumersAsync.create( num_stages=self.input_stages, @@ -2651,9 +2635,6 @@ class SSDKernel: len( [*self.pre_inter_warp_id, *self.pre_intra_warp_id, *self.epilog_warp_id] ), - len( - [*self.pre_inter_warp_id, *self.pre_intra_warp_id, *self.epilog_warp_id] - ), ) return pipeline.PipelineTmaAsync.create( @@ -2672,9 +2653,7 @@ class SSDKernel: pipeline.Agent.Thread, len([self.tma_deltas_x_d_warp_id]) ) d_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - len(self.epilog_warp_id), - len(self.epilog_warp_id), + pipeline.Agent.Thread, len(self.epilog_warp_id) ) return pipeline.PipelineTmaAsync.create( @@ -2690,7 +2669,7 @@ class SSDKernel: pipeline.Agent.Thread, len([self.mma_intra_warp_id]) ) intra1_acc_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, 32 * len(self.pre_intra_warp_id), 128 + pipeline.Agent.Thread, 32 * len(self.pre_intra_warp_id) ) return pipeline.PipelineUmmaAsync.create( num_stages=self.intra1_acc_stages, @@ -2701,7 +2680,7 @@ class SSDKernel: def make_and_init_intra2_q_pipeline(self, intra2_q_full_mbar_ptr): intra2_q_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, 32 * len(self.pre_intra_warp_id), 128 + pipeline.Agent.Thread, 32 * len(self.pre_intra_warp_id) ) intra2_q_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, len([self.mma_intra_warp_id]) @@ -2718,7 +2697,7 @@ class SSDKernel: pipeline.Agent.Thread, len([self.mma_intra_warp_id]) ) intra2_acc_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128 + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id) ) return pipeline.PipelineUmmaAsync.create( num_stages=self.internal_stages, @@ -2729,7 +2708,7 @@ class SSDKernel: def make_and_init_inter1_b_pipeline(self, inter1_b_full_mbar_ptr): inter1_b_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id) ) inter1_b_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, len([self.mma_inter_warp_id]) @@ -2746,7 +2725,7 @@ class SSDKernel: pipeline.Agent.Thread, len([self.mma_inter_warp_id]) ) inter1_acc_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id) ) return pipeline.PipelineUmmaAsync.create( num_stages=self.internal_stages, @@ -2757,7 +2736,7 @@ class SSDKernel: def make_and_init_inter2_p_pipeline(self, inter2_p_full_mbar_ptr): inter2_p_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id) ) inter2_p_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, len([self.mma_inter_warp_id]) @@ -2774,7 +2753,7 @@ class SSDKernel: pipeline.Agent.Thread, len([self.mma_inter_warp_id]) ) inter2_acc_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128 + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id) ) return pipeline.PipelineUmmaAsync.create( num_stages=self.internal_stages, @@ -3035,7 +3014,7 @@ class SSDKernel: # Partition tmem/register tensor for tensor memory store INTRA2_Q # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ...) - tRT_rQ = cute.make_fragment( + tRT_rQ = cute.make_rmem_tensor( cute.slice_(thr_r2t_q.partition_S(tCrQ).shape, (None, None, None, None, 0)), dtype, ) @@ -3049,10 +3028,10 @@ class SSDKernel: self, tTR_rQ, tQrDeltaA_Row, tQrDeltaA_Col, tQrDelta, tCoord, tRT_rQ ): # Make tmp acc type fragments - tCrDeltaA_Row = cute.make_fragment(tQrDeltaA_Row.shape, self.acc_dtype) - tCrDeltaA_Col = cute.make_fragment(tQrDeltaA_Col.shape, self.acc_dtype) - tCrDelta = cute.make_fragment(tQrDelta.shape, self.acc_dtype) - tCompute = cute.make_fragment(tRT_rQ.shape, self.acc_dtype) + tCrDeltaA_Row = cute.make_rmem_tensor(tQrDeltaA_Row.shape, self.acc_dtype) + tCrDeltaA_Col = cute.make_rmem_tensor(tQrDeltaA_Col.shape, self.acc_dtype) + tCrDelta = cute.make_rmem_tensor(tQrDelta.shape, self.acc_dtype) + tCompute = cute.make_rmem_tensor(tRT_rQ.shape, self.acc_dtype) # Combine tTR_rQ/tCrDeltaA_Row/tCrDeltaA_Col/tCrDelta tCrDeltaA_Row.store(tQrDeltaA_Row.load().to(self.acc_dtype)) @@ -3127,7 +3106,7 @@ class SSDKernel: tBsB_s2r = thr_s2r_b.partition_S(smem_bt) # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) - tBrB_s2r = cute.make_fragment( + tBrB_s2r = cute.make_rmem_tensor( cute.slice_(tBsB_s2r.shape, (None, None, None, 0)), dtype, ) @@ -3167,7 +3146,7 @@ class SSDKernel: # Make register fragments for smem load/store of Delta/DeltaA # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) - tBrDelta_s2r = cute.make_fragment(tBsDelta_s2r[smem_tile_coord].shape, dtype) + tBrDelta_s2r = cute.make_rmem_tensor(tBsDelta_s2r[smem_tile_coord].shape, dtype) return s2r_atom_delta, tBsDelta_s2r, tBrDelta_s2r def pre_inter_tmem_load_and_partition_p(self, local_tidx, tInter1, smem_pt): @@ -3195,7 +3174,7 @@ class SSDKernel: tTR_s = thr_t2r.partition_D(smem_tensor) # Make register fragments for tmem load INTER1_ACC # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) - tTR_r = cute.make_fragment( + tTR_r = cute.make_rmem_tensor( tTR_s.shape, dtype, ) @@ -3213,7 +3192,7 @@ class SSDKernel: # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE) tRS_sP = thr_r2s_p.partition_D(smem_pt) # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N) - tRS_rP = cute.make_fragment( + tRS_rP = cute.make_rmem_tensor( cute.slice_(tRS_sP.shape, (None, None, None, 0)), self.io_dtype ) return tiled_r2s_p, tRS_rP, tRS_sP @@ -3239,10 +3218,10 @@ class SSDKernel: def pre_inter_scale_bt_with_delta( self, tBrB_s2r, tBrDelta_s2r, tBrDeltaA_s2r, last_column ): - tCompute = cute.make_fragment(tBrB_s2r.shape, self.acc_dtype) - tBrB_Compute = cute.make_fragment(tBrB_s2r.shape, self.acc_dtype) - tBrDelta_Compute = cute.make_fragment(tBrDelta_s2r.shape, self.acc_dtype) - tBrDeltaA_Compute = cute.make_fragment(tBrDeltaA_s2r.shape, self.acc_dtype) + tCompute = cute.make_rmem_tensor(tBrB_s2r.shape, self.acc_dtype) + tBrB_Compute = cute.make_rmem_tensor(tBrB_s2r.shape, self.acc_dtype) + tBrDelta_Compute = cute.make_rmem_tensor(tBrDelta_s2r.shape, self.acc_dtype) + tBrDeltaA_Compute = cute.make_rmem_tensor(tBrDeltaA_s2r.shape, self.acc_dtype) tBrB_Compute.store(tBrB_s2r.load().to(self.acc_dtype)) tBrDelta_Compute.store(tBrDelta_s2r.load().to(self.acc_dtype)) @@ -3323,7 +3302,7 @@ class SSDKernel: # (R2S_ATOM, R2S_M, R2S_N, EPI_M, EPI_N, INPUT_STAGES) tSR_sX = thr_s2r_x.partition_S(cute.flat_divide(smem_xt, epi_tile)) # (R2S_ATOM, R2S_M, R2S_N) - tSR_rX = cute.make_fragment( + tSR_rX = cute.make_rmem_tensor( cute.slice_(tSR_sX.shape, (None, None, None, 0, 0, 0)), dtype ) return tiled_s2r_x, tSR_sX, tSR_rX @@ -3360,7 +3339,7 @@ def run( has_d = fuse_scale_d != "none" d_has_hdim = fuse_scale_d == "vector" - print(f"Running B100 Mamba2 SSD with:") + print("Running B100 Mamba2 SSD with:") print(f"GBEHCDLN: {gbehcdln}") print( f"Input/Output dtype: {io_dtype}, Intermediate delta dtype: {cumsum_delta_dtype}, Acc dtype: {acc_dtype}" @@ -3405,7 +3384,7 @@ def run( # Build torch_dtype torch tensor torch_dtype = cutlass_torch.dtype(dtype) - dst_tensor = ref_tensor.to(torch_dtype).cuda() + dst_tensor = ref_tensor.to(dtype=torch_dtype).cuda() cute_tensor = from_dlpack(dst_tensor, assumed_align=16) for mode in dynamic_modes: cute_tensor = cute_tensor.mark_compact_shape_dynamic( diff --git a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_reference.py b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_reference.py index ced5c6f2..ffeffa56 100644 --- a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_reference.py +++ b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_reference.py @@ -212,7 +212,7 @@ def analyze_relative_diffs(actual, expected): ) # Print max relative difference info - print(f"Maximum relative difference:") + print("Maximum relative difference:") print(f"Position: {max_rel_diff_pos}") print(f"Value: {max_rel_diff:.6e}") print(f"Actual value: {actual.flatten()[max_rel_diff_pos]}") @@ -236,7 +236,7 @@ def analyze_relative_diffs(actual, expected): print(f"Elements with rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)") else: print( - f"Elements with {rtol_levels[i-1]:.0e} < rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)" + f"Elements with {rtol_levels[i - 1]:.0e} < rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)" ) # Print elements exceeding the largest rtol diff --git a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_tile_scheduler.py b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_tile_scheduler.py index 544b4772..6d40479e 100644 --- a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_tile_scheduler.py +++ b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_tile_scheduler.py @@ -29,7 +29,6 @@ from typing import Tuple from cutlass.cutlass_dsl import ( - Boolean, Integer, Int32, min, @@ -121,8 +120,8 @@ class Mamba2SSDTileScheduler: ) # called by host - @dsl_user_op @staticmethod + @dsl_user_op def create( params: Mamba2SSDTileSchedulerParams, block_idx: Tuple[Integer, Integer, Integer], diff --git a/examples/python/CuTeDSL/blackwell/mixed_input_gemm.py b/examples/python/CuTeDSL/blackwell/mixed_input_gemm.py new file mode 100644 index 00000000..7d8a4b07 --- /dev/null +++ b/examples/python/CuTeDSL/blackwell/mixed_input_gemm.py @@ -0,0 +1,3113 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from enum import Enum, auto +from math import log2, ceil +from typing import Optional, Union + +import torch +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.cute.testing as testing +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.runtime import from_dlpack + +""" +A mixed-input GEMM example for the NVIDIA Blackwell SM100 architecture using CUTE DSL. + +This example demonstrates an implementation of mixed-input GEMM using a TMA plus Blackwell SM100 TensorCore +warp-specialized persistent kernel. + +The inputs A and B have different data types. In this example, it's assumed that A is the narrow-precision tensor +and B holds data with a wider precision. +MMA will work in the wide precision of tensor B and tensor A will be transformed to the wide precision of tensor B +following 1 of the 2 possible modes as follows: + +1. convert-only mode: + C = type_convert(A) x B + +In convert-only mode, tensor A is directly converted to the wide precision of tensor B. + +2. convert-scale mode: + C = (type_convert(A) * scale) x B + +In convert-scale mode, tensor A is first converted to the wide precision of tensor B and then scaled by the scale tensor. +The scale tensor is in the same precision as tensor B. +The mode is determined by tensor A's data type as follows: +- if tensor A is in int8 or uint8, convert-only mode is used. +- if tensor A is in int4, convert-scale mode is used. + +The output tensor C could have the same precision as tensor B or fp32. + +To run this example: + +.. code-block:: bash + + python examples/blackwell/mixed_input_gemm.py \ + --a_dtype Int8 --b_dtype BFloat16 \ + --scale_granularity_m 0 --scale_granularity_k 0 \ + --c_dtype BFloat16 --acc_dtype Float32 \ + --mma_tiler_mnk 128,128,64 --cluster_shape_mn 1,1 \ + --mnkl 256,512,8192,1 + +Input A and B have int8 and bf16 data types, respectively. The Blackwell tcgen05 MMA tile shape +is specified as (128,128,64) and the cluster shape is (1,1). The MMA accumulator and output data type +are set as fp32 and bf16, respectively. As tensor A is int8, convert-only mode is used. +scale_granularity_m and scale_granularity_k are set as 0 for convert-only mode. + +Here is an example of running convert-scale mode: + +.. code-block:: bash + + python examples/blackwell/mixed_input_gemm.py \ + --a_dtype Int4 --b_dtype BFloat16 \ + --scale_granularity_m 1 --scale_granularity_k 256 \ + --c_dtype BFloat16 --acc_dtype Float32 \ + --mma_tiler_mnk 256,128,128 --cluster_shape_mn 2,1 \ + --use_2cta_instrs --use_tma_store \ + --mnkl 1024,8192,6144,16 + +Input A and B have int4 and bf16 data types, respectively. The scale granularity is set as (1,256), +which means each element along the m mode of tensor A has its own scale element and 256 contiguous elements +along the k mode share the same scale element. There is no scale reuse along the L mode. If the GEMM shape is +(M, N, K, L), then the scale tensor shape is (M // scale_granularity_m, K // scale_granularity_k, L), +which is (1024, 6144/256, 16) in this example. +The Blackwell tcgen05 MMA tile shape is specified as (256,128,128) and tcgen05 2CTA feature is enabled. +The cluster shape is (2,1). The MMA accumulator and output data type are set as fp32 and bf16, respectively. +As tensor A is int4, the convert-scale mode is used. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/mixed_input_gemm.py \ + --a_dtype Int8 --b_dtype BFloat16 \ + --scale_granularity_m 0 --scale_granularity_k 0 \ + --c_dtype BFloat16 --acc_dtype Float32 \ + --mma_tiler_mnk 128,128,64 --cluster_shape_mn 1,1 \ + --mnkl 256,512,8192,1 \ + --warmup_iterations 1 --iterations 10 --skip_ref_check + +Besides the requirements from the Blackwell dense GEMM example, there are some constraints for this example: +* The narrow-precision is constrained to be int8, uint8, or int4 and the other data type is bf16 or f16. +* Output data types could only be fp16, bf16, or fp32. +* The scale_granularity_m must be 1 currently. +* The scale_granularity_k must be a multiple of mma_tiler_k and also be divisible by gemm_k. +* The scale tensor must be in m-major mode. +* OOB tiles are not allowed when TMA store is disabled +""" + + +class TransformMode(Enum): + """ + An enumeration for the possible transform modes of a mixed-input GEMM. + """ + + ConvertOnly = auto() + ConvertScale = auto() + + +class MixedInputGemmKernel: + """ + Mixed-input GEMM kernel for NVIDIA Blackwell SM100 architecture. + + This kernel supports GEMM operations where input tensors A and B have different + data types, with tensor A being transformed to the precision of tensor B before + matrix multiplication. + + :param scale_granularity_m: Number of elements sharing the same scale factor along the M mode + :type scale_granularity_m: int + :param scale_granularity_k: Number of elements sharing the same scale factor along the K mode + :type scale_granularity_k: int + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mnk: Shape of the Matrix Multiply-Accumulate (MMA) tile (M, N, K) + :type mma_tiler_mnk: tuple[int, int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: tuple[int, int] + :param use_tma_store: Whether to use Tensor Memory Access (TMA) for storing results + :type use_tma_store: bool + """ + + def __init__( + self, + scale_granularity_m: int, + scale_granularity_k: int, + acc_dtype: type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mnk: tuple[int, int, int], + cluster_shape_mn: tuple[int, int], + use_tma_store: bool, + ): + """ + Initializes the mixed-input GEMM kernel with a specified configuration. + """ + # Scale granularity defines how many elements share the same scale factor + # along the M and K modes. + self.scale_granularity_m = scale_granularity_m + self.scale_granularity_k = scale_granularity_k + # Set transform mode + if cutlass.const_expr( + self.scale_granularity_m == 0 and self.scale_granularity_k == 0 + ): + self.scale_mode = TransformMode.ConvertOnly + else: + self.scale_mode = TransformMode.ConvertScale + self.acc_dtype = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + self.mma_tiler = mma_tiler_mnk + self.use_tma_store = use_tma_store + + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + # Set specialized warp ids + self.epilog_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.scale_tma_warp_id = 6 + # 4 warps to do the transformation + self.transform_warp_id = ( + 8, + 9, + 10, + 11, + ) + self.threads_per_cta = 32 * ( + max( + ( + self.mma_warp_id, + self.tma_warp_id, + self.scale_tma_warp_id, + *self.epilog_warp_id, + *self.transform_warp_id, + ) + ) + + 1 + ) + + # Set barrier id for cta sync, epilogue sync, tmem ptr sync, and transform sync + self.epilog_sync_barrier = pipeline.NamedBarrier( + 1, 32 * len(self.epilog_warp_id) + ) + self.tmem_ptr_sync_barrier = pipeline.NamedBarrier(2, self.threads_per_cta) + self.transform_sync_barrier = pipeline.NamedBarrier( + 3, 32 * len(self.transform_warp_id) + ) + self.cta_sync_barrier = pipeline.NamedBarrier(4, self.threads_per_cta) + + self.smem_buffer_align_bytes = 1024 + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Deduce where the transformed A tensor is stored + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/scale/B/C stage counts in shared memory + - Setting up transformed A stage count in shared memory or tensor memory + - Computing A/transformed A/scale/B/C memory layout + - Computing tensor memory allocation columns + """ + # Deduce where the transformed A tensor is stored, shared memory(SMEM) or tensor memory(TMEM) + self.transform_a_source = self._get_transform_a_source(self.a_major_mode) + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.mma_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + self.transform_a_source, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + if cutlass.const_expr(self.use_tma_store): + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + else: + self.epi_tile = self.cta_tile_shape_mnk[:2] + + # Compute tensor memory(TMEM) columns and stages for each pipeline + ( + self.num_load2trans_stage, + self.num_scale_load2trans_stage, + self.num_trans2mma_stage, + self.num_acc_stage, + self.num_c_stage, + self.num_acc_tmem_cols, + self.num_a_tmem_cols, + ) = self._compute_stages_and_tmem_cols( + tiled_mma, + self.mma_tiler, + self.cta_tile_shape_mnk, + self.epi_tile, + self.a_dtype, + self.b_dtype, + self.c_dtype, + self.c_layout, + self.transform_a_source, + self.scale_granularity_m, + self.scale_granularity_k, + self.smem_buffer_align_bytes, + self.use_tma_store, + self.scale_mode, + ) + # Ensure load2trans and trans2mma pipelines share same stage count, + # so we can use same pipeline stage index to slice both A and B buffers + if cutlass.const_expr(self.num_load2trans_stage != self.num_trans2mma_stage): + self.num_load2trans_stage = min( + self.num_load2trans_stage, self.num_trans2mma_stage + ) + self.num_trans2mma_stage = self.num_load2trans_stage + + # Align TMEM columns for allocation + # TMEM allocation requires power-of-2 column alignment + # and must meet minimum allocation requirements + self.num_tmem_alloc_cols = MixedInputGemmKernel.align_up( + self.num_acc_tmem_cols + self.num_a_tmem_cols, + cute.arch.SM100_TMEM_MIN_ALLOC_COLUMNS, + ) + self.num_tmem_alloc_cols = 2 ** (ceil(log2(self.num_tmem_alloc_cols))) + # Get smem layout for C tensor when TMA store is enabled + self.c_smem_layout_staged = ( + sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + if self.use_tma_store + else None + ) + # Get smem layout for A, transformed A, and B + ( + self.smem_layout_a, + self.smem_layout_a_transform, + self.smem_layout_b, + ) = self._compute_smem_layout( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.num_load2trans_stage, + self.num_trans2mma_stage, + ) + # Get smem layout for scale tensor + self.smem_layout_scale_per_stage = None + self.smem_layout_scale = None + if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale): + # Get smem layout for scale tensor + ( + self.smem_layout_scale_per_stage, + self.smem_layout_scale, + ) = self.get_smem_layout_scale() + + def _validate_inputs( + self, + a: cute.Tensor, + a_scale: Optional[cute.Tensor], + b: cute.Tensor, + c: cute.Tensor, + ) -> None: + """ + Validates input tensors and their properties. + + :param a: Input tensor A. + :type a: cute.Tensor + :param a_scale: Scale tensor for tensor A (None for ConvertOnly mode). + :type a_scale: Optional[cute.Tensor] + :param b: Input tensor B. + :type b: cute.Tensor + :param c: Output tensor C. + :type c: cute.Tensor + :raises ValueError: If inputs don't meet kernel requirements. + """ + # Validate scale tensor major mode + if cutlass.const_expr( + self.scale_mode == TransformMode.ConvertScale + and utils.LayoutEnum.from_tensor(a_scale).mma_major_mode() + != tcgen05.OperandMajorMode.MN + ): + raise ValueError("scale_major_mode should be m-major") + + @cute.jit + def __call__( + self, + a: cute.Tensor, + a_scale: Optional[cute.Tensor], # None for ConvertOnly mode + b: cute.Tensor, + c: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + ): + """ + Executes the Mixed Input GEMM operation. + + This method sets up the kernel parameters, computes the grid size, + defines the shared storage, and launches the kernel. + + The execution steps are as follows: + - Setup static attributes before smem/grid/tma computation. + - Setup TMA load/store atoms and tensors. + - Compute grid size with regard to hardware constraints. + - Define shared storage for kernel. + - Launch the kernel synchronously. + + :param a: Input tensor A. + :type a: cute.Tensor + :param a_scale: Scale tensor for tensor A (None for ConvertOnly mode). + :type a_scale: Optional[cute.Tensor] + :param b: Input tensor B. + :type b: cute.Tensor + :param c: Output tensor C. + :type c: cute.Tensor + :param max_active_clusters: Maximum number of active clusters to launch. + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream to launch the kernel on. + :type stream: cuda.CUstream + """ + self.a_dtype: type[cutlass.Numeric] = a.element_type + self.a_scale_dtype: type[cutlass.Numeric] = ( + a_scale.element_type + if self.scale_mode is TransformMode.ConvertScale + else None + ) + self.b_dtype: type[cutlass.Numeric] = b.element_type + self.c_dtype: type[cutlass.Numeric] = c.element_type + self.mma_dtype = self.b_dtype + + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.scale_major_mode = ( + utils.LayoutEnum.from_tensor(a_scale).mma_major_mode() + if self.scale_mode is TransformMode.ConvertScale + else None + ) + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale): + # Get gmem layout for scale tensor + self.gmem_layout_scale = self.get_gmem_layout_scale(a.shape) + + # Validate inputs + self._validate_inputs(a, a_scale, b, c) + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.mma_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + self.transform_a_source, + ) + # Set up gmem copy atoms for A, scale, and B + a_op = self._get_tma_atom_kind(self.is_a_mcast, self.use_2cta_instrs, False) + b_op = self._get_tma_atom_kind(self.is_b_mcast, self.use_2cta_instrs, True) + a_scale_op = a_op + # Deduce TMA copy atom and TMA tensor for A, scale, and B + smem_layout_a_per_stage = cute.slice_(self.smem_layout_a, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + smem_layout_a_per_stage, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if a.element_type is cutlass.Float32 else None + ), + ) + + tma_atom_scale, tma_tensor_scale = None, None + if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale): + # Partition smem layout for scale tensor to make it compatible with TMA atom + smem_layout_for_tma_atom = cute.get( + tiled_mma._thrfrg_A(self.smem_layout_scale_per_stage.outer), mode=[1] + ) + # ((MMA_M, MMA_K), REST_M, REST_K) + smem_layout_for_tma_atom = cute.dice( + smem_layout_for_tma_atom, + (1, (1,) * cute.rank(self.smem_layout_scale_per_stage.outer)), + ) + tma_atom_scale, tma_tensor_scale = cute.nvgpu.make_tiled_tma_atom_A( + a_scale_op, + cute.make_tensor(a_scale.iterator, self.gmem_layout_scale), + smem_layout_for_tma_atom, + # (SCALE_M, 1, SCALE_K) + (self.scale_tile_shape[0], 1, self.scale_tile_shape[1]), + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 + if a_scale.element_type is cutlass.Float32 + else None + ), + ) + + smem_layout_b_per_stage = cute.slice_(self.smem_layout_b, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + smem_layout_b_per_stage, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + # Calculate copy size for tensor A, B, and scale + a_copy_size = cute.size_in_bytes(self.a_dtype, smem_layout_a_per_stage) + b_copy_size = cute.size_in_bytes(self.b_dtype, smem_layout_b_per_stage) + a_scale_copy_size = ( + cute.size_in_bytes(self.a_scale_dtype, self.smem_layout_scale_per_stage) + if self.scale_mode is TransformMode.ConvertScale + else 0 + ) + + self.num_tma_load_bytes_a = a_copy_size + self.num_tma_load_bytes_b = b_copy_size * cute.size(tiled_mma.thr_id.shape) + self.num_tma_load_bytes_scale = a_scale_copy_size + self.tile_sched_params, grid = self._compute_grid( + c, + self.cta_tile_shape_mnk, + self.cluster_shape_mn, + max_active_clusters, + ) + + tma_atom_c = None + tma_tensor_c = None + c_smem_size = 0 + if cutlass.const_expr(self.use_tma_store): + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + self.epi_tile, + ) + c_smem_size = cute.cosize(self.c_smem_layout_staged.outer) + + # Shared memory structure + a_smem_size = cute.cosize(self.smem_layout_a.outer) + b_smem_size = cute.cosize(self.smem_layout_b.outer) + a_transform_smem_size = ( + cute.cosize(self.smem_layout_a_transform.outer) + if self.transform_a_source == tcgen05.OperandSource.SMEM + else 0 + ) + a_scale_smem_size = ( + cute.cosize(self.smem_layout_scale.outer) + if self.scale_mode is TransformMode.ConvertScale + else 0 + ) + + @cute.struct + class SharedStorage: + a_load2trans_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_load2trans_stage + ] + a_load2trans_empty_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_load2trans_stage + ] + a_scale_load2trans_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_scale_load2trans_stage + ] + a_scale_load2trans_empty_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_scale_load2trans_stage + ] + a_trans2mma_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_trans2mma_stage + ] + a_trans2mma_empty_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_trans2mma_stage + ] + b_load2mma_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_load2trans_stage + ] + b_load2mma_empty_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_load2trans_stage + ] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # Tensor buffers + # (EPI_TILE_M, EPI_TILE_N, STAGE) + smem_C: cute.struct.Align[ + cute.struct.MemRange[self.c_dtype, c_smem_size], + self.smem_buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + smem_A: cute.struct.Align[ + cute.struct.MemRange[self.a_dtype, a_smem_size], + self.smem_buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + smem_B: cute.struct.Align[ + cute.struct.MemRange[self.b_dtype, b_smem_size], + self.smem_buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + smem_A_transform: cute.struct.Align[ + cute.struct.MemRange[self.mma_dtype, a_transform_smem_size], + self.smem_buffer_align_bytes, + ] + # (MMA, MMA_M_SCALE, MMA_K_SCALE, STAGE) + smem_A_scale: cute.struct.Align[ + cute.struct.MemRange[self.mma_dtype, a_scale_smem_size], + self.smem_buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch kernel + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_scale, + tma_tensor_scale, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c if self.use_tma_store else c, + self.cluster_layout_vmnk, + self.smem_layout_a, + self.smem_layout_scale, + self.smem_layout_a_transform, + self.smem_layout_b, + self.c_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_s: Optional[cute.CopyAtom], + mS_mkl: Optional[cute.Tensor], + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout: cute.ComposedLayout, + scale_smem_layout: cute.ComposedLayout, + a_smem_layout_transform: cute.ComposedLayout, + b_smem_layout: cute.ComposedLayout, + c_smem_layout_staged: cute.ComposedLayout, + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + ): + """ + GPU device kernel performing the Persistent Mixed-Input GEMM computation. + """ + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + # Prefetch TMA descriptors + if warp_idx == self.epilog_warp_id[0]: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale): + cpasync.prefetch_descriptor(tma_atom_s) + if cutlass.const_expr(self.use_tma_store): + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + bidx, bidy, bidz = cute.arch.block_idx() + # Compute how many k_tiles share the same scale + num_k_tiles_per_scale = self.scale_granularity_k // self.cta_tile_shape_mnk[2] + + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + tidx, _, _ = cute.arch.thread_idx() + + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + # Initialize load2transform pipeline, which tracks the dependencies between TMA's loading + # of A and B, and the transformation of A and MMA's consumption + transform_thread_idx = ( + tidx - 32 * self.transform_warp_id[0] + if tidx >= 32 * self.transform_warp_id[0] + else tidx + ) + a_load2trans_pipeline = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.a_load2trans_full_mbar_ptr.data_ptr(), + num_stages=self.num_load2trans_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_mcast_ctas_a * len(self.transform_warp_id), + ), + tx_count=self.num_tma_load_bytes_a, + cta_layout_vmnk=cluster_layout_vmnk, + tidx=transform_thread_idx, + mcast_mode_mn=(1, 0), # multicast for A will only happen on the M-mode + ) + # Initialize scale_load2trans pipeline, which tracks the dependencies between TMA's loading + # of scale, and the transformation of A + scale_load2trans_pipeline = None + if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale): + num_producers_a_scale = self.num_mcast_ctas_a + scale_load2trans_pipeline = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.a_scale_load2trans_full_mbar_ptr.data_ptr(), + num_stages=self.num_scale_load2trans_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, + num_producers_a_scale + * len(self.transform_warp_id) + * num_k_tiles_per_scale, + ), + tx_count=self.num_tma_load_bytes_scale, + cta_layout_vmnk=cluster_layout_vmnk, + tidx=transform_thread_idx, + mcast_mode_mn=( + 1, + 0, + ), # multicast for scale_a will only happen on the M-mode + ) + # Initialize transform2mma pipeline, which tracks the dependencies between the transformation + # of A and MMA's consumption of transformed A + cta_v_size = cute.size(cluster_layout_vmnk, mode=[0]) + trans2mma_pipeline = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.a_trans2mma_full_mbar_ptr.data_ptr(), + num_stages=self.num_trans2mma_stage, + producer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.transform_warp_id) * cta_v_size, + ), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + cta_layout_vmnk=cluster_layout_vmnk, + ) + # Initialize pipeline for tensor B load to MMA + # MMA warp informs TMA warp to proceed to load next tile of B tensor + b_load2mma_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.b_load2mma_full_mbar_ptr.data_ptr(), + num_stages=self.num_load2trans_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.num_mcast_ctas_b + ), + tx_count=self.num_tma_load_bytes_b, + cta_layout_vmnk=cluster_layout_vmnk, + mcast_mode_mn=(0, 1), # multicast for B will only happen on the N-mode + ) + # Initialize accumulator pipeline, which tracks the dependencies between + # MMA's computation of accumulators and epilogue warps' consumption of accumulators + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, cta_v_size * len(self.epilog_warp_id) + ), + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_ptr_sync_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + if cutlass.const_expr(cute.size(self.cluster_shape_mn) > 1): + cute.arch.cluster_arrive_relaxed() + + # Setup smem tensor A/scale/B/C + sC = ( + storage.smem_C.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + if self.use_tma_store + else None + ) + sA_input = storage.smem_A.get_tensor( + a_smem_layout.outer, swizzle=a_smem_layout.inner + ) + sS_input = ( + storage.smem_A_scale.get_tensor( + scale_smem_layout.outer, swizzle=scale_smem_layout.inner + ) + if self.scale_mode is TransformMode.ConvertScale + else None + ) + sB_input = storage.smem_B.get_tensor( + b_smem_layout.outer, swizzle=b_smem_layout.inner + ) + sA_transform = None + # Get smem tensor for transformed A when transform_a_source is SMEM + if cutlass.const_expr(self.transform_a_source == tcgen05.OperandSource.SMEM): + sA_transform = storage.smem_A_transform.get_tensor( + a_smem_layout_transform.outer, swizzle=a_smem_layout_transform.inner + ) + + # Compute multicast mask for A/B buffer full + a_full_mcast_mask = None + b_full_mcast_mask = None + s_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + # scale tensor share the same multicast mask with A tensor + s_full_mcast_mask = a_full_mcast_mask + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # local_tile partition global tensors + # (bM, bK, loopM, loopK, loopL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bM, bK, loopM, loopK, loopL) + gS_mkl = ( + cute.local_tile( + mS_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + if self.scale_mode is TransformMode.ConvertScale + else None + ) + # (bN, bK, loopN, loopK, loopL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, loopM, loopN, loopL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # Partition global tensor for TiledMMA_A/B/C + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) + tCgS = ( + thr_mma.partition_A(gS_mkl) + if self.scale_mode is TransformMode.ConvertScale + else None + ) + # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) + tCgC = thr_mma.partition_C(gC_mnl) + + # Setup copy atom to load A from shared memory for further transformation + copy_atom_a_input = ( + cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self.a_dtype, num_bits_per_copy=32 + ) + if self.scale_mode is TransformMode.ConvertScale + else None + ) + a_smem_shape = tiled_mma.partition_shape_A( + cute.dice(self.mma_tiler, (1, None, 1)) + ) + # Setup copy atom to store transformed A into tensor memory or shared memory + copy_atom_a_transform = self._get_copy_atom_a_transform( + self.mma_dtype, + self.use_2cta_instrs, + self.transform_a_source, + a_smem_shape, + self.a_dtype, + ) + + # Partition global/shared tensor for TMA load A/B + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA_input, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + + tCsS = None + tSsS = None + tSgS = None + if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale): + # (MMA, MMA_M, MMA_K, STAGE) + tCsS = thr_mma.partition_A(sS_input) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tSsS, tSgS = self.scale_tma_partition( + tCsS, + tCgS, + tma_atom_s, + block_in_cluster_coord_vmnk, + a_cta_layout, + ) + + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB_input, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB_input) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # Cluster wait before TMEM alloc and ensure pipelines are ready + if cutlass.const_expr(cute.size(self.cluster_shape_mn) > 1): + cute.arch.cluster_wait() + else: + self.cta_sync_barrier.arrive_and_wait() + + # TMEM allocation + tmem.allocate(self.num_tmem_alloc_cols) + tmem.wait_for_alloc() + # Get the pointer to the TMEM buffer + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + accumulators = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + tCrA = None + if cutlass.const_expr(self.transform_a_source == tcgen05.OperandSource.TMEM): + tmem_ptr_transform = cute.recast_ptr( + accumulators.iterator + self.num_acc_tmem_cols, dtype=self.mma_dtype + ) + tCrA = cute.make_tensor( + tmem_ptr_transform, + tiled_mma.make_fragment_A(a_smem_layout_transform.outer).layout, + ) + else: + tCrA = tiled_mma.make_fragment_A(sA_transform) + + # Specialized TMA load warp for A/B tensor + if warp_idx == self.tma_warp_id: + # Persistent tile scheduling loop + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, (bidx, bidy, bidz), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + a_load2trans_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_load2trans_stage + ) + b_load2mma_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_load2trans_stage + ) + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + a_load2trans_producer_state.reset_count() + peek_load2trans_empty_status = cutlass.Boolean(1) + if a_load2trans_producer_state.count < k_tile_cnt: + peek_load2trans_empty_status = ( + a_load2trans_pipeline.producer_try_acquire( + a_load2trans_producer_state + ) + ) + b_load2mma_producer_state.reset_count() + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + a_load2trans_pipeline.producer_acquire( + a_load2trans_producer_state, peek_load2trans_empty_status + ) + b_load2mma_pipeline.producer_acquire(b_load2mma_producer_state) + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA_slice[(None, a_load2trans_producer_state.count)], + tAsA[(None, a_load2trans_producer_state.index)], + tma_bar_ptr=a_load2trans_pipeline.producer_get_barrier( + a_load2trans_producer_state + ), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, b_load2mma_producer_state.count)], + tBsB[(None, b_load2mma_producer_state.index)], + tma_bar_ptr=b_load2mma_pipeline.producer_get_barrier( + b_load2mma_producer_state + ), + mcast_mask=b_full_mcast_mask, + ) + a_load2trans_pipeline.producer_commit(a_load2trans_producer_state) + b_load2mma_pipeline.producer_commit(b_load2mma_producer_state) + a_load2trans_producer_state.advance() + b_load2mma_producer_state.advance() + if a_load2trans_producer_state.count < k_tile_cnt: + peek_load2trans_empty_status = ( + a_load2trans_pipeline.producer_try_acquire( + a_load2trans_producer_state + ) + ) + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # Wait A/B buffer empty + a_load2trans_pipeline.producer_tail(a_load2trans_producer_state) + b_load2mma_pipeline.producer_tail(b_load2mma_producer_state) + + # Specialized TMA load for scale tensor + if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale): + if warp_idx == self.scale_tma_warp_id: + # Persistent tile scheduling loop + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, (bidx, bidy, bidz), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + scale_load2trans_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_scale_load2trans_stage + ) + scale_k_tile_cnt = cute.size(mS_mkl.layout.shape[1][1]) + + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + # ((atom_v, rest_v), RestK) + tSgS_slice = tSgS[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # Filter zeros in rest mode + rest_filtered = cute.filter_zeros(tSgS_slice[(0, None)].layout) + tSgS_slice_filtered = cute.make_tensor( + tSgS_slice.iterator, + cute.make_layout( + (tSgS_slice.layout[0].shape, rest_filtered.shape), + stride=(tSgS_slice.layout[0].stride, rest_filtered.stride), + ), + ) + + scale_load2trans_producer_state.reset_count() + peek_scale_load2trans_empty_status = cutlass.Boolean(1) + if scale_load2trans_producer_state.count < scale_k_tile_cnt: + peek_scale_load2trans_empty_status = ( + scale_load2trans_pipeline.producer_try_acquire( + scale_load2trans_producer_state + ) + ) + for k_tile in cutlass.range(0, scale_k_tile_cnt, 1, unroll=1): + scale_load2trans_pipeline.producer_acquire( + scale_load2trans_producer_state, + peek_scale_load2trans_empty_status, + ) + # TMA load scale + cute.copy( + tma_atom_s, + tSgS_slice_filtered[ + (None, scale_load2trans_producer_state.count) + ], + tSsS[(None, scale_load2trans_producer_state.index)], + tma_bar_ptr=scale_load2trans_pipeline.producer_get_barrier( + scale_load2trans_producer_state + ), + mcast_mask=s_full_mcast_mask, + ) + + scale_load2trans_producer_state.advance() + peek_scale_load2trans_empty_status = cutlass.Boolean(1) + if scale_load2trans_producer_state.count < scale_k_tile_cnt: + peek_scale_load2trans_empty_status = ( + scale_load2trans_pipeline.producer_try_acquire( + scale_load2trans_producer_state + ) + ) + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # Wait scale buffer empty + scale_load2trans_pipeline.producer_tail(scale_load2trans_producer_state) + + # Specialized transform warps + if warp_idx >= self.transform_warp_id[0]: + transform_local_tidx = tidx - 32 * self.transform_warp_id[0] + # Partition tensors for transform input and output and set up the copy atom + # used for loading and storing transformed A tensor + ( + src_copy_a, + dst_copy_a, + tAsA_input, + tAsA_transform, + ) = self.transform_partition( + self.transform_a_source, + self.scale_mode, + copy_atom_a_input, + copy_atom_a_transform, + sA_input, + ( + tCrA + if self.transform_a_source == tcgen05.OperandSource.TMEM + else sA_transform + ), + transform_local_tidx, + ) + # make fragment for input A and transformed A + tArA = cute.make_rmem_tensor( + cute.select(tAsA_input.layout, mode=[0, 1, 2, 3]).shape, + dtype=tAsA_input.element_type, + ) + tArA_transform = cute.make_rmem_tensor( + cute.select(tAsA_input.layout, mode=[0, 1, 2, 3]).shape, + dtype=self.mma_dtype, + ) + # Partition scale tensor + smem_thr_copy_S = None + tSsS_trans = None + tSrS_copy = None + tSrS = None + if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale): + smem_thr_copy_S, tSsS_trans, tSrS_copy, tSrS = self.scale_partition( + src_copy_a, tCsS, transform_local_tidx, self.mma_dtype + ) + assert cute.size(tSrS, mode=[0]) == cute.size(tArA, mode=[0]), ( + "tSrS and tArA have different leading dimension" + ) + assert cute.size(tSrS) == cute.size(tArA), ( + "tSrS and tArA have different shape" + ) + # Make all modes except mode0 into loops + tArA_load = cute.group_modes(tArA, 1, cute.rank(tArA)) + tSrS_load = ( + cute.group_modes(tSrS, 1, cute.rank(tSrS)) + if self.scale_mode is TransformMode.ConvertScale + else None + ) + tArA_transform_store = cute.group_modes( + tArA_transform, 1, cute.rank(tArA_transform) + ) + + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, (bidx, bidy, bidz), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + a_load2trans_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, + self.num_load2trans_stage, + ) + scale_load2trans_consumer_state = ( + pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, + self.num_scale_load2trans_stage, + ) + if self.scale_mode is TransformMode.ConvertScale + else None + ) + trans2mma_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, + self.num_trans2mma_stage, + ) + while work_tile.is_valid_tile: + a_load2trans_consumer_state.reset_count() + peek_load2trans_full_status = cutlass.Boolean(1) + if a_load2trans_consumer_state.count < k_tile_cnt: + peek_load2trans_full_status = ( + a_load2trans_pipeline.consumer_try_wait( + a_load2trans_consumer_state + ) + ) + if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale): + scale_load2trans_consumer_state.reset_count() + trans2mma_producer_state.reset_count() + peek_trans2mma_empty_status = cutlass.Boolean(1) + if trans2mma_producer_state.count < k_tile_cnt: + peek_trans2mma_empty_status = ( + trans2mma_pipeline.producer_try_acquire( + trans2mma_producer_state + ) + ) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + a_load2trans_pipeline.consumer_wait( + a_load2trans_consumer_state, peek_load2trans_full_status + ) + # Load A from shared memory + cute.autovec_copy( + tAsA_input[ + (None, None, None, None, a_load2trans_consumer_state.index) + ], + tArA, + ) + if cutlass.const_expr( + self.scale_mode == TransformMode.ConvertScale + ): + scale_load2trans_pipeline.consumer_wait( + scale_load2trans_consumer_state + ) + trans2mma_pipeline.producer_acquire( + trans2mma_producer_state, peek_trans2mma_empty_status + ) + # load scale tensor when needed + if cutlass.const_expr( + self.scale_mode == TransformMode.ConvertScale + ): + if k_tile % num_k_tiles_per_scale == 0: + tSsS_slice = tSsS_trans[ + ( + None, + None, + None, + None, + scale_load2trans_consumer_state.index, + ) + ] + tSsS_slice_filtered = cute.make_tensor( + tSsS_slice.iterator, + cute.filter_zeros(tSsS_slice.layout), + ) + cute.autovec_copy(tSsS_slice_filtered, tSrS_copy) + + for idx in cutlass.range_constexpr(cute.size(tArA_load, mode=[1])): + # Load tensor A and convert it to mma dtype + tensor_transformed = ( + tArA_load[(None, idx)].load().to(self.mma_dtype) + ) + if cutlass.const_expr( + self.scale_mode == TransformMode.ConvertScale + ): + scale = cute.TensorSSA( + tSrS_load[(None, idx)].load(), + tensor_transformed.shape, + self.mma_dtype, + ) + # Apply scale + tensor_transformed = tensor_transformed * scale + tArA_transform_store[(None, idx)].store(tensor_transformed) + # Store transformed A to tensor memory or shared memory + if cutlass.const_expr(dst_copy_a is not None): + cute.copy( + dst_copy_a, + tArA_transform, + tAsA_transform[ + (None, None, None, None, trans2mma_producer_state.index) + ], + ) + else: + cute.autovec_copy( + tArA_transform, + tAsA_transform[ + (None, None, None, None, trans2mma_producer_state.index) + ], + ) + # Ensure all transform threads have finished the copy and reached the fence + self.transform_sync_barrier.arrive_and_wait() + if cutlass.const_expr( + self.transform_a_source == tcgen05.OperandSource.TMEM + ): + cute.arch.fence_view_async_tmem_store() + else: + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + # Signal the completion of transformation + trans2mma_pipeline.producer_commit(trans2mma_producer_state) + # Signal the completion of using A and scale tensors + a_load2trans_pipeline.consumer_release(a_load2trans_consumer_state) + if cutlass.const_expr( + self.scale_mode == TransformMode.ConvertScale + ): + scale_load2trans_pipeline.consumer_release( + scale_load2trans_consumer_state + ) + if (k_tile + 1) % num_k_tiles_per_scale == 0: + scale_load2trans_consumer_state.advance() + + a_load2trans_consumer_state.advance() + trans2mma_producer_state.advance() + if a_load2trans_consumer_state.count < k_tile_cnt: + peek_load2trans_full_status = ( + a_load2trans_pipeline.consumer_try_wait( + a_load2trans_consumer_state + ) + ) + if trans2mma_producer_state.count < k_tile_cnt: + peek_trans2mma_empty_status = ( + trans2mma_pipeline.producer_try_acquire( + trans2mma_producer_state + ) + ) + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # Wait a_transform buffer empty + trans2mma_pipeline.producer_tail(trans2mma_producer_state) + + # Specialized MMA warp + if warp_idx == self.mma_warp_id: + tCtAcc_base = accumulators + # Persistent tile scheduling loop + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, (bidx, bidy, bidz), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + trans2mma_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_trans2mma_stage + ) + b_load2mma_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_load2trans_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + b_load2mma_consumer_state.reset_count() + trans2mma_consumer_state.reset_count() + peek_trans2mma_full_status = cutlass.Boolean(1) + if is_leader_cta: + if trans2mma_consumer_state.count < k_tile_cnt: + peek_trans2mma_full_status = ( + trans2mma_pipeline.consumer_try_wait( + trans2mma_consumer_state + ) + ) + acc_pipeline.producer_acquire(acc_producer_state) + + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + # Mma mainloop + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + trans2mma_pipeline.consumer_wait( + trans2mma_consumer_state, peek_trans2mma_full_status + ) + b_load2mma_pipeline.consumer_wait(b_load2mma_consumer_state) + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + trans2mma_consumer_state.index, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + trans2mma_pipeline.consumer_release(trans2mma_consumer_state) + b_load2mma_pipeline.consumer_release(b_load2mma_consumer_state) + trans2mma_consumer_state.advance() + b_load2mma_consumer_state.advance() + peek_trans2mma_full_status = cutlass.Boolean(1) + if trans2mma_consumer_state.count < k_tile_cnt: + peek_trans2mma_full_status = ( + trans2mma_pipeline.consumer_try_wait( + trans2mma_consumer_state + ) + ) + # Async arrive accumulator buffer full + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # Wait for accumulator buffer empty + acc_pipeline.producer_tail(acc_producer_state) + + # Specialized epilogue warps + if warp_idx < self.mma_warp_id: + epi_tidx = tidx + tCtAcc_base = accumulators + # Partition for epilogue + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_r2s = None + simt_atom = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC_partitioned = None + tTR_gC_partitioned = None + if cutlass.const_expr(self.use_tma_store): + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC + ) + else: + ( + simt_atom, + tTR_rC, + tTR_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition( + epi_tidx, tiled_copy_t2r, tCgC, epi_tile, sC + ) + # Persistent tile scheduling loop + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, (bidx, bidy, bidz), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + c_pipeline = None + if cutlass.const_expr(self.use_tma_store): + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + bSG_gC = None + tTR_gC = None + if cutlass.const_expr(self.use_tma_store): + bSG_gC = bSG_gC_partitioned[(None, None, None, *mma_tile_coord_mnl)] + else: + tTR_gC = tTR_gC_partitioned[ + (None, None, None, None, None, *mma_tile_coord_mnl) + ] + + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + # Wait for accumulator buffer full + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + if cutlass.const_expr(self.use_tma_store): + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + else: + tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC)) + + # Store accumulator to global memory in subtiles + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in cutlass.range(subtile_cnt): + # Load accumulator from tensor memory buffer to register + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + if cutlass.const_expr(self.use_tma_store): + # Convert to C type + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = acc_vec.to(self.c_dtype) + tRS_rC.store(acc_vec) + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + # Store C to shared memory + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.epilog_sync_barrier.arrive_and_wait() + # TMA store C to global memory + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + else: + # Convert to C type + acc_vec = tTR_rAcc.load() + acc_vec = acc_vec.to(self.c_dtype) + tTR_rC.store(acc_vec) + # Store C to global memory + cute.autovec_copy( + tTR_rC, tTR_gC[(None, None, None, subtile_idx)] + ) + # Async arrive accumulator buffer empty + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # Dealloc the tensor memory buffer + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + if cutlass.const_expr(self.use_tma_store): + c_pipeline.producer_tail() + + def scale_tma_partition( + self, + tCsS: cute.Tensor, + tCgS: cute.Tensor, + tma_atom_s: cute.CopyAtom, + block_in_cluster_coord_vmnk: cute.Coord, + scale_cta_layout: cute.Layout, + ) -> tuple[cute.Tensor, cute.Tensor]: + """ + Perform TMA partition for scale tensor. + + This method partitions the gobal memory and shared memory buffer for scale tensor for TMA load. + + :param tCsS: Input scale shared memory tensor + :type tCsS: cute.Tensor + :param tCgS: Input scale global memory tensor + :type tCgS: cute.Tensor + :param tma_atom_s: TMA copy atom for scale tensor + :type tma_atom_s: cute.CopyAtom + :param block_in_cluster_coord_vmnk: CTA coord in the cluster + :type block_in_cluster_coord_vmnk: cute.Coord + :param scale_cta_layout: Layout of CTA from the view of the scale tensor + :type scale_cta_layout: cute.Layout + + :return: A tuple containing (tSsS, tSgS) where: + - tSsS: Partitioned scale tensor in shared memory + - tSgS: Partitioned scale tensor in global memory + :rtype: tuple[cute.Tensor, cute.Tensor] + """ + tSsS, tSgS = cpasync.tma_partition( + tma_atom_s, + block_in_cluster_coord_vmnk[2], + scale_cta_layout, + cute.group_modes(tCsS, 0, 3), + cute.group_modes(tCgS, 0, 3), + ) + # Add rest_v mode + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tSsS = cute.make_tensor( + tSsS.iterator, + cute.make_layout( + ((tSsS.layout.shape[0], 1), *tSsS.layout.shape[1:]), + stride=( + (tSsS.layout.stride[0], 0), + *tSsS.layout.stride[1:], + ), + ), + ) + tSgS = cute.make_tensor( + tSgS.iterator, + cute.make_layout( + ((tSgS.layout.shape[0], 1), *tSgS.layout.shape[1:]), + stride=( + (tSgS.layout.stride[0], 0), + *tSgS.layout.stride[1:], + ), + ), + ) + return tSsS, tSgS + + def transform_partition( + self, + transform_a_source: tcgen05.OperandSource, + scale_mode: TransformMode, + copy_atom_a_input: cute.CopyAtom, + copy_atom_a_transform: cute.CopyAtom, + sA_input: cute.Tensor, + A_transform: cute.Tensor, + transform_local_tidx: cutlass.Int32, + ) -> tuple[cute.TiledCopy, cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Partition tensors for transform input and output. + + This method sets up the copy atoms and partitions the shared/tensor memory + for the transformation of tensor A. + + :param transform_a_source: Where the transformed tensor A is stored (TMEM or SMEM) + :type transform_a_source: tcgen05.OperandSource + :param scale_mode: The transform mode (ConvertOnly or ConvertScale) + :type scale_mode: TransformMode + :param copy_atom_a_input: Copy atom for loading A from shared memory + :type copy_atom_a_input: cute.CopyAtom + :param copy_atom_a_transform: Copy atom for storing transformed A + :type copy_atom_a_transform: cute.CopyAtom + :param sA_input: Input tensor A in shared memory + :type sA_input: cute.Tensor + :param A_transform: Transformed tensor A in tensor or shared memory + :type A_transform: cute.Tensor + :param transform_local_tidx: Local thread index for transformation warps + :type transform_local_tidx: cutlass.Int32 + + :return: A tuple containing (src_copy_a, dst_copy_a, tAsA_input, tA_transform) where: + - src_copy_a: Tiled copy for source tensor + - dst_copy_a: Tiled copy for destination tensor + - tAsA_input: Partitioned input tensor A + - tA_transform: Partitioned transformed tensor A + :rtype: tuple[cute.TiledCopy, cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + if cutlass.const_expr(transform_a_source == tcgen05.OperandSource.TMEM): + if cutlass.const_expr( + cute.size(A_transform, mode=[0, 0]) == 128 + and cute.size(sA_input, mode=[0, 0]) == 64 + ): + tensor_input = cute.make_tensor( + sA_input.iterator, + cute.logical_product( + sA_input.layout, + ((cute.make_layout(2, stride=0), None), None, None, None), + ), + ) + else: + tensor_input = sA_input + reg2tmem_tiled_copy = tcgen05.make_tmem_copy( + copy_atom_a_transform, A_transform[(None, None, None, 0)] + ) + thr_reg2tmem_tiled_copy = reg2tmem_tiled_copy.get_slice( + transform_local_tidx + ) + partitioned_tensor_input = thr_reg2tmem_tiled_copy.partition_S(tensor_input) + partitioned_tensor_transform = thr_reg2tmem_tiled_copy.partition_D( + A_transform + ) + src_copy_a = ( + cute.make_tiled_copy_S(copy_atom_a_input, reg2tmem_tiled_copy) + if scale_mode is TransformMode.ConvertScale + else None + ) + dst_copy_a = reg2tmem_tiled_copy + tAsA_input = partitioned_tensor_input + tA_transform = partitioned_tensor_transform + elif cutlass.const_expr(transform_a_source == tcgen05.OperandSource.SMEM): + # Construct tiled_copy satisfying 8 contiguous elts per copy atom + reg2smem_tiled_copy = cute.make_cotiled_copy( + copy_atom_a_transform, + cute.make_layout((128, 8), stride=(8, 1)), + A_transform[(None, None, None, 0)].layout, + ) + thr_reg2smem_tiled_copy = reg2smem_tiled_copy.get_slice( + transform_local_tidx + ) + partitioned_tensor_input = thr_reg2smem_tiled_copy.partition_S(sA_input) + partitioned_tensor_transform = thr_reg2smem_tiled_copy.partition_D( + A_transform + ) + src_copy_a = ( + cute.make_tiled_copy_S(copy_atom_a_input, reg2smem_tiled_copy) + if scale_mode is TransformMode.ConvertScale + else None + ) + # auto-vec copy is enough for copy from register to shared memory here + dst_copy_a = None + tAsA_input = partitioned_tensor_input + tA_transform = partitioned_tensor_transform + return src_copy_a, dst_copy_a, tAsA_input, tA_transform + + def scale_partition( + self, + src_copy_a: cute.TiledCopy, + tCsS: cute.Tensor, + transform_local_tidx: cutlass.Int32, + mma_dtype: type[cutlass.Numeric], + ) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor]: + """ + Partition the scale tensor for transformation. + + This method prepares the copy atom and partitions the shared memory for the scale tensor. + + :param src_copy_a: Tiled copy for the source tensor + :type src_copy_a: cute.TiledCopy + :param tCsS: Scale tensor in shared memory + :type tCsS: cute.Tensor + :param transform_local_tidx: Local thread index for transformation warps + :type transform_local_tidx: cutlass.Int32 + :param mma_dtype: Data type for the MMA operation + :type mma_dtype: type[cutlass.Numeric] + + :return: A tuple containing (smem_thr_copy_S, tSsS_trans, tSrS) where: + - smem_thr_copy_S: Tiled copy for the scale tensor + - tSsS_trans: Partitioned scale tensor for transformation + - tSrS_copy: Register fragment for the scale tensor + - tSrS: view of scale tensor used for transformation computation + :rtype: tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor] + """ + smem_thr_copy_S = None + tSsS_trans = None + tSrS = None + # Partition scale tensor + smem_thr_copy_S = src_copy_a.get_slice(transform_local_tidx) + tSsS_trans = smem_thr_copy_S.partition_S(tCsS) + # Construct register fragment for scale tensor + tSsS_layout_per_stage = tSsS_trans[(None, None, None, None, 0)].layout + # tSrS for copy + tSrS_copy = cute.make_rmem_tensor( + cute.filter_zeros(tSsS_layout_per_stage).shape, mma_dtype + ) + # tSrS view for transformation computation + tSrS = cute.make_tensor( + tSrS_copy.iterator, + cute.make_layout( + tSsS_layout_per_stage.shape, stride=tSrS_copy.layout.stride + ), + ) + return smem_thr_copy_S, tSsS_trans, tSrS_copy, tSrS + + def get_gmem_layout_scale( + self, scale_shape_mkl: tuple[int, int, int] + ) -> cute.Layout: + """ + Get the layout of the scale tensor in global memory. + + :param scale_shape_mkl: The shape of the scale tensor (M, K, L). + :type scale_shape_mkl: tuple[int, int, int] + + :return: The layout of the scale tensor in global memory. + :rtype: cute.Layout + """ + m, k, l = scale_shape_mkl + shape_scale = ( + (self.scale_granularity_m, cute.ceil_div(m, self.scale_granularity_m)), + (self.scale_granularity_k, cute.ceil_div(k, self.scale_granularity_k)), + ) + if cutlass.const_expr(self.scale_major_mode == tcgen05.OperandMajorMode.MN): + layout_mk = cute.make_layout( + shape_scale, + stride=( + (0, 1), + (0, cute.size(shape_scale[0][1])), + ), + ) + else: + layout_mk = cute.make_layout( + shape_scale, + stride=( + (0, cute.size(shape_scale[1][1])), + (0, 1), + ), + ) + return cute.make_layout( + (*layout_mk.shape, l), + stride=(*layout_mk.stride, cute.cosize(layout_mk)), + ) + + def get_smem_layout_scale(self) -> tuple[cute.ComposedLayout, cute.ComposedLayout]: + """ + Get the layout of the scale tensor in shared memory. + + :return: A tuple containing: + - smem_layout_scale_per_stage: Shared memory layout for scale tensor per stage + - smem_layout_scale: Shared memory layout for scale tensor + :rtype: tuple[cute.ComposedLayout, cute.ComposedLayout] + """ + self.scale_tile_shape = ( + ( + cute.size(self.mma_tiler[0]) // 2 + if self.use_2cta_instrs + else cute.size(self.mma_tiler[0]) + ), + cute.size(self.mma_tiler[2]), + ) + size_mn = self.scale_tile_shape[0] + size_k = self.scale_tile_shape[1] + smem_size_mn = ( + self.scale_granularity_m if self.scale_granularity_m < size_mn else size_mn + ) + smem_size_k = ( + self.scale_granularity_k if self.scale_granularity_k < size_k else size_k + ) + div_mn = cute.ceil_div(size_mn, smem_size_mn) + div_k = cute.ceil_div(size_k, smem_size_k) + smem_atom_shape = ( + (smem_size_mn, div_mn), + (smem_size_k, div_k), + ) + if cutlass.const_expr(self.scale_major_mode == tcgen05.OperandMajorMode.MN): + outer_layout = cute.make_layout( + smem_atom_shape, + stride=( + (0, 1), + (0, div_mn), + ), + ) + else: + outer_layout = cute.make_layout( + smem_atom_shape, + stride=( + (0, div_k), + (0, 1), + ), + ) + # Apply a trivial swizzle to make it a composed layout, which could be used to construct TMA atom + smem_layout_scale_per_stage = cute.make_composed_layout( + cute.make_swizzle(0, 4, 3), 0, outer_layout + ) + assert cute.rank(smem_layout_scale_per_stage) == 2, ( + "Scale layout must be rank 2" + ) + + assert ( + cute.size(self.mma_tiler[0]) + % cute.size(smem_layout_scale_per_stage.outer[0]) + == 0 + ), "smem_layout_scale_per_stage must equal the tile shape." + assert ( + cute.size(self.mma_tiler[2]) + % cute.size(smem_layout_scale_per_stage.outer[1]) + == 0 + ), "smem_layout_scale_per_stage must evenly divide tile k shape." + # Shared memory buffer for scale must be at least 128B to satisfy TMA requirement + assert ( + cute.size_in_bytes(self.a_scale_dtype, smem_layout_scale_per_stage) >= 128 + ), "smem size for scale must be at least 128B" + # Scale layout in smem with multiple stages + smem_layout_scale = cute.append( + smem_layout_scale_per_stage, + cute.make_layout( + (self.num_scale_load2trans_stage), + stride=(cute.cosize(smem_layout_scale_per_stage.outer)), + ), + ) + return smem_layout_scale_per_stage, smem_layout_scale + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """ + Partitions source and destination tensors for a global memory store. + + This method generates a tiled copy for storing results to global memory + and partitions the source (register or shared memory) and destination + (global memory) tensors accordingly. The behavior varies based on whether + TMA store is enabled. + + :param tidx: The thread index in epilogue warp groups. + :type tidx: cutlass.Int32 + :param atom: The copy atom to be used (TMA or universal). + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C. + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler. + :type epi_tile: cute.Tile + :param sC: The shared memory tensor C. + :return: A tuple containing the appropriate copy atom and partitioned + source and destination tensors for the store operation. + :rtype: tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + if self.use_tma_store: + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + else: + tiled_copy_t2r = atom + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tTR_gC = thr_copy_t2r.partition_D(gC_epi) + # (T2R, T2R_M, T2R_N) + tTR_rC = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype + ) + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) + return simt_atom, tTR_rC, tTR_gC + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Partitions source and destination tensors for a shared memory store. + + This method generates a tiled copy for storing results to shared memory + and partitions the source (register) and destination (shared memory) + tensors accordingly. + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy. + :param tTR_rC: The partitioned accumulator tensor. + :param tidx: The thread index in epilogue warp groups. + :param sC: The shared memory tensor to be copied and partitioned. + :return: A tuple containing the tiled copy for the store operation and + the partitioned source and destination tensors. + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Partitions source and destination tensors for a tensor memory load. + + This method generates a tiled copy for loading accumulators from tensor + memory and partitions the source (tensor memory) and destination + (register) tensors accordingly. + + :param tidx: The thread index in epilogue warp groups. + :param tAcc: The accumulator tensor to be copied and partitioned. + :param gC_mnl: The global tensor C. + :param epi_tile: The epilogue tiler. + :param use_2cta_instrs: Whether use_2cta_instrs is enabled. + :return: A tuple containing the tiled copy for the load operation and + the partitioned source and destination tensors. + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + @staticmethod + def align_up(x: int, align: int) -> int: + """Align x up to the nearest multiple of align.""" + return (x + align - 1) // align * align + + @staticmethod + def _compute_stages_and_tmem_cols( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: tuple[int, int, int], + cta_tile_shape_mnk: tuple[int, int, int], + epi_tile: cute.Tile, + a_dtype: type[cutlass.Numeric], + b_dtype: type[cutlass.Numeric], + c_dtype: type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + transform_a_source: tcgen05.OperandSource, + scale_granularity_m: int, + scale_granularity_k: int, + smem_buffer_align_bytes: int, + use_tma_store: bool, + scale_mode: TransformMode, + ) -> tuple[int, int, int, int, int, int, int]: + """ + Compute pipeline stages and TMEM column allocation configurations. + + This method calculates the number of pipeline stages for different operations + (load2trans, trans2mma, accumulator, etc.) and determines TMEM column allocation + based on available memory resources and tile configuration. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param c_dtype: Data type of operand C. + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C. + :type c_layout: utils.LayoutEnum + :param transform_a_source: The source of the transformed A tensor. + :type transform_a_source: tcgen05.OperandSource + :param scale_granularity_m: The granularity of the scale tensor along the M mode. + :type scale_granularity_m: int + :param scale_granularity_k: The granularity of the scale tensor along the K mode. + :type scale_granularity_k: int + :param smem_buffer_align_bytes: The alignment of the shared memory buffer. + :type smem_buffer_align_bytes: int + :param use_tma_store: Whether TMA store is enabled. + :type use_tma_store: bool + :param scale_mode: The transform mode. + :type scale_mode: TransformMode + + :return: A tuple containing the number of stages for: + (load2trans, scale_load2trans, transform2mma, accumulator, c, tmem_acc_cols, tmem_a_cols) + :rtype: tuple[int, int, int, int, int, int, int] + - num_load2trans_stage: Stages for load-to-transform A and B tensors pipeline + - num_scale_load2trans_stage: Stages for scale load-to-transform A tensor pipeline + - num_trans2mma_stage: Stages for transform-to-MMA pipeline + - num_acc_stage: Stages for accumulator-to-epilogue pipeline + - num_c_stage: Stages for epilogue-to-output C pipeline + - num_acc_tmem_cols: TMEM columns for accumulator + - num_a_tmem_cols: TMEM columns for transformed A tensor + """ + # Compute tmem columns required for accumulator + acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) + tCtAcc_stage1 = tiled_mma.make_fragment_C(cute.append(acc_shape, 1)) + num_tmem_acc_col_per_stage = utils.get_num_tmem_alloc_cols(tCtAcc_stage1, True) + # Heuristic to decide the number of stages for accumulator + sm100_tmem_columns = cute.arch.SM100_TMEM_CAPACITY_COLUMNS + accumulator_stage_count = sm100_tmem_columns // num_tmem_acc_col_per_stage + if transform_a_source == tcgen05.OperandSource.TMEM: + if num_tmem_acc_col_per_stage < 128: + accumulator_stage_count = 3 + elif num_tmem_acc_col_per_stage < 256: + accumulator_stage_count = 2 + else: + accumulator_stage_count = 1 + # transformed A in 16bit, thus 1 tmem column could hold 2 elements + num_elts_per_tmem_col = 32 // tiled_mma.op.a_dtype.width + num_tmem_cols_a_per_stage = MixedInputGemmKernel.align_up( + ( + cta_tile_shape_mnk[2] // num_elts_per_tmem_col + if transform_a_source == tcgen05.OperandSource.TMEM + else 0 + ), + 4, + ) + + c_stage_count = 2 if use_tma_store else 0 + c_smem_layout_staged_one = ( + sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + if use_tma_store + else None + ) + c_bytes_per_stage = ( + cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + if use_tma_store + else 0 + ) + c_bytes = c_bytes_per_stage * c_stage_count + + smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + bytes_per_pipeline_stage = 16 + if scale_mode == TransformMode.ConvertOnly: + scale_load2trans_stage_count = 0 + a_scale_bytes_per_stage = 0 + else: + # Ensure we have 2 buffers for scale tiles needed for 1 CTA tile + a_scale_k_mode = max(cta_tile_shape_mnk[2] // scale_granularity_k, 1) + a_scale_m_mode = max(cta_tile_shape_mnk[0] // scale_granularity_m, 1) + scale_load2trans_stage_count = 2 + a_scale_bytes_per_stage = MixedInputGemmKernel.align_up( + cute.size_in_bytes( + tiled_mma.op.a_dtype, + cute.make_layout((a_scale_m_mode, a_scale_k_mode)), + ), + smem_buffer_align_bytes, + ) + a_scale_bytes = ( + a_scale_bytes_per_stage + bytes_per_pipeline_stage + ) * scale_load2trans_stage_count + caveout_smem_bytes = ( + bytes_per_pipeline_stage * accumulator_stage_count + a_scale_bytes + c_bytes + ) + + # Compute transform stages if A is in TMEM + num_tmem_acc_cols = MixedInputGemmKernel.align_up( + accumulator_stage_count * num_tmem_acc_col_per_stage, 4 + ) + + transform2mma_stage_count_a_source_tmem_potential = ( + (sm100_tmem_columns - num_tmem_acc_cols) // num_tmem_cols_a_per_stage + if transform_a_source == tcgen05.OperandSource.TMEM + else -1 + ) + if ( + transform_a_source == tcgen05.OperandSource.TMEM + and transform2mma_stage_count_a_source_tmem_potential <= 0 + ): + raise ValueError("Not enough TMEM capacity for selected tile size") + a_load_bytes_per_stage = MixedInputGemmKernel.align_up( + cute.size_in_bytes( + a_dtype, + cute.make_layout((cta_tile_shape_mnk[0], cta_tile_shape_mnk[2])), + ), + smem_buffer_align_bytes, + ) + b_load_bytes_per_stage = MixedInputGemmKernel.align_up( + cute.size_in_bytes( + b_dtype, + cute.make_layout( + ( + cta_tile_shape_mnk[1] // cute.size(tiled_mma.thr_id), + cta_tile_shape_mnk[2], + ) + ), + ), + smem_buffer_align_bytes, + ) + ab_load_bytes_per_stage = ( + a_load_bytes_per_stage + + b_load_bytes_per_stage + + 2 * bytes_per_pipeline_stage + ) + a_transform_bytes_per_stage = ( + MixedInputGemmKernel.align_up( + cute.size_in_bytes( + tiled_mma.op.a_dtype, + cute.make_layout((cta_tile_shape_mnk[0], cta_tile_shape_mnk[2])), + ), + smem_buffer_align_bytes, + ) + if transform_a_source == tcgen05.OperandSource.SMEM + else 0 + ) + + a_transform_bytes_per_stage = ( + a_transform_bytes_per_stage + bytes_per_pipeline_stage + ) + transform2mma_stage_count_a_source_smem_potential = ( + smem_capacity - caveout_smem_bytes + ) // (ab_load_bytes_per_stage + a_transform_bytes_per_stage) + transform2mma_stage_count = ( + min( + transform2mma_stage_count_a_source_tmem_potential, + transform2mma_stage_count_a_source_smem_potential, + ) + if transform_a_source == tcgen05.OperandSource.TMEM + else transform2mma_stage_count_a_source_smem_potential + ) + load2transform_stage_count = ( + smem_capacity + - caveout_smem_bytes + - (transform2mma_stage_count * a_transform_bytes_per_stage) + ) // ab_load_bytes_per_stage + if ( + load2transform_stage_count < 2 + or transform2mma_stage_count < 2 + or accumulator_stage_count < 1 + ): + raise ValueError("Not enough SMEM or TMEM capacity for selected tile size") + num_tmem_a_cols = transform2mma_stage_count * num_tmem_cols_a_per_stage + # Check if we can increase c_stage_count with leftover smem + if use_tma_store: + c_stage_count += ( + smem_capacity + - load2transform_stage_count * ab_load_bytes_per_stage + - transform2mma_stage_count * a_transform_bytes_per_stage + - scale_load2trans_stage_count * a_scale_bytes_per_stage + - c_bytes + ) // c_bytes_per_stage + + return ( + load2transform_stage_count, + scale_load2trans_stage_count, + transform2mma_stage_count, + accumulator_stage_count, + c_stage_count, + num_tmem_acc_cols, + num_tmem_a_cols, + ) + + @staticmethod + def _compute_smem_layout( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: tuple[int, int, int], + a_dtype: type[cutlass.Numeric], + b_dtype: type[cutlass.Numeric], + load2trans_stage_count: int, + trans2mma_stage_count: int, + ) -> tuple[ + cute.ComposedLayout, + cute.ComposedLayout, + cute.ComposedLayout, + ]: + """ + Compute shared memory layouts for tensor A, transformed A and tensor B. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param load2trans_stage_count: Number of stages for load-to-transform pipeline. + :type load2trans_stage_count: int + :param trans2mma_stage_count: Number of stages for transform-to-MMA pipeline. + :type trans2mma_stage_count: int + + :return: A tuple containing: + - smem_layout_a: Shared memory layout for tensor A + - smem_layout_a_transform: Shared memory layout for transformed tensor A + - smem_layout_b: Shared memory layout for tensor B + :rtype: tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout] + """ + smem_layout_a = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + load2trans_stage_count, + ) + smem_layout_a_transform = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + tiled_mma.op.a_dtype, + trans2mma_stage_count, + ) + smem_layout_b = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + load2trans_stage_count, + ) + return ( + smem_layout_a, + smem_layout_a_transform, + smem_layout_b, + ) + + @staticmethod + def _get_transform_a_source( + a_major_mode: tcgen05.OperandMajorMode, + ) -> tcgen05.OperandSource: + """ + Determine the operand source for transformed A tensor based on the operand major mode. + """ + if cutlass.const_expr(a_major_mode == tcgen05.OperandMajorMode.K): + return tcgen05.OperandSource.TMEM + else: + return tcgen05.OperandSource.SMEM + + @staticmethod + def _get_tma_atom_kind( + mcast: cutlass.Boolean, + use_2cta_instrs: bool, + is_b: bool, + ) -> Union[ + cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp + ]: + """ + Get the TMA atom kind based on 1) whether it's a multicast operation, + 2) whether 2CTA tcgen05.mma instruction is enabled, and + 3) whether it's a B tensor + """ + # Not using .2CTA instructions for tensor A as the consumer is threads on different CTAs + cta_group = ( + tcgen05.CtaGroup.TWO if (use_2cta_instrs and is_b) else tcgen05.CtaGroup.ONE + ) + if cutlass.const_expr(mcast): + return cpasync.CopyBulkTensorTileG2SMulticastOp(cta_group) + return cpasync.CopyBulkTensorTileG2SOp(cta_group) + + @staticmethod + def _get_copy_atom_a_transform( + mma_dtype: type[cutlass.Numeric], + use_2cta_instrs: bool, + transform_a_source: tcgen05.OperandSource, + a_smem_shape: cute.Shape, + a_dtype: type[cutlass.Numeric], + ) -> cute.CopyAtom: + """ + Determine the copy atom for transformed A tensor based on the operand source and tile size. + """ + if cutlass.const_expr(transform_a_source == tcgen05.OperandSource.TMEM): + if cutlass.const_expr( + cute.size(a_smem_shape[0][0]) == 64 and (not use_2cta_instrs) + ): + copy_op_r2t = tcgen05.St16x256bOp( + tcgen05.Repetition(1), tcgen05.Unpack.NONE + ) + else: + copy_op_r2t = tcgen05.St32x32bOp( + tcgen05.Repetition(8), tcgen05.Unpack.NONE + ) + return cute.make_copy_atom(copy_op_r2t, mma_dtype) + else: + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), a_dtype, num_bits_per_copy=32 + ) + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: tuple[int, int, int], + cluster_shape_mn: tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]: + """ + Use persistent tile scheduler to compute the grid size for the output tensor C. + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + def is_valid_scale_granularity( + scale_granularity_m: int, + scale_granularity_k: int, + a_dtype: type[cutlass.Numeric], + k: int, + mma_tiler_k: int, + ) -> bool: + """ + Check if the scale granularity settings are valid for the given data type and problem size. + """ + if a_dtype.width == 8: + # No scale tensor for 8bit data type A + if not (scale_granularity_m == 0 and scale_granularity_k == 0): + return False + elif a_dtype.width == 4: + if scale_granularity_m != 1 or ( + scale_granularity_k == 0 + or k % scale_granularity_k != 0 + or scale_granularity_k % mma_tiler_k != 0 + ): + return False + return True + + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + a_dtype: type[cutlass.Numeric], + b_dtype: type[cutlass.Numeric], + c_dtype: type[cutlass.Numeric], + scale_dtype: type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mnk: tuple[int, int, int], + use_2cta_instrs: bool, + cluster_shape_mn: tuple[int, int], + scale_granularity_m: int, + scale_granularity_k: int, + ) -> bool: + """ + Check if the tensor alignments are valid for the given problem size and data types. + """ + + def check_contiguous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if not ( + check_contiguous_16B_alignment(a_dtype, a_major == "m", (m, k)) + and check_contiguous_16B_alignment(b_dtype, b_major == "n", (n, k)) + and check_contiguous_16B_alignment(c_dtype, c_major == "m", (m, n)) + and ( + scale_granularity_k == 0 + or check_contiguous_16B_alignment( + b_dtype, True, (m, k // scale_granularity_k) + ) + ) + ): + return False + # Check if scale tensor matches the TMA load 128B alignment requirement + cta_tile_shape_mnk = ( + mma_tiler_mnk[0] // (2 if use_2cta_instrs else 1), + mma_tiler_mnk[1], + mma_tiler_mnk[2], + ) + if ( + scale_granularity_m > 0 + and (cta_tile_shape_mnk[0] // cluster_shape_mn[1] // scale_granularity_m) + * (scale_dtype.width // 8) + < 128 + ): + return False + + return True + + def is_valid_epilog_store_option( + m: int, + n: int, + mma_tiler_mn: tuple[int, int], + use_tma_store: bool, + use_2cta_instrs: bool, + ) -> bool: + """ + Check if the epilogue store option is valid for the given problem size. + """ + cta_tile_shape_mn = ( + mma_tiler_mn[0] // (2 if use_2cta_instrs else 1), + mma_tiler_mn[1], + ) + # No OOB tile support when TMA store is disabled + if not use_tma_store: + if not (m % cta_tile_shape_mn[0] == 0 and n % cta_tile_shape_mn[1] == 0): + return False + return True + + def is_valid_mma_tiler_and_cluster_shape( + mma_tiler: tuple[int, int, int], + cluster_shape_mn: tuple[int, int], + use_2cta_instrs: bool, + ) -> bool: + """ + Check if the MMA tiler and cluster shape are valid for the given problem size. + """ + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + return False + + if (mma_tiler[0] // (2 if use_2cta_instrs else 1)) not in [64, 128]: + return False + return True + + def can_implement( + mnkl: tuple[int, int, int, int], + a_dtype: type[cutlass.Numeric], + b_dtype: type[cutlass.Numeric], + c_dtype: type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + scale_granularity_m: int, + scale_granularity_k: int, + mma_tiler: tuple[int, int, int], + cluster_shape_mn: tuple[int, int], + use_2cta_instrs: bool, + use_tma_store: bool, + ) -> bool: + """ + Check if the kernel can be implemented for the given tensor shapes and data types. + """ + m, n, k, l = mnkl + + if not MixedInputGemmKernel.is_valid_mma_tiler_and_cluster_shape( + mma_tiler, cluster_shape_mn, use_2cta_instrs + ): + return False + if not MixedInputGemmKernel.is_valid_scale_granularity( + scale_granularity_m, scale_granularity_k, a_dtype, k, mma_tiler[2] + ): + return False + if not MixedInputGemmKernel.is_valid_tensor_alignment( + m, + n, + k, + a_dtype, + b_dtype, + c_dtype, + b_dtype, + a_major, + b_major, + c_major, + mma_tiler, + use_2cta_instrs, + cluster_shape_mn, + scale_granularity_m, + scale_granularity_k, + ): + return False + if not MixedInputGemmKernel.is_valid_epilog_store_option( + m, n, mma_tiler[:2], use_tma_store, use_2cta_instrs + ): + return False + return True + + +def create_i4_tensor_and_scale( + l: int, + m: int, + k: int, + is_m_major: bool, + dtype: type[cutlass.Numeric], + scale_granularity_m: int, + scale_granularity_k: int, + is_dynamic_layout: bool = True, + init_config: tuple = ( + cutlass_torch.TensorInitType.RANDOM, + cutlass_torch.RandomInitConfig(min_val=-7, max_val=6), + ), + divisibility: int = 16, + transformed_dtype: Optional[type[cutlass.Numeric]] = None, +) -> tuple[ + cute.Tensor, torch.Tensor, torch.Tensor, cute.Tensor, torch.Tensor, torch.Tensor +]: + """ + Create quantized 4-bit tensor and corresponding scale tensor. + """ + lb_4b = -8 if dtype == cutlass.Int4 else 0 + up_4b = 7 if dtype == cutlass.Int4 else 15 + if not ( + init_config[0] == cutlass_torch.TensorInitType.RANDOM + or init_config[0] == cutlass_torch.TensorInitType.SCALAR + ): + raise ValueError( + "Only random and scalar initialization is supported for 4bit data type" + ) + + # Construct reference tensor in f32 + ref_fp32 = cutlass_torch.matrix(l, m, k, is_m_major, cutlass.Float32, *init_config) + # Generate scale data and perform quantization + num_scales = k // scale_granularity_k + ref = ref_fp32.to(dtype=cutlass_torch.dtype(transformed_dtype)).reshape( + m, num_scales, scale_granularity_k, l + ) + # Get elements with maximum absolute value to compute scaling factors + a_max = torch.maximum(ref / up_4b, ref / lb_4b) + a_scales, _ = torch.max(a_max, dim=2, keepdim=True) + a_scale_inv = torch.where(a_scales == 0, 0, 1 / a_scales) + a_quant = ref * a_scale_inv + # Convert values to integer to avoid computation errors + a_quant = a_quant.to(dtype=torch.int32).reshape((m, k, l)).to(dtype=torch.float32) + # Construct A quantized tensor + cute_a_quant_tensor, torch_a_quant_tensor = cutlass_torch.cute_tensor_like( + a_quant, dtype, is_dynamic_layout=is_dynamic_layout, assumed_align=divisibility + ) + # Construct cute scale tensor + a_scales = a_scales.random_(-3, 3).reshape((m, num_scales, l)) + # Scale tensor is always m-major + a_scales = a_scales.permute(2, 1, 0).contiguous().permute(2, 1, 0).to(device="cuda") + cute_scale_tensor = from_dlpack(a_scales, assumed_align=divisibility) + for i, stride in enumerate(a_scales.stride()): + if stride == 1: + leading_dim = i + break + if is_dynamic_layout: + cute_scale_tensor = cute_scale_tensor.mark_layout_dynamic( + leading_dim=leading_dim + ) + + return ( + cute_a_quant_tensor, + torch_a_quant_tensor, + a_quant.to("cpu"), + cute_scale_tensor, + a_scales, + a_scales.to("cpu"), + ) + + +def get_divisibility(contiguous_dim_size: int, upper_bound: int = 128) -> int: + """ + Calculate the largest power of 2 divisibility factor for memory alignment. + """ + # Check the largest power of 2 factor of contiguous_dim_size + for i in range(int(log2(contiguous_dim_size)), 0, -1): + if contiguous_dim_size % (2**i) == 0: + return min(2**i, upper_bound) + return 1 + + +def create_tensor_a( + l: int, + m: int, + k: int, + a_major: str, + a_dtype: type[cutlass.Numeric], + scale_granularity_m: int = 0, + scale_granularity_k: int = 0, + transformed_dtype: Optional[type[cutlass.Numeric]] = None, +) -> tuple[cute.Tensor, cute.Tensor, torch.Tensor, torch.Tensor]: + """ + Create tensor A and scale tensor. + """ + a_scale_tensor = None + a_scale_torch_cpu = None + if a_dtype in (cutlass.Int4,): + ( + a_tensor, + a_torch_gpu, + a_torch_cpu, + a_scale_tensor, + a_scale_torch_gpu, + a_scale_torch_cpu, + ) = create_i4_tensor_and_scale( + l, + m, + k, + a_major == "m", + a_dtype, + scale_granularity_m, + scale_granularity_k, + divisibility=get_divisibility(m if a_major == "m" else k), + transformed_dtype=transformed_dtype, + ) + else: + a_torch_cpu = cutlass_torch.matrix( + l, + m, + k, + a_major == "m", + a_dtype, + ) + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_torch_cpu, + a_dtype, + is_dynamic_layout=True, + assumed_align=get_divisibility(m if a_major == "m" else k), + ) + return a_tensor, a_scale_tensor, a_torch_cpu, a_scale_torch_cpu + + +def create_tensors( + l: int, + m: int, + n: int, + k: int, + a_major: str, + b_major: str, + c_major: str, + a_dtype: type[cutlass.Numeric], + b_dtype: type[cutlass.Numeric], + c_dtype: type[cutlass.Numeric], + scale_granularity_m: int = 0, + scale_granularity_k: int = 0, +) -> tuple: + """ + Create all input and output tensors for the mixed-input GEMM. + """ + torch.manual_seed(2025) + + a_tensor, a_scale_tensor, a_torch_cpu, a_scale_torch_cpu = create_tensor_a( + l, m, k, a_major, a_dtype, scale_granularity_m, scale_granularity_k, b_dtype + ) + + b_torch_cpu = cutlass_torch.matrix( + l, + n, + k, + b_major == "n", + b_dtype, + cutlass_torch.TensorInitType.RANDOM, + cutlass_torch.RandomInitConfig(min_val=-10, max_val=10), + ) + c_torch_cpu = cutlass_torch.matrix( + l, + m, + n, + c_major == "m", + c_dtype, + ) + + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_torch_cpu, + b_dtype, + is_dynamic_layout=True, + assumed_align=get_divisibility(n if b_major == "n" else k), + ) + c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like( + c_torch_cpu, + c_dtype, + is_dynamic_layout=True, + assumed_align=get_divisibility(m if c_major == "m" else n), + ) + c_tensor = c_tensor.mark_compact_shape_dynamic( + mode=(0 if c_major == "m" else 1), + stride_order=(2, 1, 0) if c_major == "m" else (2, 0, 1), + divisibility=get_divisibility(m if c_major == "m" else n), + ) + + return ( + a_tensor, + a_scale_tensor, + b_tensor, + c_tensor, + a_torch_cpu, + a_scale_torch_cpu, + b_torch_cpu, + c_torch_gpu, + ) + + +def compare( + a_torch_cpu: torch.Tensor, + b_torch_cpu: torch.Tensor, + a_scale_torch_cpu: Optional[torch.Tensor], + c_torch_gpu: torch.Tensor, + c_dtype: type[cutlass.Numeric], + tolerance: float, +) -> None: + """ + Compare kernel result with reference computation. + """ + kernel_result = c_torch_gpu.cpu() + # Compute reference result + if a_scale_torch_cpu is not None: + scale_shape = a_scale_torch_cpu.shape + a_shape = a_torch_cpu.shape + a_scale_torch_cpu = a_scale_torch_cpu.to(dtype=torch.float32).reshape( + scale_shape[0], scale_shape[1], 1, scale_shape[2] + ) + a_torch_cpu = a_torch_cpu.to(dtype=torch.float32).reshape( + a_torch_cpu.shape[0], scale_shape[1], -1, a_torch_cpu.shape[2] + ) + a_dequant = a_torch_cpu * a_scale_torch_cpu + ref = torch.einsum( + "mkl,nkl->mnl", + a_dequant.reshape(a_shape), + b_torch_cpu.to(dtype=torch.float32), + ) + else: + ref = torch.einsum( + "mkl,nkl->mnl", + a_torch_cpu.to(dtype=torch.float32), + b_torch_cpu.to(dtype=torch.float32), + ) + # Convert ref to c_dtype + _, ref_torch_gpu = cutlass_torch.cute_tensor_like( + ref, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + ref_result = ref_torch_gpu.cpu() + + # Assert close results + torch.testing.assert_close(kernel_result, ref_result, atol=tolerance, rtol=1e-05) + + +def run( + mnkl: tuple[int, int, int, int], + scale_granularity_m: int, + scale_granularity_k: int, + a_dtype: type[cutlass.Numeric], + b_dtype: type[cutlass.Numeric], + c_dtype: type[cutlass.Numeric], + acc_dtype: type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mnk: tuple[int, int, int], + cluster_shape_mn: tuple[int, int], + use_2cta_instrs: bool, + use_tma_store: bool, + tolerance: float, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +) -> None: + """ + Run the mixed-input GEMM kernel with specified parameters. + + This function creates tensors, validates parameters, executes the kernel, + optionally compares results with a reference implementation and reports + kernel execution time. + """ + m, n, k, l = mnkl + + if not torch.cuda.is_available(): + raise ValueError("CUDA is not available") + + # Check if given configuration is supported + if not MixedInputGemmKernel.can_implement( + mnkl, + a_dtype, + b_dtype, + c_dtype, + a_major, + b_major, + c_major, + scale_granularity_m, + scale_granularity_k, + mma_tiler_mnk, + cluster_shape_mn, + use_2cta_instrs, + use_tma_store, + ): + raise ValueError("GEMM configuration not supported") + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + + ( + a_tensor, + a_scale_tensor, + b_tensor, + c_tensor, + a_torch_cpu, + a_scale_torch_cpu, + b_torch_cpu, + c_torch_gpu, + ) = create_tensors( + l, + m, + n, + k, + a_major, + b_major, + c_major, + a_dtype, + b_dtype, + c_dtype, + scale_granularity_m, + scale_granularity_k, + ) + + mixed_input_gemm = MixedInputGemmKernel( + scale_granularity_m, + scale_granularity_k, + acc_dtype, + use_2cta_instrs, + mma_tiler_mnk, + cluster_shape_mn, + use_tma_store, + ) + + max_active_clusters = utils.HardwareInfo().get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1], + ) + compiled_kernel = cute.compile( + mixed_input_gemm, + a_tensor, + a_scale_tensor, + b_tensor, + c_tensor, + max_active_clusters, + current_stream, + ) + + if not skip_ref_check: + compiled_kernel( + a_tensor, + a_scale_tensor, + b_tensor, + c_tensor, + current_stream, + ) + compare( + a_torch_cpu, b_torch_cpu, a_scale_torch_cpu, c_torch_gpu, c_dtype, tolerance + ) + + # Early return if no performance measurement is needed + if iterations <= 0: + return + + def generate_tensors(): + a_tensor, a_scale_tensor, a_torch_cpu, a_scale_torch_cpu = create_tensor_a( + l, m, k, a_major, a_dtype, scale_granularity_m, scale_granularity_k, b_dtype + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_torch_cpu, + b_dtype, + is_dynamic_layout=True, + assumed_align=get_divisibility(n if b_major == "n" else k), + ) + c_torch_cpu = cutlass_torch.matrix(l, m, n, c_major == "m", c_dtype) + c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like( + c_torch_cpu, + c_dtype, + is_dynamic_layout=True, + assumed_align=get_divisibility(m if c_major == "m" else n), + ) + c_tensor = c_tensor.mark_compact_shape_dynamic( + mode=(0 if c_major == "m" else 1), + stride_order=(2, 1, 0) if c_major == "m" else (2, 0, 1), + divisibility=get_divisibility(m if c_major == "m" else n), + ) + return testing.JitArguments( + a_tensor, a_scale_tensor, b_tensor, c_tensor, current_stream + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch_cpu.numel() * a_torch_cpu.element_size() + + b_torch_cpu.numel() * b_torch_cpu.element_size() + + c_torch_gpu.numel() * c_torch_gpu.element_size() + + a_scale_torch_cpu.numel() * a_scale_torch_cpu.element_size() + if a_scale_torch_cpu is not None + else 0 + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_kernel, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--mnkl", type=parse_comma_separated_ints, default=(128, 128, 128, 1) + ) + parser.add_argument( + "--mma_tiler_mnk", type=parse_comma_separated_ints, default=(128, 128, 128) + ) + parser.add_argument( + "--cluster_shape_mn", type=parse_comma_separated_ints, default=(1, 1) + ) + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument( + "--a_dtype", + type=cutlass.dtype, + default=cutlass.Int4, + choices=[cutlass.Int8, cutlass.Uint8, cutlass.Int4], + ) + parser.add_argument( + "--b_dtype", + type=cutlass.dtype, + default=cutlass.BFloat16, + choices=[cutlass.BFloat16, cutlass.Float16], + ) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.BFloat16) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="m") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--scale_granularity_m", + type=int, + default=1, + help="Scale granularity along M dimension.", + ) + parser.add_argument( + "--scale_granularity_k", + type=int, + default=128, + help="Scale granularity along K dimension.", + ) + parser.add_argument( + "--use_tma_store", action="store_true", help="Use tma store or not" + ) + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + args = parser.parse_args() + + run( + args.mnkl, + args.scale_granularity_m, + args.scale_granularity_k, + args.a_dtype, + args.b_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mnk, + args.cluster_shape_mn, + args.use_2cta_instrs, + args.use_tma_store, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + ) + print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/mla.py b/examples/python/CuTeDSL/blackwell/mla.py new file mode 100644 index 00000000..4bdbcac5 --- /dev/null +++ b/examples/python/CuTeDSL/blackwell/mla.py @@ -0,0 +1,5193 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import math +from typing import Type, Tuple, Optional, Callable +from types import SimpleNamespace +from functools import partial + +import torch +import torch.nn.functional as F +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +import cutlass.cute.nvgpu.cpasync as cpasync +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack + +""" +A Multi-Head Latent Attention (MLA) example for the NVIDIA Blackwell SM100 architecture using CUTE DSL + +This example demonstrates an implementation of inference of multi-head latent attention using a TMA + Blackwell +SM100 TensorCore warp-specialized persistent kernel. The implementation integrates the (Qc + Qr)*(Kc + Kr)^T +matrix multiplication, softmax normalization, and softmax((Qc + Qr)*(Kc + Kr)^T)*Vc into a single kernel. +The kernel provides support for page table storage and variable-length KV cache sequences. It implements KV splitting +functionality to minimize latency when processing long KV sequences. + +The kernel implements key optimizations including: +- Warp specialization for different computation phases (load, MMA, softmax, correction, epilogue) +- Pipeline stages between different warps for overlapping computation and memory access +- Support for different precision data types +- Two sub-kernels (split KV kernel and reduction kernel) that enable split KV processing + +To run this example: + +.. code-block:: bash + + python examples/blackwell/mla.py \ + --batch_size 4 --latent_dim 512 --rope_dim 64 \ + --num_heads 128 --seq_len 1024 \ + --in_dtype Float8E4M3FN --out_dtype Float16 \ + --acc_dtype Float32 --lse_dtype Float32 \ + --use_page_table --is_var_seq --is_var_split_kv \ + --is_persistent + +The above example runs Multi-Head Latent Attention (MLA) with the following configuration: +- Batch size: 4 +- Sequence length: 1024 +- Latent dimension: 512 +- RoPE dimension: 64 +- Number of heads: 128 +- Data types: Float8E4M3FN (input), Float16 (output), Float32 (accumulation and LSE) + +It utilizes page table storage for the KV cache and enables both variable-length KV cache sequences +and variable split KV processing with persistent scheduling. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/mla.py \ + --batch_size 4 --latent_dim 512 --rope_dim 64 \ + --num_heads 128 --seq_len 1024 \ + --in_dtype Float8E4M3FN --out_dtype Float16 \ + --acc_dtype Float32 --lse_dtype Float32 \ + --use_page_table --is_var_seq --is_var_split_kv \ + --is_persistent --warmup_iterations 3 \ + --iterations 10 --skip_ref_check + +Constraints for this example: +* Data type requirements: + - Input/output: Float8E4M3FN or Float16 + - Accumulation and LSE: Float32 +* Fixed architecture parameters: + - Number of attention heads: 128 + - Latent dimension: 512 + - RoPE dimension: 64 +* Input query modes should be (NumHeads, LatentDim/RopeDim, BatchSize) +* Input kv latent/rope modes should be (SeqLen, LatentDim/RopeDim, BatchSize) +* Query sequence length must be 1 +* Only supports 2-CTA instructions +* Variable sequence length requires page table storage enabled +""" + + +class MLAStaticTileSchedulerParams: + def __init__( + self, + is_persistent: bool, + problem_shape_b: cute.Int32, + cluster_shape_mnk: cute.Shape, + split_kv: cutlass.Int32, + *, + loc=None, + ip=None, + ): + """The static tile scheduler parameters prepared for MLA static tile scheduler. + + :param is_persistent: Whether to use persistent kernel mode + :type is_persistent: bool + :param problem_shape_b: The shape of the problem + :type problem_shape_b: cute.Int32 + :param cluster_shape_mnk: The shape of the cluster + :type cluster_shape_mnk: cute.Shape + :param split_kv: The scalar factor for split KV + """ + self.is_persistent = is_persistent + self.problem_shape_b = problem_shape_b + self.cluster_shape_mnk = cluster_shape_mnk + self.split_kv = split_kv + self.loc = loc + self.ip = ip + + def __extract_mlir_values__(self): + values = cutlass.extract_mlir_values(self.problem_shape_b) + values += cutlass.extract_mlir_values(self.split_kv) + return values + + def __new_from_mlir_values__(self, values): + problem_shape_b = cutlass.new_from_mlir_values( + self.problem_shape_b, (values[0],) + ) + split_kv = cutlass.new_from_mlir_values(self.split_kv, (values[1],)) + return MLAStaticTileSchedulerParams( + self.is_persistent, + problem_shape_b, + self.cluster_shape_mnk, + split_kv, + loc=self.loc, + ) + + +def create_mla_static_tile_scheduler_params( + is_persistent: bool, + problem_shape_b: cute.Int32, + cluster_shape_mnk: cute.Shape, + split_kv: cutlass.Int32, +) -> MLAStaticTileSchedulerParams: + return MLAStaticTileSchedulerParams( + is_persistent, problem_shape_b, cluster_shape_mnk, split_kv + ) + + +class MLAStaticTileScheduler: + def __init__( + self, + params: MLAStaticTileSchedulerParams, + current_work_linear_idx: cutlass.Int32, + blk_coord: cute.Coord, + grid_shape: cute.Shape, + *, + is_valid: bool = True, + loc=None, + ip=None, + ): + """The static tile scheduler for MLA split kv kernel. + Based on `is_persistent`, it provides 2 modes for use: + - Persistent mode: Launch fixed blocks and reschedule the data blocks. + - Non-persistent mode: Launch dynamic blocks and exit when the current work is done. + + :param params: The static tile scheduler parameters + :type params: MLAStaticTileSchedulerParams + :param current_work_linear_idx: The linear index of the current work + :type current_work_linear_idx: cutlass.Int32 + :param blk_coord: The coordinate of the current work + :type blk_coord: cute.Coord + :param grid_shape: The shape of the grid + :type grid_shape: cute.Shape + :param is_valid: Whether the current work is valid + :type is_valid: bool + """ + self.params = params + self.blk_coord = blk_coord + self.grid_shape = grid_shape + self.current_work_linear_idx = current_work_linear_idx + if params.is_persistent: + self.persistent_blk_layout = cute.make_layout( + ( + params.cluster_shape_mnk[0], + 1, + params.problem_shape_b, + params.split_kv, + ), + loc=loc, + ip=ip, + ) + self.num_blocks = cute.size(self.persistent_blk_layout, loc=loc, ip=ip) + # Used for persistent scheduling + self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip) + else: + self.is_valid = is_valid + self.loc = loc + self.ip = ip + + @staticmethod + def get_grid_shape( + params: MLAStaticTileSchedulerParams, + max_active_clusters: int, + *, + loc=None, + ip=None, + ) -> cute.Shape: + # called by host + grid_shape = ( + params.cluster_shape_mnk[0], + params.problem_shape_b, + params.split_kv, + ) + if params.is_persistent: + return ( + cutlass.min( + max_active_clusters * cute.size(params.cluster_shape_mnk), + cute.size(grid_shape, loc=loc, ip=ip), + ), + 1, + 1, + ) + else: + return grid_shape + + def get_current_work(self, *, loc=None, ip=None) -> utils.WorkTileInfo: + is_valid = ( + self.current_work_linear_idx < self.num_blocks + if self.params.is_persistent + else self.is_valid + ) + + if self.params.is_persistent: + blk_coord = self.persistent_blk_layout.get_hier_coord( + self.current_work_linear_idx, loc=loc, ip=ip + ) + else: + blk_coord = (self.blk_coord[0], 0, self.blk_coord[1], self.blk_coord[2]) + + return utils.WorkTileInfo(blk_coord, is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None): + if self.params.is_persistent: + self.current_work_linear_idx += advance_count * self.num_persistent_sm + else: + self.is_valid = False + + def __extract_mlir_values__(self): + values = cutlass.extract_mlir_values(self.params) + values.extend(cutlass.extract_mlir_values(self.current_work_linear_idx)) + values.extend(cutlass.extract_mlir_values(self.blk_coord)) + values.extend(cutlass.extract_mlir_values(self.grid_shape)) + return values + + def __new_from_mlir_values__(self, values): + assert len(values) == 9 + new_params = cutlass.new_from_mlir_values(self.params, values[0:2]) + new_current_work_linear_idx = cutlass.new_from_mlir_values( + self.current_work_linear_idx, [values[2]] + ) + new_blk_coord = cutlass.new_from_mlir_values(self.blk_coord, values[3:6]) + new_grid_shape = cutlass.new_from_mlir_values(self.grid_shape, values[6:]) + return MLAStaticTileScheduler( + new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape + ) + + +def create_mla_static_tile_scheduler( + params: MLAStaticTileSchedulerParams, + blk_coord: cute.Coord, + grid_shape: cute.Shape, +) -> MLAStaticTileScheduler: + return MLAStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape) + + +LOG2_E = 1.4426950408889634074 +# avoid register indexing on array. +MAX_SPLITS = 256 + + +class BlackwellMultiHeadLatentAttentionForward: + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + max_active_clusters: int, + is_persistent: bool, + is_cpasync: bool, + use_page_table: bool, + is_var_seq: bool, + is_var_split_kv: bool, + ): + """Initializes the configuration for a Blackwell Multi-Head Latent Attention (MLA) kernel. + + :param acc_dtype: Data type for accumulation S and O + :type acc_dtype: Type[cutlass.Numeric] + :param lse_dtype: Data type for output LSE + :type lse_dtype: Type[cutlass.Numeric] + :param mma_s_tiler: The (H, K) tile shape of the MMA instruction for S + :type mma_s_tiler: Tuple[int, int] + :param mma_p_tiler: The (H, D) tile shape of the MMA instruction for P + :type mma_p_tiler: Tuple[int, int] + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: int + :param is_persistent: Whether to use persistent kernel mode + :type is_persistent: bool + :param is_cpasync: Whether to use CP async mode + :type is_cpasync: bool + :param use_page_table: Whether to use page table + :type use_page_table: bool + :param is_var_seq: Whether to use variable sequence length + :type is_var_seq: bool + :param is_var_split_kv: Whether to use variable split KV + :type is_var_split_kv: bool + """ + + self.latent_dim = 512 + self.rope_dim = 64 + self.acc_dtype = acc_dtype + self.lse_dtype = lse_dtype + self.mma_qk_tiler_mn = mma_qk_tiler_mn + self.mma_pv_tiler_mn = mma_pv_tiler_mn + self.max_active_clusters = max_active_clusters + self.is_persistent = is_persistent + self.is_cpasync = is_cpasync + self.use_page_table = use_page_table + self.is_var_seq = is_var_seq + self.is_var_split_kv = is_var_split_kv + self.cluster_shape_mnk = (2, 1, 1) + self.use_2cta_instrs = True + # When using 2 CTAs with m=128: warps 0-1 handle accumulation for first half [0, n/2), + # while warps 2-3 handle accumulation for second half [n/2, n) + self.warps_in_n = 2 + self.num_compute_warps = 4 + self.threads_per_warp = 32 + self.num_load_warps = 2 if self.is_cpasync else 1 + mma_qk_tiler_k = self.rope_dim if self.is_cpasync else self.rope_dim * 2 + self.mma_qk_tiler = ( + self.mma_qk_tiler_mn[0], + self.mma_qk_tiler_mn[1], + mma_qk_tiler_k, + ) + self.mma_pv_tiler = ( + self.mma_pv_tiler_mn[0], + self.mma_pv_tiler_mn[1], + self.mma_qk_tiler[1] * self.mma_qk_tiler[2] // self.mma_pv_tiler_mn[1], + ) + self.iterations_qk_latent = self.latent_dim // self.mma_qk_tiler[2] + self.iterations_qk_rope = mma_qk_tiler_k // self.mma_qk_tiler[2] + self.iterations_qk = self.iterations_qk_latent + self.iterations_qk_rope + self.iterations_pv_k = self.mma_qk_tiler[1] // self.mma_pv_tiler[2] + self.iterations_pv_n = self.latent_dim // self.mma_pv_tiler[1] + + # Set specialized warp ids + self.compute_warp_ids = (0, 1, 2, 3) + self.correction_warp_ids = (4, 5, 6, 7) + self.mma_warp_id = 8 + if self.is_cpasync: + self.load_cp_async_warp_ids = (9, 10) + self.load_pt_warp_id = 11 + self.threads_per_cta = self.threads_per_warp * len( + ( + self.mma_warp_id, + *self.load_cp_async_warp_ids, + self.load_pt_warp_id, + *self.compute_warp_ids, + *self.correction_warp_ids, + ) + ) + else: + self.load_tma_warp_id = 9 + self.empty_warp_ids = (10, 11) + self.threads_per_cta = self.threads_per_warp * len( + ( + self.mma_warp_id, + self.load_tma_warp_id, + *self.compute_warp_ids, + *self.correction_warp_ids, + *self.empty_warp_ids, + ) + ) + + # register settings + self.softmax_reg_num = 192 + self.correction_reg_num = 192 + self.other_reg_num = 112 + # Named barriers + self.tmem_ptr_sync_bar = pipeline.NamedBarrier( + barrier_id=1, + num_threads=( + self.threads_per_warp + + self.threads_per_warp * self.num_compute_warps * 2 + ), + ) + self.softmax_exchange_sync_bar = pipeline.NamedBarrier( + barrier_id=2, num_threads=(self.threads_per_warp * self.num_compute_warps) + ) + self.epilogue_exchange_sync_bar = pipeline.NamedBarrier( + barrier_id=3, num_threads=(self.threads_per_warp * self.num_compute_warps) + ) + + def _setup_attributes(self): + """Set up configurations and parameters for the MLA kernel operation. + + This method initializes and configures various attributes required for the + execution of the multi-head latent attention kernel, mainly about the pipeline stages: + + - Sets up staging parameters for Q, K, V inputs and accumulator data + - Configures pipeline stages for softmax, correction, and epilogue operations + """ + + self.load_q_stage = self.iterations_qk + self.load_kv_stage = (24 if self.is_cpasync else 12) // ( + self.k_dtype.width // 8 + ) + self.mma_s_stage = 2 + self.p_mma_stage = 2 + self.p_cor_stage = 2 + self.mma_o_stage = 1 + self.load_pt_stage = self.load_kv_stage if self.is_cpasync else 1 + + self.tmem_o_offset = self.mma_s_stage * self.mma_qk_tiler[1] // self.warps_in_n + self.correction_factor_offset = ( + self.tmem_o_offset + self.latent_dim // self.warps_in_n + ) + + @cute.jit + def __call__( + self, + q_latent: cute.Tensor, + q_rope: cute.Tensor, + c_latent: cute.Tensor, + c_rope: cute.Tensor, + page_table: cute.Tensor, + o: cute.Tensor, + lse: cute.Tensor, + workspace: cute.Tensor, + split_kv: cutlass.Int32, + cache_seqs: Optional[cute.Tensor], + block_split_kvs: Optional[cute.Tensor], + softmax_scale: cutlass.Float32, + output_scale: cutlass.Float32, + stream: cuda.CUstream, + ): + """Execute the Multi-Head Latent Attention operation on the provided tensors. + + The method handles: + 1. Initialization of workspace for temporary split KV buffers + 2. Validation of tensor data types + 3. Initialization of hardware-specific parameters and memory layouts + 4. Configuration of TMA (Tensor Memory Access) operations + 5. Grid and work scheduling computation + 6. Kernel launch(split KV kernel and reduction kernel) with appropriate parameters + + :param q_latent: The query tensor with shape [num_head, latent_dim, batch_size] + :type q_latent: cute.Tensor + :param q_rope: The query RoPE tensor with shape [num_head, rope_dim, batch_size] + :type q_rope: cute.Tensor + :param c_latent: The key tensor with shape [seq_len, latent_dim, batch_size] + :type c_latent: cute.Tensor + :param c_rope: The key RoPE tensor with shape [seq_len, rope_dim, batch_size] + :type c_rope: cute.Tensor + :param page_table: The page table tensor with shape [page_count, batch_size] + :type page_table: cute.Tensor + :param o: The output tensor with shape [num_head, latent_dim, batch_size] + :type o: cute.Tensor + :param lse: The LSE tensor with shape [num_head, batch_size] + :type lse: cute.Tensor + :param workspace: The workspace tensor with 1-d shape prepared for acc_o and acc_lse + :type workspace: cute.Tensor + :param split_kv: The scalar factor for split KV + :type split_kv: cutlass.Int32 + :param cache_seqs: The cache sequences tensor with shape [batch_size] + :type cache_seqs: cute.Tensor + :param block_split_kvs: The block split KV tensor with shape [batch_size] + :type block_split_kvs: cute.Tensor + :param softmax_scale: The scale factor for softmax + :type softmax_scale: cutlass.Float32 + :param output_scale: The scale factor for the output + :type output_scale: cutlass.Float32 + :param stream: The CUDA stream to execute the kernel on + :type stream: cuda.CUstream + + :raises TypeError: If tensor data types don't match or aren't supported + """ + + # setup static attributes before smem/grid/tma computation + self.q_dtype = q_latent.element_type + self.k_dtype = c_latent.element_type + self.v_dtype = c_latent.element_type + self.o_dtype = o.element_type + + # check type consistency + if cutlass.const_expr( + self.q_dtype != self.k_dtype or self.q_dtype != self.v_dtype + ): + raise TypeError( + f"Type mismatch: {self.q_dtype} != {self.k_dtype} or {self.q_dtype} != {self.v_dtype}" + ) + # check leading dimensions of input/output + if cutlass.const_expr(q_latent.stride[1] != 1 or q_rope.stride[1] != 1): + raise ValueError("q_latent and q_rope must have leading dimension 1") + if cutlass.const_expr(c_latent.stride[1] != 1 or c_rope.stride[1] != 1): + raise ValueError("c_latent and c_rope must have leading dimension 1") + if cutlass.const_expr(o.stride[1] != 1): + raise ValueError("o must have leading dimension 1") + if cutlass.const_expr(lse.stride[0] != 1): + raise ValueError("lse must have leading dimension 0") + + acc_o, acc_lse = self.initialize_workspace( + q_latent.shape[0], + q_latent.shape[1], + q_latent.shape[2], + split_kv, + self.acc_dtype, + workspace, + ) + + c_latent_tranpose_layout = cute.select(c_latent.layout, mode=[1, 0, 2]) + c_latent_transpose = cute.make_tensor( + c_latent.iterator, c_latent_tranpose_layout + ) + + self.q_major_mode = tcgen05.OperandMajorMode.K + self.k_major_mode = tcgen05.OperandMajorMode.K + self.v_major_mode = tcgen05.OperandMajorMode.MN + + self._setup_attributes() + + cta_group = tcgen05.CtaGroup.TWO + # the intermediate tensor p is from smem & k-major + p_major_mode = tcgen05.OperandMajorMode.K + qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, + self.q_major_mode, + self.k_major_mode, + self.acc_dtype, + cta_group, + self.mma_qk_tiler[:2], + ) + pv_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, + p_major_mode, + self.v_major_mode, + self.acc_dtype, + cta_group, + self.mma_pv_tiler[:2], + ) + + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (qk_tiled_mma.thr_id.shape,), + ) + + self.epi_tile = self.mma_pv_tiler[:2] + + q_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + self.mma_qk_tiler, + self.q_dtype, + self.load_q_stage, + ) + kc_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + self.mma_qk_tiler, + self.k_dtype, + self.load_kv_stage, + ) + p_smem_layout_staged = sm100_utils.make_smem_layout_a( + pv_tiled_mma, + self.mma_pv_tiler, + self.q_dtype, + (self.iterations_pv_k * self.p_mma_stage), + ) + p_smem_layout_staged = cute.logical_divide( + p_smem_layout_staged, (None, None, None, self.iterations_pv_k) + ) + vc_smem_layout_staged = sm100_utils.make_smem_layout_b( + pv_tiled_mma, + self.mma_pv_tiler, + self.v_dtype, + self.load_kv_stage, + ) + if cutlass.const_expr(not self.is_cpasync): + # TMA load for Q latent and rope + tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) + + q_smem_layout = cute.select(q_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q_latent, tma_tensor_q_latent = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_latent, + q_smem_layout, + self.mma_qk_tiler, + qk_tiled_mma, + cta_layout_vmnk.shape, + ) + tma_atom_q_rope, tma_tensor_q_rope = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_rope, + q_smem_layout, + self.mma_qk_tiler, + qk_tiled_mma, + cta_layout_vmnk.shape, + ) + # TMA load for c latent and k rope + kc_smem_layout = cute.select(kc_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_c_latent, tma_tensor_c_latent = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + c_latent, + kc_smem_layout, + self.mma_qk_tiler, + qk_tiled_mma, + cta_layout_vmnk.shape, + ) + tma_atom_c_rope, tma_tensor_c_rope = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + c_rope, + kc_smem_layout, + self.mma_qk_tiler, + qk_tiled_mma, + cta_layout_vmnk.shape, + ) + # TMA load for c latent transpose + vc_smem_layout = cute.select(vc_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_c_latent_transpose, tma_tensor_c_latent_transpose = ( + cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + c_latent_transpose, + vc_smem_layout, + self.mma_pv_tiler, + pv_tiled_mma, + cta_layout_vmnk.shape, + ) + ) + + q_copy_size = cute.size_in_bytes(self.q_dtype, q_smem_layout) * cute.size( + qk_tiled_mma.thr_id.shape + ) + kc_copy_size = cute.size_in_bytes(self.k_dtype, kc_smem_layout) * cute.size( + qk_tiled_mma.thr_id.shape + ) + vc_copy_size = cute.size_in_bytes(self.v_dtype, vc_smem_layout) * cute.size( + pv_tiled_mma.thr_id.shape + ) + assert kc_copy_size == vc_copy_size, ( + "kc_copy_size and vc_copy_size must be the same" + ) + + self.tma_copy_q_bytes = q_copy_size + self.tma_copy_kc_bytes = kc_copy_size + else: + self.tma_copy_q_bytes = 0 + self.tma_copy_kc_bytes = 0 + + tile_sched_params, grid = self._compute_grid( + o, + split_kv, + self.cluster_shape_mnk, + self.max_active_clusters, + self.is_persistent, + ) + + @cute.struct + class SplitKVKernelSharedStorage: + # Pipeline barriers + load_q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_q_stage * 2] + load_kv_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.load_kv_stage * 2 + ] + mma_s_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_s_stage * 2] + p_mma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_mma_stage * 2] + p_cor_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_cor_stage * 2] + mma_o_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_o_stage * 2] + load_pt_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.load_pt_stage * 2 + ] + + # Smem tensors + softmax_smem_exchange: cute.struct.MemRange[ + self.acc_dtype, self.num_compute_warps * self.threads_per_warp + ] + epilogue_smem_exchange: cute.struct.MemRange[ + self.acc_dtype, self.num_compute_warps * self.threads_per_warp + ] + + smem_page_table: cute.struct.MemRange[ + cutlass.Int32, self.load_pt_stage * self.mma_qk_tiler[1] + ] + smem_q: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(q_smem_layout_staged)], + 1024, + ] + smem_kc: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(kc_smem_layout_staged)], + 1024, + ] + smem_p: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(p_smem_layout_staged)], + 1024, + ] + # Tmem dealloc cluster barrier + tmem_dealloc_mbar_ptr: cutlass.Int64 + + # Tmem holding buffer + tmem_holding_buf: cutlass.Int32 + + softmax_scale_log2 = softmax_scale * LOG2_E + # Launch the kernel synchronously + if cutlass.const_expr(self.is_cpasync): + self.split_kv_kernel( + qk_tiled_mma, + pv_tiled_mma, + None, + q_latent, + None, + q_rope, + None, + c_latent, + None, + c_rope, + None, + c_latent_transpose, + page_table, + o, + lse, + acc_o, + acc_lse, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale_log2, + output_scale, + q_smem_layout_staged, + kc_smem_layout_staged, + p_smem_layout_staged, + vc_smem_layout_staged, + cta_layout_vmnk, + tile_sched_params, + SplitKVKernelSharedStorage, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=SplitKVKernelSharedStorage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + else: + self.split_kv_kernel( + qk_tiled_mma, + pv_tiled_mma, + tma_atom_q_latent, + tma_tensor_q_latent, + tma_atom_q_rope, + tma_tensor_q_rope, + tma_atom_c_latent, + tma_tensor_c_latent, + tma_atom_c_rope, + tma_tensor_c_rope, + tma_atom_c_latent_transpose, + tma_tensor_c_latent_transpose, + page_table, + o, + lse, + acc_o, + acc_lse, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale_log2, + output_scale, + q_smem_layout_staged, + kc_smem_layout_staged, + p_smem_layout_staged, + vc_smem_layout_staged, + cta_layout_vmnk, + tile_sched_params, + SplitKVKernelSharedStorage, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=SplitKVKernelSharedStorage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + if cutlass.const_expr(acc_o is not None): + self.reduction_kernel( + o, + lse, + acc_o, + acc_lse, + split_kv, + cache_seqs, + block_split_kvs, + ).launch( + grid=(q_latent.shape[0], 1, q_latent.shape[2]), + block=[self.threads_per_warp * self.num_compute_warps, 1, 1], + smem=MAX_SPLITS * self.acc_dtype.width // 8, + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def split_kv_kernel( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tma_atom_q_latent: Optional[cute.CopyAtom], + mQL: cute.Tensor, + tma_atom_q_rope: Optional[cute.CopyAtom], + mQR: cute.Tensor, + tma_atom_c_latent: Optional[cute.CopyAtom], + mCL: cute.Tensor, + tma_atom_c_rope: Optional[cute.CopyAtom], + mKR: cute.Tensor, + tma_atom_c_latent_transpose: Optional[cute.CopyAtom], + mCLT: cute.Tensor, + mPT: cute.Tensor, + mO: Optional[cute.Tensor], + mLSE: Optional[cute.Tensor], + mAccO: Optional[cute.Tensor], + mAccLSE: Optional[cute.Tensor], + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + softmax_scale_log2: cutlass.Float32, + output_scale: cutlass.Float32, + q_smem_layout_staged: cute.ComposedLayout, + kc_smem_layout_staged: cute.ComposedLayout, + p_smem_layout_staged: cute.ComposedLayout, + vc_smem_layout_staged: cute.ComposedLayout, + cta_layout_vmnk: cute.Layout, + tile_sched_params: MLAStaticTileSchedulerParams, + SharedStorage: cutlass.Constexpr, + ): + """The device split_kv kernel implementation of the Multi-Head Latent Attention. + + This kernel coordinates multiple specialized warps to perform different phases of the MLA computation: + 1. Load warp: Loads Q/C latent/rope data from global memory to shared memory using TMA + 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) + 3. Compute warps: Compute softmax and do rescaling on accumulators, and store the intermediate/final results + to global memory + + The kernel produces either intermediate or final results of the MLA computation based on the split_kv parameter. + When split_kv is 1, the kernel generates the final results directly. Otherwise, it produces intermediate results + that will later be combined by a reduction kernel. + + The kernel implements a complex pipeline with overlapping computation and memory operations, + using tensor memory access (TMA) for efficient data loading, warp specialization for different + computation phases. + + :param tiled_mma_qk: Tiled MMA for Q*K^T + :type tiled_mma_qk: cute.TiledMma + :param tiled_mma_pv: Tiled MMA for P*V + :type tiled_mma_pv: cute.TiledMma + :param tma_atom_q_latent: TMA copy atom for query latent tensor + :type tma_atom_q_latent: cute.CopyAtom + :param mQL: query latent tensor + :type mQL: cute.Tensor + :param tma_atom_q_rope: TMA copy atom for query rope tensor + :type tma_atom_q_rope: cute.CopyAtom + :param mKR: Compressed rope tensor + :type mKR: cute.Tensor + :param tma_atom_c_latent: TMA copy atom for c latent tensor + :type tma_atom_c_latent: cute.CopyAtom + :param mCL: Compressed latent tensor + :type mCL: cute.Tensor + :param tma_atom_c_rope: TMA copy atom for c rope tensor + :type tma_atom_c_rope: cute.CopyAtom + :param mCLT: Compressed latent transpose tensor + :type mCLT: cute.Tensor + :param mPT: Page table tensor + :type mPT: cute.Tensor + :param mO: Output tensor + :type mO: cute.Tensor + :param mLSE: Log-sum-exp tensor + :type mLSE: cute.Tensor + :param mAccO: Intermediate accumulator output tensor + :type mAccO: cute.Tensor + :param mAccLSE: Intermediate accumulator log-sum-exp tensor + :type mAccLSE: cute.Tensor + :param split_kv: The split_kv parameter + :type split_kv: cutlass.Int32 + :param cache_seqs: The variable sequence length tensor + :type cache_seqs: cute.Tensor + :param block_split_kvs: The per-block split_kv values tensor + :type block_split_kvs: cute.Tensor + :param softmax_scale_log2: The log2 scale factor for softmax + :type softmax_scale_log2: cutlass.Float32 + :param output_scale: The scale factor for the output + :type output_scale: cutlass.Float32 + :param q_smem_layout_staged: Shared memory layout for query tensor + :type q_smem_layout_staged: cute.ComposedLayout + :param kc_smem_layout_staged: Shared memory layout for key tensor + :type kc_smem_layout_staged: cute.ComposedLayout + :param p_smem_layout_staged: Shared memory layout for probability matrix + :type p_smem_layout_staged: cute.ComposedLayout + :param vc_smem_layout_staged: Shared memory layout for value tensor + :type vc_smem_layout_staged: cute.ComposedLayout + :param cta_layout_vmnk: Layout for compute threads + :type cta_layout_vmnk: cute.Layout + :param tile_sched_params: Scheduling parameters for work distribution + :type tile_sched_params: MLAStaticTileSchedulerParams + :param SharedStorage: Shared storage for the kernel + :type SharedStorage: cutlass.Constexpr + """ + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma_qk.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + + # Coords inside cluster + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + + # Prefetch tma descriptor + if cutlass.const_expr(not self.is_cpasync): + if warp_idx == self.mma_warp_id: + cpasync.prefetch_descriptor(tma_atom_q_latent) + cpasync.prefetch_descriptor(tma_atom_q_rope) + cpasync.prefetch_descriptor(tma_atom_c_latent) + cpasync.prefetch_descriptor(tma_atom_c_rope) + cpasync.prefetch_descriptor(tma_atom_c_latent_transpose) + + # Alloc + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Tensor memory dealloc barrier init + if warp_idx == self.mma_warp_id: + num_tmem_dealloc_threads = self.threads_per_warp * self.num_compute_warps + with cute.arch.elect_one(): + cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads) + cute.arch.mbarrier_init_fence() + + load_q_pipeline = self.make_and_init_load_qkv_pipeline( + storage.load_q_mbar_ptr.data_ptr(), + cta_layout_vmnk, + self.load_q_stage, + self.tma_copy_q_bytes, + self.is_cpasync, + ) + load_kv_pipeline = self.make_and_init_load_qkv_pipeline( + storage.load_kv_mbar_ptr.data_ptr(), + cta_layout_vmnk, + self.load_kv_stage, + self.tma_copy_kc_bytes, + self.is_cpasync, + ) + mma_s_pipeline = self.make_and_init_mma_s_pipeline( + storage.mma_s_mbar_ptr.data_ptr(), cta_layout_vmnk + ) + p_mma_pipeline = self.make_and_init_p_mma_pipeline( + storage.p_mma_mbar_ptr.data_ptr(), cta_layout_vmnk + ) + p_cor_pipeline = self.make_and_init_p_cor_pipeline( + storage.p_cor_mbar_ptr.data_ptr() + ) + mma_o_pipeline = self.make_and_init_mma_o_pipeline( + storage.mma_o_mbar_ptr.data_ptr(), cta_layout_vmnk + ) + if cutlass.const_expr(self.is_cpasync): + load_pt_pipeline = self.make_and_init_load_pt_pipeline( + storage.load_pt_mbar_ptr.data_ptr() + ) + + # Cluster arrive after barrier init + if cutlass.const_expr(cute.size(self.cluster_shape_mnk) > 1): + cute.arch.cluster_arrive_relaxed() + + # Generate smem tensor Q/KC/VC/exchange + # (MMA, MMA_H, MMA_R, PIPE) + sQ = storage.smem_q.get_tensor( + q_smem_layout_staged.outer, swizzle=q_smem_layout_staged.inner + ) + # (MMA, MMA_K, MMA_R, PIPE) + sKC = storage.smem_kc.get_tensor( + kc_smem_layout_staged.outer, swizzle=kc_smem_layout_staged.inner + ) + # (MMA, MMA_D, MMA_K, PIPE) + # reuse smem + sVC_ptr = cute.recast_ptr(sKC.iterator, vc_smem_layout_staged.inner) + sVC = cute.make_tensor(sVC_ptr, vc_smem_layout_staged.outer) + # (MMA, MMA_H, MMA_K) + sP = storage.smem_p.get_tensor( + p_smem_layout_staged.outer, swizzle=p_smem_layout_staged.inner + ) + # (compute_threads,) + softmax_smem_exchange = storage.softmax_smem_exchange.get_tensor( + cute.make_layout(self.num_compute_warps * self.threads_per_warp) + ) + epilogue_smem_exchange = storage.epilogue_smem_exchange.get_tensor( + cute.make_layout(self.num_compute_warps * self.threads_per_warp) + ) + + # + # Cluster wait before tensor memory alloc + # + if cutlass.const_expr(cute.size(self.cluster_shape_mnk) > 1): + cute.arch.cluster_wait() + else: + pipeline.sync(barrier_id=4) + + # /////////////////////////////////////////////////////////////////////////////// + # Load warps, including page table and data tensors + # /////////////////////////////////////////////////////////////////////////////// + if cutlass.const_expr(self.is_cpasync): + sPT = storage.smem_page_table.get_tensor( + cute.make_layout((self.mma_qk_tiler[1], self.load_pt_stage)) + ) + # Load page table when isasync is true + if warp_idx == self.load_pt_warp_id: + cute.arch.warpgroup_reg_dealloc(self.other_reg_num) + load_pt_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_pt_stage + ) + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, + cache_seqs, + block_split_kvs, + blk_coord, + ) + if k_tile_count > 0: + load_pt_common_params = SimpleNamespace( + blk_coord=blk_coord, + load_pt_pipeline=load_pt_pipeline, + mPT=mPT, + sPT=sPT, + tidx=tidx, + page_size=mCL.shape[0], + ) + load_pt_producer_state = self.load_page_table( + load_pt_common_params, + k_index, + k_tile_count, + load_pt_producer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + load_pt_pipeline.producer_tail(load_pt_producer_state) + + if ( + warp_idx == self.load_cp_async_warp_ids[0] + or warp_idx == self.load_cp_async_warp_ids[1] + ): + cute.arch.warpgroup_reg_dealloc(self.other_reg_num) + load_pt_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_pt_stage + ) + load_pt_release_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_pt_stage + ) + load_q_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_q_stage + ) + load_kv_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_kv_stage + ) + load_kv_commit_state = load_kv_producer_state.clone() + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, + cache_seqs, + block_split_kvs, + blk_coord, + ) + if k_tile_count > 0: + load_cpasync_common_params = SimpleNamespace( + blk_coord=blk_coord, + load_pt_pipeline=load_pt_pipeline, + load_q_pipeline=load_q_pipeline, + load_kv_pipeline=load_kv_pipeline, + sPT=sPT, + tidx=tidx, + page_size=mCL.shape[0], + ) + load_cpasync_qk_params = SimpleNamespace( + tiled_mma_qk=tiled_mma_qk, + mQL=mQL, + mQR=mQR, + mCL=mCL, + mKR=mKR, + sQ=sQ, + sKC=sKC, + ) + load_cpasync_v_params = SimpleNamespace( + tiled_mma_pv=tiled_mma_pv, + mCLT=mCLT, + sVC=sVC, + ) + ( + load_pt_consumer_state, + load_pt_release_state, + load_q_producer_state, + load_kv_producer_state, + load_kv_commit_state, + ) = self.load_cpasync( + load_cpasync_common_params, + load_cpasync_qk_params, + load_cpasync_v_params, + k_index, + k_tile_count, + load_pt_consumer_state, + load_pt_release_state, + load_q_producer_state, + load_kv_producer_state, + load_kv_commit_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + load_q_pipeline.producer_tail(load_q_producer_state) + load_kv_pipeline.producer_tail(load_kv_producer_state) + else: + if ( + warp_idx >= self.empty_warp_ids[0] + and warp_idx <= self.empty_warp_ids[-1] + ): + cute.arch.warpgroup_reg_dealloc(self.other_reg_num) + if warp_idx == self.load_tma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.other_reg_num) + load_q_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_q_stage + ) + load_kv_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_kv_stage + ) + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, + cache_seqs, + block_split_kvs, + blk_coord, + ) + if k_tile_count > 0: + # Construct fixed common/tma_qk/tma_pv params for load_tma + tma_common_params = SimpleNamespace( + blk_coord=blk_coord, + local_split_kv=local_split_kv, + load_q_pipeline=load_q_pipeline, + load_kv_pipeline=load_kv_pipeline, + mPT=mPT, + ) + tma_qk_params = SimpleNamespace( + tiled_mma_qk=tiled_mma_qk, + tma_atom_q_latent=tma_atom_q_latent, + tma_atom_q_rope=tma_atom_q_rope, + tma_atom_c_latent=tma_atom_c_latent, + tma_atom_c_rope=tma_atom_c_rope, + mQL=mQL, + mQR=mQR, + mCL=mCL, + mKR=mKR, + sQ=sQ, + sKC=sKC, + ) + tma_pv_params = SimpleNamespace( + tiled_mma_pv=tiled_mma_pv, + tma_atom_c_latent_transpose=tma_atom_c_latent_transpose, + mCL=mCL, + mKR=mKR, + mCLT=mCLT, + sVC=sVC, + ) + # Load tma + load_q_producer_state, load_kv_producer_state = self.load_tma( + tma_common_params, + tma_qk_params, + tma_pv_params, + k_index, + k_tile_count, + load_q_producer_state, + load_kv_producer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + load_q_pipeline.producer_tail(load_q_producer_state) + load_kv_pipeline.producer_tail(load_kv_producer_state) + + # /////////////////////////////////////////////////////////////////////////////// + # MMA warp + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.other_reg_num) + # Alloc tensor memory buffer + cute.arch.alloc_tmem( + cute.arch.SM100_TMEM_CAPACITY_COLUMNS, + tmem_holding_buf, + is_two_cta=self.use_2cta_instrs, + ) + + # sync with compute warp before tmem ptr is retrieved + self.tmem_ptr_sync_bar.arrive() + + # Retrieving tensor memory ptr and make accumulator tensor + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + + load_q_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_q_stage + ) + load_kv_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_kv_stage + ) + mma_s_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_s_stage + ) + p_mma_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.p_mma_stage + ) + mma_o_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_o_stage + ) + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + mma_common_params = SimpleNamespace( + blk_coord=blk_coord, + local_split_kv=local_split_kv, + load_q_pipeline=load_q_pipeline, + load_kv_pipeline=load_kv_pipeline, + tmem_ptr=tmem_ptr, + is_leader_cta=is_leader_cta, + L=mCL.shape[1], + ) + mma_qk_params = SimpleNamespace( + mma_s_pipeline=mma_s_pipeline, + sQ=sQ, + sKC=sKC, + ) + mma_pv_params = SimpleNamespace( + p_mma_pipeline=p_mma_pipeline, + mma_o_pipeline=mma_o_pipeline, + sP=sP, + sVC=sVC, + ) + ( + tiled_mma_qk, + tiled_mma_pv, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) = self.mma( + mma_common_params, + mma_qk_params, + mma_pv_params, + k_tile_count, + tiled_mma_qk, + tiled_mma_pv, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + mma_s_pipeline.producer_tail(mma_s_producer_state) + mma_o_pipeline.producer_tail(mma_o_producer_state) + + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=self.use_2cta_instrs) + # Dealloc the tensor memory buffer + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + + cute.arch.dealloc_tmem( + tmem_ptr, + cute.arch.SM100_TMEM_CAPACITY_COLUMNS, + is_two_cta=self.use_2cta_instrs, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Compute warp + # /////////////////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.compute_warp_ids[0] + and warp_idx <= self.compute_warp_ids[-1] + ): + cute.arch.warpgroup_reg_alloc(self.softmax_reg_num) + mma_s_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_s_stage + ) + p_mma_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.p_mma_stage + ) + p_cor_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.p_cor_stage + ) + mma_o_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_o_stage + ) + # sync with mma warp before retrieving tmem ptr + self.tmem_ptr_sync_bar.wait() + + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + compute_common_params = SimpleNamespace( + blk_coord=blk_coord, + split_kv=split_kv, + local_split_kv=local_split_kv, + smem_exchange=softmax_smem_exchange, + mAccO=mAccO, + mO=mO, + K=cache_seqs[blk_coord[2]], + L=mCL.shape[1], + tmem_ptr=tmem_ptr, + tidx=tidx, + p_cor_pipeline=p_cor_pipeline, + ) + compute_softmax_params = SimpleNamespace( + tiled_mma_qk=tiled_mma_qk, + sP=sP, + mma_s_pipeline=mma_s_pipeline, + p_mma_pipeline=p_mma_pipeline, + softmax_scale_log2=softmax_scale_log2, + ) + mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state = ( + self.compute( + compute_common_params, + compute_softmax_params, + k_index=k_index, + k_tile_count=k_tile_count, + mma_s_consumer_state=mma_s_consumer_state, + p_mma_producer_state=p_mma_producer_state, + p_cor_producer_state=p_cor_producer_state, + ) + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + p_cor_pipeline.producer_tail(p_cor_producer_state) + + # /////////////////////////////////////////////////////////////////////////////// + # Correction warp + # /////////////////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.correction_warp_ids[0] + and warp_idx <= self.correction_warp_ids[-1] + ): + cute.arch.warpgroup_reg_alloc(self.correction_reg_num) + p_cor_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.p_cor_stage + ) + mma_o_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_o_stage + ) + # sync with mma warp before retrieving tmem ptr + self.tmem_ptr_sync_bar.wait() + + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + compute_common_params = SimpleNamespace( + blk_coord=blk_coord, + split_kv=split_kv, + local_split_kv=local_split_kv, + smem_exchange=epilogue_smem_exchange, + mAccO=mAccO, + mO=mO, + K=cache_seqs[blk_coord[2]], + L=mCL.shape[1], + H=mQL.shape[0], + tmem_ptr=tmem_ptr, + tidx=tidx, + tiled_mma_pv=tiled_mma_pv, + p_cor_pipeline=p_cor_pipeline, + mma_o_pipeline=mma_o_pipeline, + ) + compute_epilogue_params = SimpleNamespace( + output_scale=output_scale, + softmax_scale_log2=softmax_scale_log2, + mAccLSE=mAccLSE, + mLSE=mLSE, + ) + p_cor_consumer_state, mma_o_consumer_state = self.correction( + compute_common_params, + compute_epilogue_params, + k_tile_count=k_tile_count, + p_cor_consumer_state=p_cor_consumer_state, + mma_o_consumer_state=mma_o_consumer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # Arrive for the tensor memory deallocation barrier + cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1) + + return + + @cute.kernel + def reduction_kernel( + self, + mO: cute.Tensor, + mLSE: cute.Tensor, + mAccO: cute.Tensor, + mAccLSE: cute.Tensor, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + ): + """The reduction kernel for Multi-Head Latent Attention (MLA) that combines intermediate results + from multiple split_kv blocks into final outputs. + + :param mO: Output tensor for storing final results + :type mO: cute.Tensor + :param mLSE: Log-sum-exp tensor for storing final LSE values + :type mLSE: cute.Tensor + :param mAccO: Accumulated output tensor from split_kv blocks + :type mAccO: cute.Tensor + :param mAccLSE: Accumulated LSE tensor from split_kv blocks + :type mAccLSE: cute.Tensor + :param split_kv: Number of split_kv blocks + :type split_kv: cutlass.Int32 + :param cache_seqs: Cache sequence lengths tensor + :type cache_seqs: cute.Tensor + :param block_split_kvs: Per-block split_kv values tensor (for variable split_kv) + :type block_split_kvs: cute.Tensor + """ + bidx, _, bidz = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + blk_coord = (bidx, 0, bidz) + local_split_kv = ( + block_split_kvs[blk_coord[2]] if self.is_var_split_kv else split_kv + ) + k_tile_total = cute.ceil_div(cache_seqs[blk_coord[2]], self.mma_qk_tiler[1]) + k_tile_per_cta = cute.ceil_div(k_tile_total, local_split_kv) + local_split_kv = cute.ceil_div(k_tile_total, k_tile_per_cta) + + # Alloc shared memory + smem = utils.SmemAllocator() + storage = smem.allocate(MAX_SPLITS * self.acc_dtype.width // 8, 16) + lse_scale_ptr = cute.recast_ptr(storage, dtype=self.acc_dtype) + smem_lse_scale = cute.make_tensor(lse_scale_ptr, cute.make_layout(MAX_SPLITS)) + + gLSE = mAccLSE[blk_coord[0], None, blk_coord[2]] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + if warp_idx == 0: + # calculate the global lse and exp ^ (local_lse - global_lse) + lse_per_thread = cute.ceil_div(MAX_SPLITS, self.threads_per_warp) + + local_lse = cute.make_rmem_tensor( + cute.make_layout(lse_per_thread), self.lse_dtype + ) + lse_max = -self.lse_dtype.inf + # find the max lse + for i in cutlass.range_constexpr(lse_per_thread): + split_kv_idx = tidx + i * self.threads_per_warp + local_lse[i] = ( + gLSE[split_kv_idx] + if cute.elem_less(split_kv_idx, local_split_kv) + else -self.lse_dtype.inf + ) + # reduce the local lse + lse_max = cute.arch.fmax(lse_max, local_lse[i]) + lse_max = cute.arch.warp_reduction_max(lse_max) + lse_max = lse_max if lse_max != -self.lse_dtype.inf else 0.0 + # calculate sum_lse + sum_lse = 0.0 + for i in cutlass.range_constexpr(lse_per_thread): + sum_lse += cute.math.exp2(local_lse[i] - lse_max, fastmath=True) + sum_lse = cute.arch.warp_reduction_sum(sum_lse) + # calculate the global_lse + global_lse = ( + lse_max + cute.math.log2(sum_lse, fastmath=True) + if not sum_lse == self.lse_dtype(0.0) or sum_lse != sum_lse + else self.lse_dtype.inf + ) + if tidx == 0: + mLSE[blk_coord[0], blk_coord[2]] = global_lse + # store the scale to shared memory + for i in cutlass.range_constexpr(lse_per_thread): + split_kv_idx = tidx + i * self.threads_per_warp + if cute.elem_less(split_kv_idx, local_split_kv): + smem_lse_scale[split_kv_idx] = cute.math.exp2( + local_lse[i] - global_lse, fastmath=True + ) + + pipeline.sync(barrier_id=4) + + elements_per_thread = cute.ceil_div( + self.latent_dim, self.threads_per_warp * self.num_compute_warps + ) + gAccO = mAccO[blk_coord[0], None, None, blk_coord[2]] + rAccO = cute.make_rmem_tensor( + cute.make_layout(elements_per_thread), self.acc_dtype + ) + rO = cute.make_rmem_tensor(cute.make_layout(elements_per_thread), self.o_dtype) + rAccO.fill(0.0) + for i in range(local_split_kv): + for j in cutlass.range_constexpr(elements_per_thread): + element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps + rAccO[j] += gAccO[i, element_idx] * smem_lse_scale[i] + rO.store(rAccO.load().to(self.o_dtype)) + for j in cutlass.range_constexpr(elements_per_thread): + element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps + mO[blk_coord[0], element_idx, blk_coord[2]] = rO[j] + return + + @staticmethod + def get_split_kv( + B: int, K: int, mma_qk_tiler_mn: tuple, max_active_blocks: int + ) -> int: + """Get the proper split_kv value for the MLA kernel based on parameters. + + :param B: Batch size + :type B: int + :param K: Sequence length + :type K: int + :param mma_qk_tiler_mn: MLA tiling parameters + :type mma_qk_tiler_mn: tuple + :param max_active_blocks: Maximum number of active blocks + :type max_active_blocks: int + :return: Split_kv value + :rtype: int + """ + max_splits = ceil_div(K, mma_qk_tiler_mn[1]) + blocks_per_batch = max(1, max_active_blocks // B) + split_heur = min(max_splits, blocks_per_batch) + k_waves = ceil_div(max_splits, split_heur) + split_wave_aware = ceil_div(max_splits, k_waves) + return split_wave_aware + + @cute.jit + def get_k_tile_count( + self, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + blk_coord: cute.Coord, + ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: + """Get the current k_index, k_tile_count, and local split_kv value for the MLA kernel. + + :param split_kv: Split_kv value + :type split_kv: cutlass.Int32 + :param cache_seqs: Cache sequence lengths tensor + :type cache_seqs: cute.Tensor + :param block_split_kvs: Per-block split_kv values tensor + :type block_split_kvs: cute.Tensor + :param blk_coord: Block coordinate + :type blk_coord: cute.Coord + :return: k_index, k_tile_count, split_kv + :rtype: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] + """ + K = cache_seqs[blk_coord[2]] + if cutlass.const_expr(self.is_var_split_kv): + split_kv = block_split_kvs[blk_coord[2]] + + k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) + k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) + k_index = blk_coord[3] * k_tile_per_cta + k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) + return k_index, k_tile_count, split_kv + + @cute.jit + def load_page_table( + self, + common_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + load_pt_producer_state: pipeline.PipelineState, + ) -> pipeline.PipelineState: + """Load warp to load page table. Updates the load pt producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param load_pt_producer_state: The load pt producer state + :type load_pt_producer_state: pipeline.PipelineState + + :return: The load pt producer state + :rtype: pipeline.PipelineState + """ + mPT = common_params.mPT[None, common_params.blk_coord[2]] + page_per_tile = self.mma_qk_tiler[1] >> cute.arch.log2_of_pow2_int( + common_params.page_size + ) + tidx = common_params.tidx % self.threads_per_warp + + load_pt_pipeline = common_params.load_pt_pipeline + while k_tile_count > 0: + load_pt_pipeline.producer_acquire(load_pt_producer_state) + + elem_per_thread = cute.ceil_div(page_per_tile, self.threads_per_warp) + + # atom_async_copy: async copy atom for page table load + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + cutlass.Int32, + num_bits_per_copy=cutlass.Int32.width, + ) + mPT_for_copy = cute.flat_divide(mPT, (1,)) + sPT_for_copy = cute.flat_divide(common_params.sPT, (1,)) + # elem_per_thread is a dynamic value depends on the page_size setting. + for i in range(elem_per_thread): + idx = i * self.threads_per_warp + tidx + if cute.elem_less( + k_index * page_per_tile + idx, mPT.shape[0] + ) and cute.elem_less(idx, page_per_tile): + cute.copy( + atom_async_copy, + mPT_for_copy[None, k_index * page_per_tile + idx], + sPT_for_copy[None, idx, load_pt_producer_state.index], + ) + else: + sPT_for_copy[None, idx, load_pt_producer_state.index].fill(0) + mbar_ptr = load_pt_pipeline.producer_get_barrier(load_pt_producer_state) + load_pt_pipeline.producer_commit(load_pt_producer_state) + load_pt_producer_state.advance() + k_index += 1 + k_tile_count -= 1 + + return load_pt_producer_state + + @cute.jit + def load_cpasync( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + v_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + load_pt_consumer_state: pipeline.PipelineState, + load_pt_release_state: pipeline.PipelineState, + load_q_producer_state: pipeline.PipelineState, + load_kv_producer_state: pipeline.PipelineState, + load_kv_commit_state: pipeline.PipelineState, + ) -> tuple[ + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """Load warp to load cpasync. Updates the load cpasync producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param load_pt_consumer_state: The load pt consumer state + :type load_pt_consumer_state: pipeline.PipelineState + :param load_pt_release_state: The load pt release state + :type load_pt_release_state: pipeline.PipelineState + :param load_q_producer_state: The load q producer state + :type load_q_producer_state: pipeline.PipelineState + :param load_kv_producer_state: The load kv producer state + :type load_kv_producer_state: pipeline.PipelineState + :param load_kv_commit_state: The load kv commit state + :type load_kv_commit_state: pipeline.PipelineState + + :return: The load pt consumer state, the load pt release state, the load q producer state, the load kv producer state, the load kv commit state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + + tidx = ( + common_params.tidx - self.threads_per_warp * self.load_cp_async_warp_ids[0] + ) + + # slice view the the global tensors for cpasync, their coords are from counting tensor coord. + mCL_for_slice = cute.make_tensor( + qk_params.mCL.iterator, + cute.make_layout( + ( + (qk_params.mCL.shape[0], qk_params.mCL.shape[2]), + qk_params.mCL.shape[1], + ), + stride=( + (qk_params.mCL.stride[0], qk_params.mCL.stride[2]), + qk_params.mCL.stride[1], + ), + ), + ) + mKR_for_slice = cute.make_tensor( + qk_params.mKR.iterator, + cute.make_layout( + ( + (qk_params.mKR.shape[0], qk_params.mKR.shape[2]), + qk_params.mKR.shape[1], + ), + stride=( + (qk_params.mKR.stride[0], qk_params.mKR.stride[2]), + qk_params.mKR.stride[1], + ), + ), + ) + mCLT_for_slice = cute.make_tensor( + v_params.mCLT.iterator, + cute.make_layout( + ( + v_params.mCLT.shape[0], + (v_params.mCLT.shape[1], v_params.mCLT.shape[2]), + ), + stride=( + v_params.mCLT.stride[0], + (v_params.mCLT.stride[1], v_params.mCLT.stride[2]), + ), + ), + ) + + # make identity tensor for partition + mCL_for_partition = cute.make_identity_tensor( + (qk_params.mCL.shape[0] * qk_params.mCL.shape[2], qk_params.mCL.shape[1]) + ) + mKR_for_partition = cute.make_identity_tensor( + (qk_params.mKR.shape[0] * qk_params.mKR.shape[2], qk_params.mKR.shape[1]) + ) + mCLT_for_partition = cute.make_identity_tensor( + (v_params.mCLT.shape[0], v_params.mCLT.shape[1] * v_params.mCLT.shape[2]) + ) + + # Flatten divide and partition global tensors for QK TMA load + # (bM, bK, rM, rK, rL) + mma_qk_tiler_mk = cute.select(self.mma_qk_tiler, mode=[0, 2]) + gQL = cute.flat_divide(qk_params.mQL, mma_qk_tiler_mk) + gQR = cute.flat_divide(qk_params.mQR, mma_qk_tiler_mk) + + mma_qk_tiler_nk = cute.select(self.mma_qk_tiler, mode=[1, 2]) + gCL = cute.flat_divide(mCL_for_partition, mma_qk_tiler_nk) + gKR = cute.flat_divide(mKR_for_partition, mma_qk_tiler_nk) + + thr_mma_qk = qk_params.tiled_mma_qk.get_slice( + common_params.blk_coord[0] % cute.size(qk_params.tiled_mma_qk.thr_id) + ) + tSgQL = thr_mma_qk.partition_A(gQL) + tSgQR = thr_mma_qk.partition_A(gQR) + + tSgCL = thr_mma_qk.partition_B(gCL) + tSgKR = thr_mma_qk.partition_B(gKR) + + # create cpasync tiled copy qk + cpasync_bits = 128 + # thread for copy + thread = self.threads_per_warp * self.num_load_warps + # Value for copy + value = cpasync_bits // self.q_dtype.width + cpasync_q_tiled_copy = cute.make_cotiled_copy( + cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.q_dtype, + num_bits_per_copy=cpasync_bits, + ), + cute.make_ordered_layout((thread, value), (1, 0)), + qk_params.sQ[None, None, None, 0].layout, + ) + cpasync_kc_tiled_copy = cute.make_cotiled_copy( + cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.q_dtype, + num_bits_per_copy=cpasync_bits, + ), + cute.make_ordered_layout((thread, value), (1, 0)), + qk_params.sKC[None, None, None, 0].layout, + ) + cpasync_q_thr_copy = cpasync_q_tiled_copy.get_slice(tidx) + cpasync_kc_thr_copy = cpasync_kc_tiled_copy.get_slice(tidx) + # copy async partition + tQgQL = cpasync_q_thr_copy.partition_S(tSgQL) + tQgQR = cpasync_q_thr_copy.partition_S(tSgQR) + tQsQ = cpasync_q_thr_copy.partition_D(qk_params.sQ) + + tKCgCL = cpasync_kc_thr_copy.partition_S(tSgCL) + tKCgKR = cpasync_kc_thr_copy.partition_S(tSgKR) + tKCsKC = cpasync_kc_thr_copy.partition_D(qk_params.sKC) + + gCLT = cute.flat_divide( + mCLT_for_partition, cute.select(self.mma_pv_tiler, mode=[1, 2]) + ) + thr_mma_pv = v_params.tiled_mma_pv.get_slice( + common_params.blk_coord[0] % cute.size(v_params.tiled_mma_pv.thr_id) + ) + tOgCLT = thr_mma_pv.partition_B(gCLT) + + cpasync_v_tiled_copy = cute.make_cotiled_copy( + cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.q_dtype, + num_bits_per_copy=cpasync_bits, + ), + cute.make_ordered_layout((thread, value), (1, 0)), + v_params.sVC[None, None, None, 0].layout, + ) + cpasync_v_thr_copy = cpasync_v_tiled_copy.get_slice(tidx) + tVCgCLT = cpasync_v_thr_copy.partition_S(tOgCLT) + tVCsVC = cpasync_v_thr_copy.partition_D(v_params.sVC) + + # Use to record the in-flight cpasync stage count, wait and producer commit until `load_kv_stage - 1` cpasync arrive + copy_in_flight_count = cutlass.Int32(0) + + qk_params.tiled_copy_q = cpasync_q_tiled_copy + qk_params.tiled_copy_kc = cpasync_kc_tiled_copy + qk_params.mCL_for_slice = mCL_for_slice + qk_params.mKR_for_slice = mKR_for_slice + qk_params.tQgQL = tQgQL + qk_params.tQgQR = tQgQR + qk_params.tQsQ = tQsQ + qk_params.tKCgCL = tKCgCL + qk_params.tKCgKR = tKCgKR + qk_params.tKCsKC = tKCsKC + + v_params.tiled_copy_vc = cpasync_v_tiled_copy + v_params.tVCgCLT = tVCgCLT + v_params.tVCsVC = tVCsVC + v_params.mCLT_for_slice = mCLT_for_slice + + # first load qk latent/rope + ( + load_pt_consumer_state, + load_q_producer_state, + load_kv_producer_state, + load_kv_commit_state, + copy_in_flight_count, + ) = self.load_cpasync_qk_one_k_tile( + common_params, + qk_params, + k_index, + load_pt_consumer_state, + load_q_producer_state, + load_kv_producer_state, + load_kv_commit_state, + copy_in_flight_count, + load_q=True, + ) + + k_index += 1 + k_tile_count -= 1 + + # mainloop, load qk and v + while k_tile_count > 0: + ( + load_pt_consumer_state, + load_q_producer_state, + load_kv_producer_state, + load_kv_commit_state, + copy_in_flight_count, + ) = self.load_cpasync_qk_one_k_tile( + common_params, + qk_params, + k_index, + load_pt_consumer_state, + load_q_producer_state, + load_kv_producer_state, + load_kv_commit_state, + copy_in_flight_count, + load_q=False, + ) + ( + load_pt_release_state, + load_kv_producer_state, + load_kv_commit_state, + copy_in_flight_count, + ) = self.load_cpasync_v_one_k_tile( + common_params, + v_params, + k_index - 1, + load_pt_release_state, + load_kv_producer_state, + load_kv_commit_state, + copy_in_flight_count, + ) + k_index += 1 + k_tile_count -= 1 + + # load last tile of v + ( + load_pt_release_state, + load_kv_producer_state, + load_kv_commit_state, + copy_in_flight_count, + ) = self.load_cpasync_v_one_k_tile( + common_params, + v_params, + k_index - 1, + load_pt_release_state, + load_kv_producer_state, + load_kv_commit_state, + copy_in_flight_count, + ) + + padding_in_flight = 0 + while copy_in_flight_count + padding_in_flight < self.load_kv_stage - 1: + padding_in_flight += 1 + cute.arch.cp_async_commit_group() + # wait for previous cpasync arrive + load_kv_pipeline = common_params.load_kv_pipeline + while copy_in_flight_count > 0: + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(self.load_kv_stage - 1) + load_kv_pipeline.producer_commit(load_kv_commit_state) + load_kv_commit_state.advance() + copy_in_flight_count -= 1 + + # wait all cpasync arrive + cute.arch.cp_async_wait_group(0) + return ( + load_pt_consumer_state, + load_pt_release_state, + load_q_producer_state, + load_kv_producer_state, + load_kv_commit_state, + ) + + @cute.jit + def load_cpasync_one_smem_stage( + self, + common_params: SimpleNamespace, + load_q_producer_state: pipeline.PipelineState, + load_kv_producer_state: pipeline.PipelineState, + load_kv_commit_state: pipeline.PipelineState, + copy_func: Callable, + copy_in_flight_count: cutlass.Int32, + load_q: bool, + ) -> tuple[ + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + cutlass.Int32, + ]: + """Load one smem stage of cpasync. Reused for qkv load stages. Updates the load cpasync producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param load_pt_consumer_state: The load pt consumer state + :type load_pt_consumer_state: pipeline.PipelineState + :param load_q_producer_state: The load q producer state + :type load_q_producer_state: pipeline.PipelineState + :param load_kv_producer_state: The load kv producer state + :type load_kv_producer_state: pipeline.PipelineState + :param load_kv_commit_state: The load kv commit state + :type load_kv_commit_state: pipeline.PipelineState + :param copy_func: The copy function + :type copy_func: Callable + :param copy_in_flight_count: The copy in-flight count + :type copy_in_flight_count: cutlass.Int32 + :param load_q: Whether to load q + :type load_q: bool + + :return: The load q producer state, the load kv producer state, the load kv commit state, the copy in-flight count + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Int32] + """ + if cutlass.const_expr(load_q): + common_params.load_q_pipeline.producer_acquire(load_q_producer_state) + common_params.load_kv_pipeline.producer_acquire(load_kv_producer_state) + producer_index = load_kv_producer_state.index + copy_func(producer_index) + cute.arch.cp_async_commit_group() + + if cutlass.const_expr(load_q): + # directly commit the q producer state here, mma will wait for kv. + common_params.load_q_pipeline.producer_commit(load_q_producer_state) + load_q_producer_state.advance() + load_kv_producer_state.advance() + copy_in_flight_count += 1 + + # wait cpasync arrive until the last stage + load_kv_pipeline = common_params.load_kv_pipeline + if copy_in_flight_count == self.load_kv_stage: + cute.arch.cp_async_wait_group(self.load_kv_stage - 1) + load_kv_pipeline.producer_commit(load_kv_commit_state) + load_kv_commit_state.advance() + copy_in_flight_count -= 1 + + return ( + load_q_producer_state, + load_kv_producer_state, + load_kv_commit_state, + copy_in_flight_count, + ) + + @cute.jit + def load_cpasync_page_table_lookup_copy( + self, + tiled_copy: cute.TiledCopy, + gKV: cute.Tensor, + sKV: cute.Tensor, + sPT: cute.Tensor, + gKV_for_slice: cute.Tensor, + k_index: cutlass.Int32, + latent_idx: cutlass.Int32, + qkv_stage_idx: cutlass.Int32, + page_table_stage: cutlass.Int32, + page_size: cutlass.Int32, + transpose: bool = False, + ): + """Make page table lookup for KV cache latent/rope, then do atom copy of cpasync. + + :param tiled_copy: The tiled copy + :type tiled_copy: cute.TiledCopy + :param gKV: The global KV tensor + :type gKV: cute.Tensor + :param sKV: The sliced KV tensor + :type sKV: cute.Tensor + :param sPT: The sliced page table tensor + :type sPT: cute.Tensor + :param gKV_for_slice: The global KV for slice tensor + :type gKV_for_slice: cute.Tensor + :param k_index: The k index + :type k_index: cutlass.Int32 + :param latent_idx: The latent index + :type latent_idx: cutlass.Int32 + :param qkv_stage_idx: The qkv stage index + :type qkv_stage_idx: cutlass.Int32 + :param page_table_stage: The page table stage + :type page_table_stage: cutlass.Int32 + :param transpose: Whether to transpose the gKV_for_slice + :type transpose: bool + """ + rest_modes_start = 1 + rest_modes_end = 4 + if cutlass.const_expr(transpose): + gKV_grouped = cute.group_modes( + gKV[None, None, None, None, latent_idx, k_index], + rest_modes_start, + rest_modes_end, + ) + else: + gKV_grouped = cute.group_modes( + gKV[None, None, None, None, k_index, latent_idx], + rest_modes_start, + rest_modes_end, + ) + sKV_grouped = cute.group_modes( + sKV[None, None, None, None, qkv_stage_idx], rest_modes_start, rest_modes_end + ) + page_size_log2 = cute.arch.log2_of_pow2_int(page_size) + page_per_tile = self.mma_qk_tiler[1] >> page_size_log2 + gKV_for_copy_offsets = cute.make_rmem_tensor( + cute.size(gKV_grouped.shape[1]), cute.cosize(gKV_for_slice.layout).dtype + ) + # unroll the rest of the loop to apply page table lookup. + for i in cutlass.range_constexpr(cute.size(gKV_grouped.shape[1])): + # get the coordinate of the gKV_for_slice + coord = gKV_grouped[None, i].iterator + if cutlass.const_expr(transpose): + # fast path of mod & div here to avoid the division because of the page size is power of 2. + page_coord = ((coord[1] & (page_size - 1)), coord[1] >> page_size_log2) + new_coord = (coord[0], page_coord) + new_coord_pt = new_coord[1][1] & (page_per_tile - 1) + gKV_for_copy_offset = cute.crd2idx( + ( + new_coord[0], + (new_coord[1][0], sPT[new_coord_pt, page_table_stage]), + ), + gKV_for_slice.layout, + ) + else: + # fast path of mod & div here to avoid the division because of the page size is power of 2. + page_coord = (coord[0] & (page_size - 1), coord[0] >> page_size_log2) + new_coord = (page_coord, coord[1]) + new_coord_pt = new_coord[0][1] & (page_per_tile - 1) + gKV_for_copy_offset = cute.crd2idx( + ( + (new_coord[0][0], sPT[new_coord_pt, page_table_stage]), + new_coord[1], + ), + gKV_for_slice.layout, + ) + gKV_for_copy_offsets[i] = gKV_for_copy_offset + cpasync_bits = 128 + for i in cutlass.range_constexpr(cute.size(gKV_grouped.shape[1])): + # calculate the actual offset and apply. + sKV_for_copy = sKV_grouped[None, i] + gKV_for_copy_offset = cute.assume( + gKV_for_copy_offsets[i], cpasync_bits // self.q_dtype.width + ) + gKV_for_copy_iter = gKV_for_slice.iterator + gKV_for_copy_offset + gKV_for_copy = cute.make_tensor(gKV_for_copy_iter, sKV_for_copy.layout) + cute.copy(tiled_copy, gKV_for_copy, sKV_for_copy) + return + + @cute.jit + def load_cpasync_qk_one_k_tile( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + k_index: cutlass.Int32, + load_pt_consumer_state: pipeline.PipelineState, + load_q_producer_state: pipeline.PipelineState, + load_kv_producer_state: pipeline.PipelineState, + load_kv_commit_state: pipeline.PipelineState, + copy_in_flight_count: cutlass.Int32, + load_q: bool, + ) -> tuple[ + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + cutlass.Int32, + ]: + """Load one k tile of Q/K. Updates the load cpasync producer state. + + :param qk_params: The qk parameters + :type qk_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param load_pt_consumer_state: The load pt consumer state + :type load_pt_consumer_state: pipeline.PipelineState + :param load_q_producer_state: The load q producer state + :type load_q_producer_state: pipeline.PipelineState + :param load_kv_producer_state: The load kv producer state + :type load_kv_producer_state: pipeline.PipelineState + :param load_kv_commit_state: The load kv commit state + :type load_kv_commit_state: pipeline.PipelineState + :param copy_in_flight_count: The copy stage count + :type copy_in_flight_count: int + :param load_q: Whether to load q + :type load_q: bool + + :return: The load pt consumer state, the load q producer state, the load kv producer state, the load kv commit state, the copy stage count + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, int] + """ + common_params.load_pt_pipeline.consumer_wait(load_pt_consumer_state) + page_table_stage = load_pt_consumer_state.index + load_pt_consumer_state.advance() + + def copy_qk_latent(latent_idx, qkv_stage_idx): + if load_q: + cute.copy( + qk_params.tiled_copy_q, + qk_params.tQgQL[ + None, + None, + None, + None, + 0, + latent_idx, + common_params.blk_coord[2], + ], + qk_params.tQsQ[None, None, None, None, latent_idx], + ) + # make sure the page table lookups first. + self.load_cpasync_page_table_lookup_copy( + qk_params.tiled_copy_kc, + qk_params.tKCgCL, + qk_params.tKCsKC, + common_params.sPT, + qk_params.mCL_for_slice, + k_index, + latent_idx, + qkv_stage_idx, + page_table_stage, + common_params.page_size, + ) + + def copy_qk_rope(latent_idx, qkv_stage_idx): + if load_q: + cute.copy( + qk_params.tiled_copy_q, + qk_params.tQgQR[ + None, + None, + None, + None, + 0, + latent_idx, + common_params.blk_coord[2], + ], + qk_params.tQsQ[ + None, None, None, None, self.iterations_qk_latent + latent_idx + ], + ) + # make sure the page table lookups first. + self.load_cpasync_page_table_lookup_copy( + qk_params.tiled_copy_kc, + qk_params.tKCgKR, + qk_params.tKCsKC, + common_params.sPT, + qk_params.mKR_for_slice, + k_index, + latent_idx, + qkv_stage_idx, + page_table_stage, + common_params.page_size, + ) + + # use dynamic loop here to avoid instruction cache miss. + for i in range(self.iterations_qk_latent): + ( + load_q_producer_state, + load_kv_producer_state, + load_kv_commit_state, + copy_in_flight_count, + ) = self.load_cpasync_one_smem_stage( + common_params, + load_q_producer_state, + load_kv_producer_state, + load_kv_commit_state, + partial(copy_qk_latent, i), + copy_in_flight_count, + load_q=load_q, + ) + for i in range(self.iterations_qk_rope): + ( + load_q_producer_state, + load_kv_producer_state, + load_kv_commit_state, + copy_in_flight_count, + ) = self.load_cpasync_one_smem_stage( + common_params, + load_q_producer_state, + load_kv_producer_state, + load_kv_commit_state, + partial(copy_qk_rope, i), + copy_in_flight_count, + load_q=load_q, + ) + + return ( + load_pt_consumer_state, + load_q_producer_state, + load_kv_producer_state, + load_kv_commit_state, + copy_in_flight_count, + ) + + @cute.jit + def load_cpasync_v_one_k_tile( + self, + common_params: SimpleNamespace, + v_params: SimpleNamespace, + k_index: cutlass.Int32, + load_pt_release_state: pipeline.PipelineState, + load_kv_producer_state: pipeline.PipelineState, + load_kv_commit_state: pipeline.PipelineState, + copy_in_flight_count: cutlass.Int32, + ) -> tuple[ + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + cutlass.Int32, + ]: + """Load one k tile of V. Updates the load cpasync producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param v_params: The v parameters + :type v_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param load_pt_release_state: The load pt release state + :type load_pt_release_state: pipeline.PipelineState + :param load_kv_producer_state: The load kv producer state + :type load_kv_producer_state: pipeline.PipelineState + :param load_kv_commit_state: The load kv commit state + :type load_kv_commit_state: pipeline.PipelineState + :param copy_in_flight_count: The copy in-flight count + :type copy_in_flight_count: cutlass.Int32 + + :return: The load pt release state, the load kv producer state, the load kv commit state, the copy in-flight count + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Int32] + """ + page_table_stage = load_pt_release_state.index + + def copy_v_latent(iter_k_idx, latent_idx, qkv_stage_idx): + # make sure the page table lookups first. + self.load_cpasync_page_table_lookup_copy( + v_params.tiled_copy_vc, + v_params.tVCgCLT, + v_params.tVCsVC, + common_params.sPT, + v_params.mCLT_for_slice, + k_index * self.iterations_pv_k + iter_k_idx, + latent_idx, + qkv_stage_idx, + page_table_stage, + common_params.page_size, + transpose=True, + ) + + # use dynamic loop here to avoid instruction cache miss. + for i in range(self.iterations_pv_k): + for j in range(self.iterations_pv_n): + ( + _, + load_kv_producer_state, + load_kv_commit_state, + copy_in_flight_count, + ) = self.load_cpasync_one_smem_stage( + common_params, + None, + load_kv_producer_state, + load_kv_commit_state, + partial(copy_v_latent, i, j), + copy_in_flight_count, + load_q=False, + ) + common_params.load_pt_pipeline.consumer_release(load_pt_release_state) + load_pt_release_state.advance() + return ( + load_pt_release_state, + load_kv_producer_state, + load_kv_commit_state, + copy_in_flight_count, + ) + + @cute.jit + def load_tma( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + v_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + load_q_producer_state: pipeline.PipelineState, + load_kv_producer_state: pipeline.PipelineState, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: + """Load wrap to load Q/C latent/rope tensors. Updates the load qkv producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param qk_params: The qk parameters + :type qk_params: SimpleNamespace + :param v_params: The v parameters + :type v_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param load_q_producer_state: The load q producer state + :type load_q_producer_state: pipeline.PipelineState + :param load_kv_producer_state: The load kv producer state + :type load_kv_producer_state: pipeline.PipelineState + + :return: The load q producer state and load kv producer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] + """ + # page table + mPT = None + if cutlass.const_expr(self.use_page_table): + mPT = common_params.mPT[None, common_params.blk_coord[2]] + + # Flatten divide and partition global tensors for QK TMA load + # (bM, bK, rM, rK, rL) + mma_qk_tiler_mk = cute.select(self.mma_qk_tiler, mode=[0, 2]) + gQL = cute.flat_divide(qk_params.mQL, mma_qk_tiler_mk) + gQR = cute.flat_divide(qk_params.mQR, mma_qk_tiler_mk) + + mma_qk_tiler_nk = cute.select(self.mma_qk_tiler, mode=[1, 2]) + gCL = cute.flat_divide(qk_params.mCL, mma_qk_tiler_nk) + gKR = cute.flat_divide(qk_params.mKR, mma_qk_tiler_nk) + + thr_mma_qk = qk_params.tiled_mma_qk.get_slice( + common_params.blk_coord[0] % cute.size(qk_params.tiled_mma_qk.thr_id) + ) + tSgQL = thr_mma_qk.partition_A(gQL) + tSgQR = thr_mma_qk.partition_A(gQR) + + tSgCL = thr_mma_qk.partition_B(gCL) + tSgKR = thr_mma_qk.partition_B(gKR) + + # tma partition for q, k latent/rope + + # smem: ((atom_v, rest_v), STAGE) + # gmem: ((atom_v, rest_v), RestM, RestK, RestL) + tQsQ, tQLgQL_mkl = cpasync.tma_partition( + qk_params.tma_atom_q_latent, + 0, + cute.make_layout(1), + cute.group_modes(qk_params.sQ, 0, 3), + cute.group_modes(tSgQL, 0, 3), + ) + + _, tQRgQR_mkl = cpasync.tma_partition( + qk_params.tma_atom_q_rope, + 0, + cute.make_layout(1), + cute.group_modes(qk_params.sQ, 0, 3), + cute.group_modes(tSgQR, 0, 3), + ) + + tKCsKC, tCLgCL = cpasync.tma_partition( + qk_params.tma_atom_c_latent, + 0, + cute.make_layout(1), + cute.group_modes(qk_params.sKC, 0, 3), + cute.group_modes(tSgCL, 0, 3), + ) + + _, tKRgKR = cpasync.tma_partition( + qk_params.tma_atom_c_rope, + 0, + cute.make_layout(1), + cute.group_modes(qk_params.sKC, 0, 3), + cute.group_modes(tSgKR, 0, 3), + ) + + tQLgQL = tQLgQL_mkl[None, None, None, common_params.blk_coord[2]] + tQRgQR = tQRgQR_mkl[None, None, None, common_params.blk_coord[2]] + + # Flatten divide and partition global tensors for V TMA load + mma_pv_tiler_nk = cute.select(self.mma_pv_tiler, mode=[1, 2]) + gCLT = cute.flat_divide(v_params.mCLT, mma_pv_tiler_nk) + + thr_mma_pv = v_params.tiled_mma_pv.get_slice( + common_params.blk_coord[0] % cute.size(v_params.tiled_mma_pv.thr_id) + ) + tOgCLT = thr_mma_pv.partition_B(gCLT) + + # tma partition for vc + # smem: ((atom_v, rest_v), STAGE) + # gmem: ((atom_v, rest_v), RestM, RestK, RestL) + tVCsVC, tCLTgCLT = cpasync.tma_partition( + v_params.tma_atom_c_latent_transpose, + 0, + cute.make_layout(1), + cute.group_modes(v_params.sVC, 0, 3), + cute.group_modes(tOgCLT, 0, 3), + ) + + # set extra params + common_params.mPT = mPT + qk_params.tQLgQL = tQLgQL + qk_params.tQRgQR = tQRgQR + qk_params.tCLgCL = tCLgCL + qk_params.tKRgKR = tKRgKR + qk_params.tQsQ = tQsQ + qk_params.tKCsKC = tKCsKC + v_params.tCLTgCLT = tCLTgCLT + v_params.tVCsVC = tVCsVC + + load_q_producer_state, load_kv_producer_state = self.load_tma_qk_one_k_tile( + common_params, + qk_params, + k_index, + k_tile_count, + load_q_producer_state, + load_kv_producer_state, + load_q=True, + ) + k_index += 1 + k_tile_count -= 1 + while k_tile_count > 0: + load_q_producer_state, load_kv_producer_state = self.load_tma_qk_one_k_tile( + common_params, + qk_params, + k_index, + k_tile_count, + load_q_producer_state, + load_kv_producer_state, + load_q=False, + ) + load_kv_producer_state = self.load_tma_v_one_k_tile( + common_params, + v_params, + k_index - 1, + load_kv_producer_state, + ) + k_index += 1 + k_tile_count -= 1 + + # load last v tile + load_kv_producer_state = self.load_tma_v_one_k_tile( + common_params, + v_params, + k_index - 1, + load_kv_producer_state, + ) + return load_q_producer_state, load_kv_producer_state + + @cute.jit + def load_tma_qk_one_k_tile( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + load_q_producer_state: pipeline.PipelineState, + load_kv_producer_state: pipeline.PipelineState, + load_q: bool, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: + """Load one k-tile of Q/C latent/rope tensors. Updates the load qkv producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param qk_params: The qk parameters + :type qk_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param load_q_producer_state: The load q producer state + :type load_q_producer_state: pipeline.PipelineState + :param load_kv_producer_state: The load kv producer state + :type load_kv_producer_state: pipeline.PipelineState + :param load_q: Whether to load q + :type load_q: bool + + :return: The load q producer state and load kv producer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] + """ + k_idx = cute.make_rmem_tensor(cute.make_layout(2), cutlass.Int32) + # prefetch next K load to keep busy while we transpose-load from cache + kPrefetchDistance = 1 + if cutlass.const_expr(self.use_page_table): + k_idx[0] = common_params.mPT[k_index] + k_idx[1] = common_params.mPT[k_index + kPrefetchDistance] + else: + k_idx[0] = common_params.blk_coord[2] + k_idx[1] = common_params.blk_coord[2] + for i in cutlass.range_constexpr(self.iterations_qk_latent): + # load q once at first iteration + if cutlass.const_expr(load_q): + # get the mbar ptr from pipeline. + tma_bar_ptr = common_params.load_q_pipeline.producer_get_barrier( + load_q_producer_state + ) + # expect the extra bytes for q. + common_params.load_q_pipeline.producer_acquire(load_q_producer_state) + # load q latent + cute.copy( + qk_params.tma_atom_q_latent, + qk_params.tQLgQL[None, 0, load_q_producer_state.index], + qk_params.tQsQ[None, load_q_producer_state.index], + tma_bar_ptr=tma_bar_ptr, + ) + load_q_producer_state.advance() + # get the mbar ptr from pipeline. + tma_bar_ptr = common_params.load_kv_pipeline.producer_get_barrier( + load_kv_producer_state + ) + # expect the extra bytes for q. + common_params.load_kv_pipeline.producer_acquire(load_kv_producer_state) + # load k latent + if cutlass.const_expr(self.use_page_table): + cute.copy( + qk_params.tma_atom_c_latent, + qk_params.tCLgCL[None, 0, i, k_idx[0]], + qk_params.tKCsKC[None, load_kv_producer_state.index], + tma_bar_ptr=tma_bar_ptr, + ) + else: + cute.copy( + qk_params.tma_atom_c_latent, + qk_params.tCLgCL[None, k_index, i, k_idx[0]], + qk_params.tKCsKC[None, load_kv_producer_state.index], + tma_bar_ptr=tma_bar_ptr, + ) + load_kv_producer_state.advance() + + for i in cutlass.range_constexpr(self.iterations_qk_rope): + # load q rope once at first iteration + if cutlass.const_expr(load_q): + # get the mbar ptr from pipeline. + tma_bar_ptr = common_params.load_q_pipeline.producer_get_barrier( + load_q_producer_state + ) + # expect the extra bytes for q. + common_params.load_q_pipeline.producer_acquire(load_q_producer_state) + # load q rope + cute.copy( + qk_params.tma_atom_q_rope, + qk_params.tQRgQR[None, 0, i], + qk_params.tQsQ[None, i + self.iterations_qk_latent], + tma_bar_ptr=tma_bar_ptr, + ) + load_q_producer_state.advance() + # get the mbar ptr from pipeline. + tma_bar_ptr = common_params.load_kv_pipeline.producer_get_barrier( + load_kv_producer_state + ) + # expect the extra bytes for q. + common_params.load_kv_pipeline.producer_acquire(load_kv_producer_state) + # load k rope + if cutlass.const_expr(self.use_page_table): + cute.copy( + qk_params.tma_atom_c_rope, + qk_params.tKRgKR[None, 0, i, k_idx[0]], + qk_params.tKCsKC[None, load_kv_producer_state.index], + tma_bar_ptr=tma_bar_ptr, + ) + else: + cute.copy( + qk_params.tma_atom_c_rope, + qk_params.tKRgKR[None, k_index, i, k_idx[0]], + qk_params.tKCsKC[None, load_kv_producer_state.index], + tma_bar_ptr=tma_bar_ptr, + ) + load_kv_producer_state.advance() + + for i in cutlass.range_constexpr(self.iterations_qk_latent): + if cutlass.const_expr(self.use_page_table): + if k_tile_count > kPrefetchDistance: + cute.prefetch( + qk_params.tma_atom_c_latent, + qk_params.tCLgCL[ + None, + k_index, + i, + k_idx[1], + ], + ) + else: + cute.prefetch( + qk_params.tma_atom_c_latent, + qk_params.tCLgCL[None, k_index + kPrefetchDistance, i, k_idx[1]], + ) + + for i in cutlass.range_constexpr(self.iterations_qk_rope): + if cutlass.const_expr(self.use_page_table): + if k_tile_count > kPrefetchDistance: + cute.prefetch( + qk_params.tma_atom_c_rope, + qk_params.tKRgKR[ + None, + k_index, + i, + k_idx[1], + ], + ) + else: + cute.prefetch( + qk_params.tma_atom_c_rope, + qk_params.tKRgKR[None, k_index + kPrefetchDistance, i, k_idx[1]], + ) + return load_q_producer_state, load_kv_producer_state + + @cute.jit + def load_tma_v_one_k_tile( + self, + common_params: SimpleNamespace, + v_params: SimpleNamespace, + k_index: cutlass.Int32, + load_kv_producer_state: pipeline.PipelineState, + ) -> pipeline.PipelineState: + """Load one k-tile of compressed latent transpose tensor(v). Updates the load qkv producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param v_params: The load tma v parameters + :type v_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param load_kv_producer_state: The load qkv producer state + :type load_kv_producer_state: pipeline.PipelineState + + :return: The load qkv producer state + :rtype: pipeline.PipelineState + """ + k_idx = cute.make_rmem_tensor(cute.make_layout(1), cutlass.Int32) + if cutlass.const_expr(self.use_page_table): + k_idx[0] = common_params.mPT[k_index] + else: + k_idx[0] = common_params.blk_coord[2] + for i in cutlass.range_constexpr(self.iterations_pv_k): + for j in cutlass.range_constexpr(self.iterations_pv_n): + # get the mbar ptr from pipeline. + tma_bar_ptr = common_params.load_kv_pipeline.producer_get_barrier( + load_kv_producer_state + ) + common_params.load_kv_pipeline.producer_acquire(load_kv_producer_state) + if cutlass.const_expr(self.use_page_table): + cute.copy( + v_params.tma_atom_c_latent_transpose, + v_params.tCLTgCLT[None, j, i, k_idx[0]], + v_params.tVCsVC[None, load_kv_producer_state.index], + tma_bar_ptr=tma_bar_ptr, + ) + else: + cute.copy( + v_params.tma_atom_c_latent_transpose, + v_params.tCLTgCLT[ + None, + j, + k_index * self.iterations_pv_k + i, + k_idx[0], + ], + v_params.tVCsVC[None, load_kv_producer_state.index], + tma_bar_ptr=tma_bar_ptr, + ) + load_kv_producer_state.advance() + return load_kv_producer_state + + @cute.jit + def mma( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + pv_params: SimpleNamespace, + k_tile_count: cutlass.Int32, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + load_q_consumer_state: pipeline.PipelineState, + load_kv_consumer_state: pipeline.PipelineState, + mma_s_producer_state: pipeline.PipelineState, + p_mma_consumer_state: pipeline.PipelineState, + mma_o_producer_state: pipeline.PipelineState, + ) -> tuple[ + cute.TiledMma, + cute.TiledMma, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """MMA warp to compute the result of Q*K^T and P*V. Updates the tiled mma and pipeline states. + + :param common_params: The common parameters for mma qk and pv + :type common_params: SimpleNamespace + :param qk_params: The mma qk parameters + :type qk_params: SimpleNamespace + :param pv_params: The mma pv parameters + :type pv_params: SimpleNamespace + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param tiled_mma_qk: The tiled mma qk + :type tiled_mma_qk: cute.TiledMma + :param tiled_mma_pv: The tiled mma pv + :type tiled_mma_pv: cute.TiledMma + :param load_q_consumer_state: The load q consumer state + :type load_q_consumer_state: pipeline.PipelineState + :param load_kv_consumer_state: The load kv consumer state + :type load_kv_consumer_state: pipeline.PipelineState + :param mma_s_producer_state: The mma s producer state + :type mma_s_producer_state: pipeline.PipelineState + :param p_mma_consumer_state: The p mma consumer state + :type p_mma_consumer_state: pipeline.PipelineState + :param mma_o_producer_state: The mma o producer state + :type mma_o_producer_state: pipeline.PipelineState + + :return: The tiled mma qk, the tiled mma pv, the load q consumer state, the load kv consumer state, the mma s producer state, the p mma consumer state, and the mma o producer state + :rtype: tuple[cute.TiledMma, cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + + tSrQ = tiled_mma_qk.make_fragment_A(qk_params.sQ) + tSrKC = tiled_mma_qk.make_fragment_B(qk_params.sKC) + tOrP = tiled_mma_pv.make_fragment_A(pv_params.sP) + tOrVC = tiled_mma_pv.make_fragment_B(pv_params.sVC) + + tStS_shape = tiled_mma_qk.partition_shape_C( + cute.select(self.mma_qk_tiler, mode=[0, 1]) + ) + tStS_staged_fake = tiled_mma_qk.make_fragment_C( + cute.append(tStS_shape, self.mma_s_stage) + ) + # use real tmem ptr for tStS + tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) + tOtO_shape = tiled_mma_pv.partition_shape_C( + cute.select(self.mma_pv_tiler, mode=[0, 1]) + ) + # mma O has 1 stage. + assert self.mma_o_stage == 1, ( + "mma O has 1 stage, otherwise the tmem usage exceeds the limit." + ) + tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) + tOtO_layout = cute.append( + tOtO.layout, + cute.make_layout( + common_params.L // self.mma_pv_tiler[1], + stride=self.mma_pv_tiler[1] // self.warps_in_n, + ), + ) + tOtO_staged = cute.make_tensor( + tStS_staged.iterator + self.tmem_o_offset, tOtO_layout + ) + + # set more parameters + qk_params.tSrQ = tSrQ + qk_params.tSrKC = tSrKC + qk_params.tStS_staged = tStS_staged + pv_params.tOrP = tOrP + pv_params.tOrVC = tOrVC + pv_params.tOtO_staged = tOtO_staged + + # mma O accumulates on K, so the accumlate flag is set to False once before all K blocks. + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, False) + load_q_pipeline = common_params.load_q_pipeline + if common_params.is_leader_cta: + load_q_release_state = load_q_consumer_state.clone() + ( + tiled_mma_qk, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + ) = self.mma_qk( + common_params, + qk_params, + tiled_mma_qk, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + wait_q=True, + ) + k_tile_count -= 1 + + while k_tile_count > 0: + ( + tiled_mma_qk, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + ) = self.mma_qk( + common_params, + qk_params, + tiled_mma_qk, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + wait_q=False, + ) + ( + tiled_mma_pv, + load_kv_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) = self.mma_pv( + common_params, + pv_params, + tiled_mma_pv, + load_kv_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + k_tile_count -= 1 + # release q consumer states + for i in cutlass.range_constexpr(self.iterations_qk): + load_q_pipeline.consumer_release(load_q_release_state) + load_q_release_state.advance() + ( + tiled_mma_pv, + load_kv_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) = self.mma_pv( + common_params, + pv_params, + tiled_mma_pv, + load_kv_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + + return ( + tiled_mma_qk, + tiled_mma_pv, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + + @cute.jit + def mma_qk( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + tiled_mma_qk: cute.TiledMma, + load_q_consumer_state: pipeline.PipelineState, + load_kv_consumer_state: pipeline.PipelineState, + mma_s_producer_state: pipeline.PipelineState, + wait_q: bool, + ) -> tuple[ + cute.TiledMma, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """Compute one k-tile of mma for Q*K^T. Updates the tiled MMA QK and pipeline states. + + :param qk_params: The qk parameters + :type qk_params: SimpleNamespace + :param tiled_mma_qk: The tiled mma qk + :type tiled_mma_qk: cute.TiledMma + :param load_q_consumer_state: The load q consumer state + :type load_q_consumer_state: pipeline.PipelineState + :param load_kv_consumer_state: The load kv consumer state + :type load_kv_consumer_state: pipeline.PipelineState + :param mma_s_producer_state: The mma s producer state + :type mma_s_producer_state: pipeline.PipelineState + + :return: The tiled mma qk, the load q consumer state, the load kv consumer state, and the mma s producer state + :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + tStS = qk_params.tStS_staged[None, None, None, mma_s_producer_state.index] + + qk_params.mma_s_pipeline.producer_acquire(mma_s_producer_state) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, False) + load_q_pipeline = common_params.load_q_pipeline + load_kv_pipeline = common_params.load_kv_pipeline + for q_stage in range(self.iterations_qk_latent): + if cutlass.const_expr(wait_q): + load_q_pipeline.consumer_wait(load_q_consumer_state) + load_kv_pipeline.consumer_wait(load_kv_consumer_state) + kc_stage = load_kv_consumer_state.index + for k_block in cutlass.range_constexpr(cute.size(qk_params.tSrQ.shape[2])): + cute.gemm( + tiled_mma_qk, + tStS, + qk_params.tSrQ[None, None, k_block, q_stage], + qk_params.tSrKC[None, None, k_block, kc_stage], + tStS, + ) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) + load_kv_pipeline.consumer_release(load_kv_consumer_state) + load_kv_consumer_state.advance() + if cutlass.const_expr(wait_q): + load_q_consumer_state.advance() + for q_stage in range(self.iterations_qk_rope): + if cutlass.const_expr(wait_q): + load_q_pipeline.consumer_wait(load_q_consumer_state) + load_kv_pipeline.consumer_wait(load_kv_consumer_state) + kc_stage = load_kv_consumer_state.index + for k_block in cutlass.range_constexpr( + self.rope_dim // tiled_mma_qk.shape_mnk[2] + ): + cute.gemm( + tiled_mma_qk, + tStS, + qk_params.tSrQ[ + None, None, k_block, q_stage + self.iterations_qk_latent + ], + qk_params.tSrKC[None, None, k_block, kc_stage], + tStS, + ) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) + load_kv_pipeline.consumer_release(load_kv_consumer_state) + load_kv_consumer_state.advance() + if cutlass.const_expr(wait_q): + load_q_consumer_state.advance() + + qk_params.mma_s_pipeline.producer_commit(mma_s_producer_state) + mma_s_producer_state.advance() + return ( + tiled_mma_qk, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + ) + + @cute.jit + def mma_pv( + self, + common_params: SimpleNamespace, + pv_params: SimpleNamespace, + tiled_mma_pv: cute.TiledMma, + load_kv_consumer_state: pipeline.PipelineState, + p_mma_consumer_state: pipeline.PipelineState, + mma_o_producer_state: pipeline.PipelineState, + ) -> tuple[ + cute.TiledMma, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """Compute one k-tile of mma for P*V. Updates the tiled mma pv and pipeline states. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param pv_params: The pv parameters + :type pv_params: SimpleNamespace + :param tiled_mma_pv: The tiled mma pv + :type tiled_mma_pv: cute.TiledMma + :param load_kv_consumer_state: The load kv consumer state + :type load_kv_consumer_state: pipeline.PipelineState + :param p_mma_consumer_state: The P MMA consumer state + :type p_mma_consumer_state: pipeline.PipelineState + :param mma_o_producer_state: The MMA o producer state + :type mma_o_producer_state: pipeline.PipelineState + + :return: The tiled mma pv, the load qkv consumer state, the P MMA consumer state, and the MMA o producer state + :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + + pv_params.mma_o_pipeline.producer_acquire(mma_o_producer_state) + pv_params.p_mma_pipeline.consumer_wait(p_mma_consumer_state) + load_kv_pipeline = common_params.load_kv_pipeline + for p_stage in range(self.iterations_pv_k): + accumulate_flag = tiled_mma_pv.get(tcgen05.Field.ACCUMULATE) + for acc_stage in range(self.iterations_pv_n): + load_kv_pipeline.consumer_wait(load_kv_consumer_state) + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, accumulate_flag) + vc_stage = load_kv_consumer_state.index + tOtO = pv_params.tOtO_staged[None, None, None, acc_stage] + for k_block in cutlass.range_constexpr(pv_params.tOrP.shape[2]): + cute.gemm( + tiled_mma_pv, + tOtO, + pv_params.tOrP[ + None, + None, + k_block, + (p_stage, p_mma_consumer_state.index), + ], + pv_params.tOrVC[None, None, k_block, vc_stage], + tOtO, + ) + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, True) + load_kv_pipeline.consumer_release(load_kv_consumer_state) + load_kv_consumer_state.advance() + pv_params.p_mma_pipeline.consumer_release(p_mma_consumer_state) + p_mma_consumer_state.advance() + pv_params.mma_o_pipeline.producer_commit(mma_o_producer_state) + mma_o_producer_state.advance() + + return ( + tiled_mma_pv, + load_kv_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + + @cute.jit + def compute( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + mma_s_consumer_state: pipeline.PipelineState, + p_mma_producer_state: pipeline.PipelineState, + p_cor_producer_state: pipeline.PipelineState, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]: + """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param softmax_params: The softmax parameters + :type softmax_params: SimpleNamespace + :param k_index: The index of the k-tile + :type k_index: cutlass.Int32 + :param k_tile_count: The number of k-tiles + :type k_tile_count: cutlass.Int32 + :param mma_s_consumer_state: The MMA s consumer state + :type mma_s_consumer_state: pipeline.PipelineState + :param p_mma_producer_state: The P MMA producer state + :type p_mma_producer_state: pipeline.PipelineState + :param p_cor_producer_state: The P correction producer state + :type p_cor_producer_state: pipeline.PipelineState + + :return: The MMA s consumer state, the P MMA producer state, and the P correction producer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + + k_tile_total = cute.ceil_div(common_params.K, self.mma_qk_tiler[1]) + + row_max = -self.acc_dtype.inf + row_sum = self.acc_dtype(0) + correction_factor = self.acc_dtype(1) + while k_tile_count > 0: + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax_dispatch_apply_mask( + common_params, + softmax_params, + k_index, + k_tile_total, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) + k_index = k_index + 1 + k_tile_count = k_tile_count - 1 + + return mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state + + @cute.jit + def correction( + self, + common_params: SimpleNamespace, + epilogue_params: SimpleNamespace, + k_tile_count: cutlass.Int32, + p_cor_consumer_state: pipeline.PipelineState, + mma_o_consumer_state: pipeline.PipelineState, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: + """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param epilogue_params: The epilogue parameters + :type epilogue_params: SimpleNamespace + :param k_index: The index of the k-tile + :type k_index: cutlass.Int32 + :param k_tile_count: The number of k-tiles + :type k_tile_count: cutlass.Int32 + :param p_cor_consumer_state: The P correction consumer state + :type p_cor_consumer_state: pipeline.PipelineState + :param mma_o_consumer_state: The MMA o consumer state + :type mma_o_consumer_state: pipeline.PipelineState + + :return: The P correction consumer state, and the MMA o consumer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] + """ + + p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction = ( + self.get_correction_factor(common_params, p_cor_consumer_state) + ) + k_tile_count = k_tile_count - 1 + while k_tile_count > 0: + p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction = ( + self.get_correction_factor(common_params, p_cor_consumer_state) + ) + mma_o_consumer_state = self.rescale( + common_params, mma_o_consumer_state, correction_factor, no_correction + ) + k_tile_count = k_tile_count - 1 + + mma_o_consumer_state = self.epilogue( + common_params, epilogue_params, mma_o_consumer_state, row_sum, row_max + ) + return p_cor_consumer_state, mma_o_consumer_state + + @cute.jit + def softmax_dispatch_apply_mask( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_total: cutlass.Int32, + mma_s_consumer_state: pipeline.PipelineState, + p_mma_producer_state: pipeline.PipelineState, + p_cor_producer_state: pipeline.PipelineState, + row_max: cutlass.Float32, + row_sum: cutlass.Float32, + correction_factor: cutlass.Float32, + ) -> tuple[ + pipeline.PipelineState, + pipeline.PipelineState, + cutlass.Float32, + cutlass.Float32, + cutlass.Float32, + ]: + """Dispatch whether to apply mask for softmax for last k tile. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param softmax_params: The softmax parameters + :type softmax_params: SimpleNamespace + :param k_index: The index of the k-tile + :type k_index: cutlass.Int32 + :param k_tile_total: The total number of k-tiles + :type k_tile_total: cutlass.Int32 + :param mma_s_consumer_state: The MMA s consumer state + :type mma_s_consumer_state: pipeline.PipelineState + :param p_mma_producer_state: The P MMA producer state + :type p_mma_producer_state: pipeline.PipelineState + :param p_cor_producer_state: The P correction producer state + :type p_cor_producer_state: pipeline.PipelineState + :param row_max: The row max + :type row_max: cutlass.Float32 + :param row_sum: The row sum + :type row_sum: cutlass.Float32 + :param correction_factor: The correction factor + :type correction_factor: cutlass.Float32 + + :return: The MMA s consumer state, the P MMA producer state, the row max, the row sum, and the correction factor + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32] + """ + if k_index == k_tile_total - 1: + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax( + common_params, + softmax_params, + k_index, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + True, + ) + else: + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax( + common_params, + softmax_params, + k_index, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + False, + ) + return ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) + + @cute.jit + def softmax( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + k_index: cutlass.Int32, + mma_s_consumer_state: pipeline.PipelineState, + p_mma_producer_state: pipeline.PipelineState, + p_cor_producer_state: pipeline.PipelineState, + row_max: cutlass.Float32, + row_sum: cutlass.Float32, + correction_factor: cutlass.Float32, + is_last_tile: bool, + ) -> tuple[ + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + cutlass.Float32, + cutlass.Float32, + cutlass.Float32, + ]: + """Softmax for one k-tile. Updates the related pipeline states and returns the computed results. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param softmax_params: The softmax parameters + :type softmax_params: SimpleNamespace + :param k_index: The index of the k-tile + :type k_index: cutlass.Int32 + :param mma_s_consumer_state: The MMA s consumer state + :type mma_s_consumer_state: pipeline.PipelineState + :param p_mma_producer_state: The P MMA producer state + :type p_mma_producer_state: pipeline.PipelineState + :param p_cor_producer_state: The P correction producer state + :type p_cor_producer_state: pipeline.PipelineState + :param row_max: The row max + :type row_max: cutlass.Float32 + :param row_sum: The row sum + :type row_sum: cutlass.Float32 + :param correction_factor: The correction factor + :type correction_factor: cutlass.Float32 + :param is_last_tile: Whether the last tile + :type is_last_tile: bool + + :return: The MMA s consumer state, the P MMA producer state, the P correction producer state, the row max, the row sum, and the correction factor + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32] + """ + + softmax_params.p_mma_pipeline.producer_acquire(p_mma_producer_state) + softmax_params.mma_s_pipeline.consumer_wait(mma_s_consumer_state) + + # load S from tmem + tStS_shape = softmax_params.tiled_mma_qk.partition_shape_C( + cute.select(self.mma_qk_tiler, mode=[0, 1]) + ) + tStS_staged_fake = softmax_params.tiled_mma_qk.make_fragment_C( + cute.append(tStS_shape, self.mma_s_stage) + ) + tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) + tStS = tStS_staged[None, None, None, mma_s_consumer_state.index] + + tAcc = tStS[(None, None), 0, 0] + cta_qk_tiler = ( + self.mma_qk_tiler[0] // self.cluster_shape_mnk[0], + self.mma_qk_tiler[1], + self.mma_qk_tiler[2], + ) + cS = cute.make_identity_tensor(cute.select(cta_qk_tiler, mode=[0, 1])) + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) + + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + + tmem_thr_copy = tmem_tiled_copy.get_slice(tidx) + tTR_tAcc = tmem_thr_copy.partition_S(tAcc) + tTR_tS = tmem_thr_copy.partition_D(cS) + + tTR_rAcc = cute.make_fragment_like(tTR_tS, self.acc_dtype) + + cute.copy(tmem_tiled_copy, tTR_tAcc, tTR_rAcc) + + row_max_new = row_max + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + if cutlass.const_expr(is_last_tile): + tTR_rAcc[i] = ( + tTR_rAcc[i] + if cute.elem_less( + tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, + common_params.K, + ) + else -self.acc_dtype.inf + ) + # update row_max + row_max_new = cute.arch.fmax(row_max_new, tTR_rAcc[i]) + + # if warps in N is 2, reduce row_max across warps (0, 1) and (2, 3) + if cutlass.const_expr(self.warps_in_n == 2): + common_params.smem_exchange[tidx] = row_max_new + self.softmax_exchange_sync_bar.wait() + row_max_new = cute.arch.fmax( + row_max_new, + common_params.smem_exchange[ + (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) + ], + ) + + # find correction factor + correction_factor = cute.math.exp2( + (row_max - row_max_new) * softmax_params.softmax_scale_log2, fastmath=True + ) + no_correction = cutlass.Int32(row_max == row_max_new) + # softmax + fma_b = (softmax_params.softmax_scale_log2, softmax_params.softmax_scale_log2) + fma_c = ( + (0.0 - row_max_new) * softmax_params.softmax_scale_log2, + (0.0 - row_max_new) * softmax_params.softmax_scale_log2, + ) + + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.fma_packed_f32x2( + (tTR_rAcc[i], tTR_rAcc[i + 1]), fma_b, fma_c + ) + tTR_rAcc[i] = cute.math.exp2(tTR_rAcc[i], fastmath=True) + tTR_rAcc[i + 1] = cute.math.exp2(tTR_rAcc[i + 1], fastmath=True) + + tTR_rS = cute.make_fragment_like(tTR_tS, self.q_dtype) + + # quantize + tTR_rS.store(tTR_rAcc.load().to(self.q_dtype)) + + # create sP + sP = softmax_params.sP[None, None, None, (None, p_mma_producer_state.index)] + sP_mk_view = cute.make_tensor( + sP.iterator, + cute.make_layout( + ( + (sP.shape[0][0], sP.shape[1]), + (sP.shape[0][1], sP.shape[2], sP.shape[3]), + ), + stride=( + (sP.stride[0][0], sP.stride[1]), + (sP.stride[0][1], sP.stride[2], sP.stride[3]), + ), + ), + ) + # change to PISL + sP_wo_swizzle_iter = cute.recast_ptr(sP.iterator, swizzle_=None) + swizzle_bits = ( + int(math.log2(self.mma_pv_tiler[2] * self.q_dtype.width // 8 // 32)) + 1 + ) + swizzle_base = 3 if self.q_dtype.width == 16 else 4 + sP_swizzle = cute.make_swizzle(swizzle_bits, swizzle_base, 3) + sP_mk_view = cute.make_tensor( + sP_wo_swizzle_iter, + cute.make_composed_layout(sP_swizzle, 0, sP_mk_view.layout), + ) + universal_copy_bits = 128 + smem_copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.q_dtype, + num_bits_per_copy=universal_copy_bits, + ) + smem_tiled_copy = cute.make_tiled_copy_D(smem_copy_atom, tmem_tiled_copy) + smem_thr_copy = smem_tiled_copy.get_slice(tidx) + rP_copy_view = smem_thr_copy.retile(tTR_rS) + sP_copy_view = smem_thr_copy.partition_D(sP_mk_view) + cute.copy(smem_tiled_copy, rP_copy_view, sP_copy_view) + + # row_sum, using `add_packed_f32x2` to reduce the number of instructions + row_sum = row_sum * correction_factor + row_sum_vec = (0.0, 0.0) + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + row_sum_vec = cute.arch.add_packed_f32x2( + row_sum_vec, (tTR_rAcc[i], tTR_rAcc[i + 1]) + ) + row_sum = row_sum_vec[0] + row_sum_vec[1] + row_sum + + # fence between tmem load and mma s + cute.arch.fence_view_async_tmem_load() + # fence between smem store and mma o + cute.arch.fence_view_async_shared() + + softmax_params.mma_s_pipeline.consumer_release(mma_s_consumer_state) + softmax_params.p_mma_pipeline.producer_commit(p_mma_producer_state) + mma_s_consumer_state.advance() + p_mma_producer_state.advance() + + # store correction factor/row_sum/row_max to tmem for correction warp + common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) + # pad for 4x32b + corr_layout = cute.make_layout( + (tAcc.shape[0], (4, tAcc.shape[1][1]), self.mma_s_stage), + stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), + ) + tCor = cute.make_tensor( + common_params.tmem_ptr + self.correction_factor_offset, + corr_layout, + ) + cCor = cute.make_identity_tensor(tCor.shape) + corr_tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype + ) + corr_tmem_store_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_store_atom, tCor) + corr_tmem_store_thr_copy = corr_tmem_store_tiled_copy.get_slice(tidx) + cCor_for_copy = corr_tmem_store_thr_copy.partition_S(cCor) + tCor_for_copy = corr_tmem_store_thr_copy.partition_D(tCor) + rCor = cute.make_fragment_like( + cCor_for_copy[None, None, None, 0], self.acc_dtype + ) + rCor_int = cute.make_tensor( + cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout + ) + rCor[0] = row_sum + rCor[1] = row_max_new + rCor[2] = correction_factor + rCor_int[3] = no_correction + + cute.copy( + corr_tmem_store_tiled_copy, + rCor, + tCor_for_copy[None, None, None, p_cor_producer_state.index], + ) + # fence between tmem store and correction warp + cute.arch.fence_view_async_tmem_store() + common_params.p_cor_pipeline.producer_commit(p_cor_producer_state) + p_cor_producer_state.advance() + + return ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max_new, + row_sum, + correction_factor, + ) + + @cute.jit + def _tmem_load_partition( + self, common_params: SimpleNamespace, tiled_mma_pv: cute.TiledMma, iter_n: int + ) -> tuple[ + cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma + ]: + """Tensor memory load partition for rescale and epilogue. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param tiled_mma_pv: The tiled mma pv + :type tiled_mma_pv: cute.TiledMma + :param iter_n: The iteration number + :type iter_n: int + + :return: The tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv + :rtype: tuple[cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma] + """ + + tOtO_shape = tiled_mma_pv.partition_shape_C( + cute.select(self.mma_pv_tiler, mode=[0, 1]) + ) + tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) + tOtO_layout = cute.append( + tOtO.layout, + cute.make_layout( + common_params.L // self.mma_pv_tiler[1], + stride=self.mma_pv_tiler[1] // self.warps_in_n, + ), + ) + tOtO = cute.make_tensor( + common_params.tmem_ptr + self.tmem_o_offset, tOtO_layout + ) + tOtO = tOtO[None, None, None, iter_n] + + tAcc = tOtO[(None, None), 0, 0] + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_load_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) + tmem_load_thr_copy = tmem_load_tiled_copy.get_slice( + common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + ) + + cta_pv_tiler = ( + self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], + self.mma_pv_tiler[1], + self.mma_pv_tiler[2], + ) + # Flatten divide and partition global tensors for O + cta_pv_tiler_mn = cute.select(cta_pv_tiler, mode=[0, 1]) + + gO = None + if cutlass.const_expr(common_params.mAccO is not None): + gO = cute.local_tile( + common_params.mAccO[None, common_params.blk_coord[3], None, None], + cta_pv_tiler_mn, + (common_params.blk_coord[0], iter_n, common_params.blk_coord[2]), + ) + cO = cute.local_tile( + cute.make_identity_tensor( + common_params.mAccO[ + None, common_params.blk_coord[3], None, None + ].shape + ), + cta_pv_tiler_mn, + (common_params.blk_coord[0], iter_n, common_params.blk_coord[2]), + ) + else: + gO = cute.local_tile( + common_params.mO, + cta_pv_tiler_mn, + (common_params.blk_coord[0], iter_n, common_params.blk_coord[2]), + ) + cO = cute.local_tile( + cute.make_identity_tensor(common_params.mO.shape), + cta_pv_tiler_mn, + (common_params.blk_coord[0], iter_n, common_params.blk_coord[2]), + ) + tTR_tAcc = tmem_load_thr_copy.partition_S(tAcc) + tTR_gO = tmem_load_thr_copy.partition_D(gO) + tTR_cO = tmem_load_thr_copy.partition_D(cO) + tTR_rAcc = cute.make_fragment_like(tTR_gO, self.acc_dtype) + return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc + + def get_correction_factor( + self, + common_params: SimpleNamespace, + p_cor_consumer_state: pipeline.PipelineState, + ) -> tuple[ + pipeline.PipelineState, + cutlass.Float32, + cutlass.Float32, + cutlass.Float32, + cutlass.Int32, + ]: + """Get the correction factor from the P correction consumer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param p_cor_consumer_state: The P correction consumer state + :type p_cor_consumer_state: pipeline.PipelineState + + :return: The P correction consumer state, the row_sum, the row_max, and the correction factor + :rtype: tuple[pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32, cutlass.Int32] + """ + common_params.p_cor_pipeline.consumer_wait(p_cor_consumer_state) + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + # load correction factor + _, tAcc, _, _, _, _ = self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, 0 + ) + corr_layout = cute.make_layout( + (tAcc.shape[0], (4, tAcc.shape[1][1]), self.p_cor_stage), + stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), + ) + tCor = cute.make_tensor( + common_params.tmem_ptr + self.correction_factor_offset, corr_layout + ) + cCor = cute.make_identity_tensor(tCor.shape) + corr_tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype + ) + corr_tmem_load_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_load_atom, tCor) + corr_tmem_load_thr_copy = corr_tmem_load_tiled_copy.get_slice(tidx) + tCor_for_copy = corr_tmem_load_thr_copy.partition_S(tCor) + cCor_for_copy = corr_tmem_load_thr_copy.partition_D(cCor) + rCor = cute.make_fragment_like( + cCor_for_copy[None, None, None, 0], self.acc_dtype + ) + rCor_int = cute.make_tensor( + cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout + ) + cute.copy( + corr_tmem_load_tiled_copy, + tCor_for_copy[None, None, None, p_cor_consumer_state.index], + rCor, + ) + row_sum = rCor[0] + row_max = rCor[1] + correction_factor = rCor[2] + no_correction = rCor_int[3] + + common_params.p_cor_pipeline.consumer_release(p_cor_consumer_state) + p_cor_consumer_state.advance() + return p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction + + @cute.jit + def rescale( + self, + common_params: SimpleNamespace, + mma_o_consumer_state: pipeline.PipelineState, + correction_factor: cutlass.Float32, + no_correction: cutlass.Int32, + ) -> pipeline.PipelineState: + """Rescale for one k-tile. Updates the related pipeline state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param mma_o_consumer_state: The mma o consumer state + :type mma_o_consumer_state: pipeline.PipelineState + :param correction_factor: The correction factor + :type correction_factor: cutlass.Float32 + :param no_correction: Whether to apply correction factor + :type no_correction: cutlass.Int32 + + :return: The MMA o consumer state + :rtype: pipeline.PipelineState + """ + + common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) + skip_correction = cute.arch.vote_all_sync(no_correction == 1) + if not skip_correction: + for iter_n in cutlass.range_constexpr(self.iterations_pv_n): + # tmem load tiled copy and partition results. + tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( + self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, iter_n + ) + ) + + # tmem store tiled copy + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_store_tiled_copy = tcgen05.make_tmem_copy(tmem_store_atom, tAcc) + + # load o + cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) + # rescale, using `mul_packed_f32x2` to reduce the number of instructions + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.mul_packed_f32x2( + ( + tTR_rAcc[i], + tTR_rAcc[i + 1], + ), + (correction_factor, correction_factor), + ) + + # store o to tensor memory for next k tile + cute.copy(tmem_store_tiled_copy, tTR_rAcc, tTR_tAcc) + + cute.arch.fence_view_async_tmem_store() + common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) + mma_o_consumer_state.advance() + + return mma_o_consumer_state + + @cute.jit + def epilogue( + self, + common_params: SimpleNamespace, + epilogue_params: SimpleNamespace, + mma_o_consumer_state: pipeline.PipelineState, + row_sum: cutlass.Float32, + row_max: cutlass.Float32, + ) -> pipeline.PipelineState: + """Epilogue for one k-tile. Updates the related pipeline state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param epilogue_params: The epilogue parameters + :type epilogue_params: SimpleNamespace + :param mma_o_consumer_state: The mma o consumer state + :type mma_o_consumer_state: pipeline.PipelineState + :param row_sum: The row sum + :type row_sum: cutlass.Float32 + :param row_max: The row max + :type row_max: cutlass.Float32 + + :return: The MMA o consumer state + :rtype: pipeline.PipelineState + """ + # mma_o pipeline consumer wait + common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) + + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + + # exchange row_sum between warps (0, 1) and (2, 3) + if cutlass.const_expr(self.warps_in_n == 2): + common_params.smem_exchange[tidx] = row_sum + self.epilogue_exchange_sync_bar.wait() + # (64, 2) + row_sum = ( + row_sum + + common_params.smem_exchange[ + (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) + ] + ) + for iter_n in cutlass.range_constexpr(self.iterations_pv_n): + # tmem load tiled copy and partition results. + tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( + self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, iter_n + ) + ) + + # load o + cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) + + # apply output scale and normalize by row_sum + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.mul_packed_f32x2( + (tTR_rAcc[i], tTR_rAcc[i + 1]), + ( + epilogue_params.output_scale * cute.arch.rcp_approx(row_sum), + epilogue_params.output_scale * cute.arch.rcp_approx(row_sum), + ), + ) + + # store o to global memory + tR2G_rO_src = None + tR2G_rO_dst = tTR_gO + if cutlass.const_expr(common_params.mAccO is None): + tR2G_rO_src = cute.make_fragment_like(tTR_gO, self.o_dtype) + # using final output dtype for o + tR2G_rO_src.store(tTR_rAcc.load().to(self.o_dtype)) + else: + # using accumulate dtype for o + tR2G_rO_src = tTR_rAcc + + if cute.elem_less(tTR_cO[0][0], common_params.H): + cute.autovec_copy(tR2G_rO_src, tR2G_rO_dst) + + # store the lse to global memory + cta_pv_tiler = ( + self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], + self.mma_pv_tiler[1], + self.mma_pv_tiler[2], + ) + gLSE = None + cLSE = None + if cutlass.const_expr(epilogue_params.mAccLSE is None): + gLSE = cute.local_tile( + epilogue_params.mLSE, + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, None, 1), + ) + cLSE = cute.local_tile( + cute.make_identity_tensor(epilogue_params.mLSE.shape), + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, None, 1), + ) + + else: + gLSE = cute.local_tile( + epilogue_params.mAccLSE[None, common_params.blk_coord[3], None], + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, None, 1), + ) + cLSE = cute.local_tile( + cute.make_identity_tensor( + epilogue_params.mAccLSE[ + None, common_params.blk_coord[3], None + ].shape + ), + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, None, 1), + ) + lse = ( + cute.math.log2(row_sum, fastmath=True) + + epilogue_params.softmax_scale_log2 * row_max + ) + if cutlass.const_expr(self.warps_in_n == 2): + if cute.elem_less(cLSE[tidx][0], common_params.H): + gLSE[tidx] = lse + + cute.arch.fence_view_async_tmem_load() + common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) + mma_o_consumer_state.advance() + + return mma_o_consumer_state + + def make_and_init_load_qkv_pipeline( + self, load_qkv_mbar_ptr, cta_layout_vmnk, load_stages, tx_count, is_cpasync + ) -> pipeline.PipelineTmaUmma: + """Create and initialize the tma load qkv pipeline. + + :param load_qkv_mbar_ptr: The load qkv mbar pointer + :type load_qkv_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + :param load_stages: The load stages + :type load_stages: list[int] + :param tx_count: The tx count + :type tx_count: int + :param is_cpasync: Whether to use cpasync + :type is_cpasync: bool + + :return: The tma load qkv pipeline + :rtype: pipeline.PipelineTmaUmma + """ + if is_cpasync: + load_qkv_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len(self.load_cp_async_warp_ids) + * self.threads_per_warp + * self.cluster_shape_mnk[0], + ) + load_qkv_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineAsyncUmma.create( + barrier_storage=load_qkv_mbar_ptr, + num_stages=load_stages, + producer_group=load_qkv_producer_group, + consumer_group=load_qkv_consumer_group, + cta_layout_vmnk=cta_layout_vmnk, + ) + else: + load_qkv_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_tma_warp_id]) + ) + load_qkv_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_qkv_mbar_ptr, + num_stages=load_stages, + producer_group=load_qkv_producer_group, + consumer_group=load_qkv_consumer_group, + tx_count=tx_count, + cta_layout_vmnk=cta_layout_vmnk, + ) + + def make_and_init_mma_s_pipeline( + self, mma_s_mbar_ptr, cta_layout_vmnk + ) -> pipeline.PipelineUmmaAsync: + """Create and initialize the mma s pipeline. + + :param mma_s_mbar_ptr: The mma s mbar pointer + :type mma_s_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + + :return: The mma s pipeline + :rtype: pipeline.PipelineUmmaAsync + """ + + mma_s_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + consumer_thread_size = ( + self.threads_per_warp + * len(self.compute_warp_ids) + * self.cluster_shape_mnk[0] + ) + mma_s_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + consumer_thread_size, + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_s_mbar_ptr, + num_stages=self.mma_s_stage, + producer_group=mma_s_producer_group, + consumer_group=mma_s_consumer_group, + cta_layout_vmnk=cta_layout_vmnk, + ) + + def make_and_init_p_mma_pipeline( + self, p_mma_mbar_ptr, cta_layout_vmnk + ) -> pipeline.PipelineAsyncUmma: + """Create and initialize the p mma pipeline. + + :param p_mma_mbar_ptr: The p mma mbar pointer + :type p_mma_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + + :return: The p mma pipeline + :rtype: pipeline.PipelineAsyncUmma + """ + + producer_thread_size = ( + self.threads_per_warp + * len(self.compute_warp_ids) + * self.cluster_shape_mnk[0] + ) + p_mma_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + producer_thread_size, + ) + p_mma_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineAsyncUmma.create( + barrier_storage=p_mma_mbar_ptr, + num_stages=self.p_mma_stage, + producer_group=p_mma_producer_group, + consumer_group=p_mma_consumer_group, + cta_layout_vmnk=cta_layout_vmnk, + ) + + def make_and_init_p_cor_pipeline( + self, p_cor_mbar_ptr + ) -> pipeline.PipelineAsyncUmma: + """Create and initialize the p correction pipeline. + + :param p_cor_mbar_ptr: The p correction mbar pointer + :type p_cor_mbar_ptr: cute.Tensor + + :return: The p correction pipeline + :rtype: pipeline.PipelineAsyncUmma + """ + + producer_thread_size = self.threads_per_warp * len(self.compute_warp_ids) + p_cor_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + producer_thread_size, + ) + p_cor_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + producer_thread_size, + ) + return pipeline.PipelineAsync.create( + barrier_storage=p_cor_mbar_ptr, + num_stages=self.p_cor_stage, + producer_group=p_cor_producer_group, + consumer_group=p_cor_consumer_group, + ) + + def make_and_init_mma_o_pipeline( + self, mma_o_mbar_ptr, cta_layout_vmnk + ) -> pipeline.PipelineUmmaAsync: + """Create and initialize the mma o pipeline. + + :param mma_o_mbar_ptr: The mma o mbar pointer + :type mma_o_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + + :return: The mma o pipeline + :rtype: pipeline.PipelineUmmaAsync + """ + + mma_o_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + consumer_thread_size = ( + self.threads_per_warp + * len(self.compute_warp_ids) + * self.cluster_shape_mnk[0] + ) + mma_o_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + consumer_thread_size, + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_o_mbar_ptr, + num_stages=self.mma_o_stage, + producer_group=mma_o_producer_group, + consumer_group=mma_o_consumer_group, + cta_layout_vmnk=cta_layout_vmnk, + ) + + def make_and_init_load_pt_pipeline(self, load_pt_mbar_ptr): + """Create and initialize the load page table pipeline. + + :param load_pt_mbar_ptr: The load page table mbar pointer + :type load_pt_mbar_ptr: cute.Tensor + + :return: The load page table pipeline + :rtype: pipeline.PipelineAsync + """ + load_pt_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len([self.load_pt_warp_id]), + ) + load_pt_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.load_cp_async_warp_ids), + ) + return pipeline.PipelineCpAsync.create( + barrier_storage=load_pt_mbar_ptr, + num_stages=self.load_pt_stage, + producer_group=load_pt_producer_group, + consumer_group=load_pt_consumer_group, + ) + + @staticmethod + def _compute_grid( + o: cute.Tensor, + split_kv: cutlass.Int32, + cluster_shape_mnk: Tuple[int, int, int], + max_active_clusters: int, + is_persistent: bool, + ) -> Tuple[MLAStaticTileSchedulerParams, Tuple[int, int, int]]: + """Compute grid shape for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + + :return: Tile scheduler parameters and grid shape. + :rtype: tuple[MLAStaticTileSchedulerParams, tuple[int, int, int]] + """ + o_shape = o.shape + tile_sched_params = create_mla_static_tile_scheduler_params( + is_persistent, + cute.size(o_shape[2]), + cluster_shape_mnk, + split_kv, + ) + grid = MLAStaticTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def get_workspace_size( + H: int, + D: int, + B: int, + split_kv: int, + acc_dtype: Type[cutlass.Numeric], + ) -> int: + """Get the extra workspace(device memory) size for the MLA kernel when split_kv is not 1. + + :param H: The height of the output tensor C + :type H: int + :param D: The depth of the output tensor C + :type D: int + :param B: The batch size of the output tensor C + :type B: int + :param split_kv: The split key-value of the output tensor C + :type split_kv: int + :param acc_dtype: The data type of the output tensor C + :type acc_dtype: Type[cutlass.Numeric] + + :return: The workspace size for the MLA kernel + :rtype: int + """ + if split_kv == 1: + return 0 + return B * H * split_kv * (D + 1) * acc_dtype.width // 8 + + @cute.jit + def initialize_workspace( + self, + H: cutlass.Int32, + D: cutlass.Int32, + B: cutlass.Int32, + split_kv: cutlass.Int32, + acc_dtype: Type[cutlass.Numeric], + workspace: cute.Tensor, + ) -> tuple[cute.Tensor, cute.Tensor]: + """Initialize the workspace for the MLA kernel. Construct the intermediate tensors + acc_o and acc_lse. + + :param H: The height of the output tensor C + :type H: cutlass.Int32 + :param D: The depth of the output tensor C + :type D: cutlass.Int32 + :param B: The batch size of the output tensor C + :type B: cutlass.Int32 + :param split_kv: The split key-value of the output tensor C + :type split_kv: cutlass.Int32 + :param acc_dtype: The data type of the output tensor C + :type acc_dtype: Type[cutlass.Numeric] + :param workspace: The workspace tensor + :type workspace: cute.Tensor + + :return: The output tensor C and the workspace tensor + :rtype: tuple[cute.Tensor, cute.Tensor] + """ + acc_o, acc_lse = None, None + if cutlass.const_expr(workspace is not None): + align = 128 // self.q_dtype.width + acc_o_layout = cute.make_layout( + (H, split_kv, D, B), + stride=( + cute.assume(split_kv * D, align), + cute.assume(D, align), + 1, + cute.assume(H * split_kv * D, align), + ), + ) + acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype) + acc_o = cute.make_tensor(acc_o_iter, acc_o_layout) + acc_lse_layout = cute.make_layout( + (H, split_kv, B), stride=(split_kv, 1, H * split_kv) + ) + acc_lse_iter = cute.recast_ptr( + workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8, + dtype=acc_dtype, + ) + acc_lse = cute.make_tensor(acc_lse_iter, acc_lse_layout) + return acc_o, acc_lse + + @staticmethod + def can_implement( + B: int, + K: int, + H: int, + L: int, + R: int, + in_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + split_kv: int, + is_persistent: bool, + is_cpasync: bool, + is_var_seq: bool, + is_var_split_kv: bool, + use_page_table: bool, + page_size: int, + ) -> bool: + """Check if the MLA kernel can be implemented. + + :param H: The height of the output tensor C + :type H: int + :param K: The width of the output tensor C + :type K: int + :param L: The length of the output tensor C + :type L: int + :param R: The row of the output tensor C + :type R: int + :param B: The batch size of the output tensor C + :type B: int + :param in_dtype: The data type of the input tensor + :type in_dtype: Type[cutlass.Numeric] + :param out_dtype: The data type of the output tensor + :type out_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param lse_dtype: The data type of the log-sum-exp + :type lse_dtype: Type[cutlass.Numeric] + :param mma_qk_tiler_mn: The tile shape of the query-key matrix multiplication + :type mma_qk_tiler_mn: Tuple[int, int] + :param mma_pv_tiler_mn: The tile shape of the probability-value matrix multiplication + :type mma_pv_tiler_mn: Tuple[int, int] + :param split_kv: The split key-value of the output tensor C + :type split_kv: int + :param is_persistent: Whether to use persistent kernel optimization + :type is_persistent: bool + :param is_cpasync: Whether to use cpasync + :type is_cpasync: bool + :param is_var_seq: Whether to use variable sequence length + :type is_var_seq: bool + :param is_var_split_kv: Whether to use variable split_kv + :type is_var_split_kv: bool + :param use_page_table: Whether to use page table + :type use_page_table: bool + :param page_size: The page size of the page table + :type page_size: int + + :return: Whether the MLA kernel can be implemented + :rtype: bool + """ + if L != 512 or R != 64: + return False + if in_dtype not in [cutlass.Float8E4M3FN, cutlass.Float16]: + return False + if out_dtype not in [cutlass.Float8E4M3FN, cutlass.Float16]: + return False + if acc_dtype != cutlass.Float32 or lse_dtype != cutlass.Float32: + return False + if is_cpasync: + if not use_page_table: + return False + if page_size & (page_size - 1) != 0: + return False + if page_size > mma_qk_tiler_mn[1]: + return False + else: + if use_page_table and page_size != mma_qk_tiler_mn[1]: + return False + if mma_qk_tiler_mn[0] != mma_pv_tiler_mn[0] or mma_qk_tiler_mn[0] != 128: + return False + if is_var_split_kv and (not use_page_table or not is_var_seq): + return False + if is_var_seq and not use_page_table: + return False + if not is_cpasync and (H > 128 or (H < 128 and split_kv != 1)): + return False + if is_cpasync and H != 128: + return False + if K <= 0: + return False + return True + + +def ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + +def run( + batch_size: int, + seq_len: int, + num_heads: int, + latent_dim: int, + rope_dim: int, + in_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + split_kv: int, + is_persistent: bool, + is_cpasync: bool, + is_var_seq: bool, + is_var_split_kv: bool, + use_page_table: bool, + page_size: int, + softmax_scale: float, + output_scale: float, + tolerance: float, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool, + **kwargs, +): + """Execute Multi-Head Latent Attention (MLA) on Blackwell architecture and validate results. + + This function creates random input tensors for query latent/rope, compressed latent/rope, and value, + then performs the complete MLA computation pipeline. It supports configurable data types, tiling parameters, + page table, variable sequence length, and variable split_kv. Results can be validated against a PyTorch reference + implementation or run multiple times for performance measurement. + + :param batch_size: Batch size + :type batch_size: int + :param seq_len: Sequence length + :type seq_len: int + :param num_heads: Number of heads + :type num_heads: int + :param latent_dim: dimension of query/compressed latent + :type latent_dim: int + :param rope_dim: dimension of query/compressed rope + :type rope_dim: int + :param in_dtype: Input data type for query/compressed latent/rope tensors + :type in_dtype: Type[cutlass.Numeric] + :param out_dtype: Output data type for attention output + :type out_dtype: Type[cutlass.Numeric] + :param acc_dtype: Accumulator data type for query-key matrix multiplication + :type acc_dtype: Type[cutlass.Numeric] + :param lse_dtype: Accumulator data type for log-sum-exp + :type lse_dtype: Type[cutlass.Numeric] + :param mma_qk_tiler_mn: Matrix multiply accumulate tile shape (M, N) for query-key matrix multiplication + :type mma_qk_tiler_mn: Tuple[int, int] + :param mma_pv_tiler_mn: Matrix multiply accumulate tile shape (M, N) for probability-value matrix multiplication + :type mma_pv_tiler_mn: Tuple[int, int] + :param split_kv: Split key-value + :type split_kv: int + :param is_persistent: Whether to use persistent kernel optimization + :type is_persistent: bool + :param is_cpasync: Whether to use cpasync + :type is_cpasync: bool + :param is_var_seq: Whether to use variable sequence length + :type is_var_seq: bool + :param is_var_split_kv: Whether to use variable split_kv + :type is_var_split_kv: bool + :param use_page_table: Whether to use page table + :type use_page_table: bool + :param page_size: Page size of the page table + :type page_size: int + :param softmax_scale: Attention score scaling factor + :type softmax_scale: float + :param output_scale: Output scaling factor + :type output_scale: float + :param tolerance: Maximum acceptable error for validation + :type tolerance: float + :param warmup_iterations: Number of warmup iterations + :type warmup_iterations: int + :param iterations: Number of iterations to run for performance testing + :type iterations: int + :param skip_ref_check: Skip validation against reference implementation + :type skip_ref_check: bool + :param use_cold_l2: Whether to use cold L2 cache + :type use_cold_l2: bool + + :raises ValueError: If input shapes are incompatible or head dimension is unsupported + :raises RuntimeError: If GPU is unavailable for computation + """ + + print("Running Blackwell MLA test with:") + print(f" batch_size: {batch_size}") + print(f" seq_len: {seq_len}") + print(f" num_heads: {num_heads}") + print(f" latent_dim: {latent_dim}") + print(f" rope_dim: {rope_dim}") + print(f" in_dtype: {in_dtype}") + print(f" out_dtype: {out_dtype}") + print(f" acc_dtype: {acc_dtype}") + print(f" mma_qk_tiler_mn: {mma_qk_tiler_mn}") + print(f" mma_pv_tiler_mn: {mma_pv_tiler_mn}") + print(f" split_kv: {split_kv}") + print(f" is_persistent: {is_persistent}") + print(f" is_cpasync: {is_cpasync}") + print(f" is_var_seq: {is_var_seq}") + print(f" is_var_split_kv: {is_var_split_kv}") + print(f" use_page_table: {use_page_table}") + print(f" page_size: {page_size}") + print(f" softmax_scale: {softmax_scale}") + print(f" output_scale: {output_scale}") + print(f" tolerance: {tolerance}") + print(f" warmup_iterations: {warmup_iterations}") + print(f" iterations: {iterations}") + print(f" skip_ref_check: {skip_ref_check}") + print(f" use_cold_l2: {use_cold_l2}") + + # Prepare pytorch tensors: Q, K, V (random from 0 to 2) and O (all zero) + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + if not BlackwellMultiHeadLatentAttentionForward.can_implement( + batch_size, + seq_len, + num_heads, + latent_dim, + rope_dim, + in_dtype, + out_dtype, + acc_dtype, + lse_dtype, + mma_qk_tiler_mn, + mma_pv_tiler_mn, + split_kv, + is_persistent, + is_cpasync, + is_var_seq, + is_var_split_kv, + use_page_table, + page_size, + ): + raise TypeError( + f"Unsupported testcase {in_dtype}, {out_dtype}, {acc_dtype}, {lse_dtype}, {mma_qk_tiler_mn}, {mma_pv_tiler_mn}, {split_kv}, {is_persistent}, {is_cpasync}, {is_var_seq}, {is_var_split_kv}, {use_page_table}, {page_size}" + ) + + torch.manual_seed(1111) + + def create_data_tensor( + B, + HK, + D, + dtype, + is_dynamic_layout=True, + page_table=None, + cache_seqs=None, + is_lse=False, + ): + shape = (B, HK, D) + if page_table is not None: + if cache_seqs is not None: + max_seq_len = torch.max(cache_seqs) + shape = (B * ceil_div(max_seq_len, page_size), page_size, D) + else: + shape = (B * ceil_div(HK, page_size), page_size, D) + + permute_order = (1, 2, 0) + stride_order = (2, 0, 1) + leading_dim = 1 + if is_lse: + shape = (B, HK) + permute_order = (1, 0) + stride_order = (1, 0) + leading_dim = 0 + + init_config = cutlass.torch.RandomInitConfig(min_val=-2, max_val=2) + + torch_dtype = ( + cutlass_torch.dtype(dtype) if dtype != cutlass.Float8E4M3FN else torch.int8 + ) + + # Create dtype torch tensor (cpu) + torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=init_config, + ) + + # Create dtype torch tensor (gpu) + torch_tensor_gpu = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + # Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance. + cute_tensor = from_dlpack( + torch_tensor_gpu, assumed_align=16, use_32bit_stride=True + ) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) + if not is_lse: + cute_tensor = cute_tensor.mark_compact_shape_dynamic( + mode=leading_dim, + stride_order=stride_order, + divisibility=(128 // dtype.width), + ) + + cute_tensor = cutlass_torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor_gpu + + def create_cache_seqs(batch_size, seq_len, is_var_seq): + cache_seqs_ref = torch.ones(batch_size, dtype=torch.int32) * seq_len + cache_seqs_gpu = cache_seqs_ref.cuda() + # Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance. + cache_seqs = from_dlpack( + cache_seqs_gpu, assumed_align=16, use_32bit_stride=True + ).mark_layout_dynamic() + if is_var_seq: + max_seq_len = seq_len + min_seq_len = int(seq_len * 0.8) + cache_seqs_ref = cutlass_torch.create_and_permute_torch_tensor( + (batch_size,), + torch.int32, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=cutlass.torch.RandomInitConfig( + min_val=min_seq_len, max_val=max_seq_len + 1 + ), + ) + cache_seqs_gpu = cache_seqs_ref.cuda() + cache_seqs = from_dlpack( + # Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance. + cache_seqs_gpu, + assumed_align=16, + use_32bit_stride=True, + ).mark_layout_dynamic() + return cache_seqs_ref, cache_seqs, cache_seqs_gpu + + def create_page_table(batch_size, seq_len, is_var_seq, use_page_table, page_size): + page_table_ref, page_table, page_table_gpu = None, None, None + if use_page_table: + max_seq_len = seq_len if not is_var_seq else torch.max(cache_seqs_ref) + page_count = ceil_div(max_seq_len, page_size) + page_table_ref = torch.empty([batch_size, page_count], dtype=torch.int32) + # use transposed index for page table to make sure the value is in bound of `batch_size * seq_len_block`. In practice, the value could be any positive values. This setting is only for testing purpose. + for b in range(batch_size): + for j in range(page_count): + page_table_ref[b, j] = b + j * batch_size + page_table_gpu = page_table_ref.permute(1, 0).cuda() + # Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance. + page_table = from_dlpack( + page_table_gpu, assumed_align=16, use_32bit_stride=True + ).mark_layout_dynamic(leading_dim=0) + return page_table_ref, page_table, page_table_gpu + + def create_block_split_kvs( + batch_size, + split_kv, + cache_seqs_ref, + is_var_split_kv, + mma_qk_tiler_mn, + cluster_shape_mnk, + max_active_clusters, + ): + block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu = None, None, None + # check if split_kv is valid otherwise do auto setting of split_kv + if is_var_split_kv: + block_split_kvs_ref = torch.zeros([batch_size], dtype=torch.int32) + for b in range(batch_size): + block_split_kvs_ref[b] = ( + BlackwellMultiHeadLatentAttentionForward.get_split_kv( + batch_size, + cache_seqs_ref[b].item(), + mma_qk_tiler_mn, + max_active_clusters * cluster_shape_mnk[0], + ) + ) + split_kv = torch.max(block_split_kvs_ref).item() + block_split_kvs_gpu = block_split_kvs_ref.cuda() + # Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance. + block_split_kvs = from_dlpack( + block_split_kvs_gpu, assumed_align=16, use_32bit_stride=True + ).mark_layout_dynamic() + elif split_kv <= 0: + split_kv = BlackwellMultiHeadLatentAttentionForward.get_split_kv( + batch_size, + cache_seqs_ref[0].item(), + mma_qk_tiler_mn, + max_active_clusters * cluster_shape_mnk[0], + ) + return split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu + + def create_workspace(num_heads, latent_dim, batch_size, split_kv, acc_dtype): + workspace_size = BlackwellMultiHeadLatentAttentionForward.get_workspace_size( + num_heads, + latent_dim, + batch_size, + split_kv, + acc_dtype, + ) + + workspace, workspace_torch = None, None + if workspace_size > 0: + workspace_torch = torch.empty([workspace_size], dtype=torch.int8).cuda() + # Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance. + workspace = from_dlpack( + workspace_torch, assumed_align=16, use_32bit_stride=True + ) + return workspace, workspace_torch + + cache_seqs_ref, cache_seqs, cache_seqs_torch = create_cache_seqs( + batch_size, seq_len, is_var_seq + ) + page_table_ref, page_table, page_table_torch = create_page_table( + batch_size, seq_len, is_var_seq, use_page_table, page_size + ) + cluster_shape_mnk = (2, 1, 1) + hardware_info = utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mnk[0] * cluster_shape_mnk[1] + ) + split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_torch = ( + create_block_split_kvs( + batch_size, + split_kv, + cache_seqs_ref, + is_var_split_kv, + mma_qk_tiler_mn, + cluster_shape_mnk, + max_active_clusters, + ) + ) + + q_latent_ref, q_latent, q_latent_torch = create_data_tensor( + batch_size, + num_heads, + latent_dim, + in_dtype, + is_dynamic_layout=True, + ) + q_rope_ref, q_rope, q_rope_torch = create_data_tensor( + batch_size, + num_heads, + rope_dim, + in_dtype, + is_dynamic_layout=True, + ) + + c_latent_ref, c_latent, c_latent_torch = create_data_tensor( + batch_size, + seq_len, + latent_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + c_rope_ref, c_rope, c_rope_torch = create_data_tensor( + batch_size, + seq_len, + rope_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + o_ref, o, o_torch = create_data_tensor( + batch_size, num_heads, latent_dim, out_dtype, is_dynamic_layout=True + ) + lse_ref, lse, lse_torch = create_data_tensor( + batch_size, num_heads, 1, lse_dtype, is_dynamic_layout=True, is_lse=True + ) + workspace, workspace_torch = create_workspace( + num_heads, latent_dim, batch_size, split_kv, acc_dtype + ) + + mla = BlackwellMultiHeadLatentAttentionForward( + acc_dtype, + lse_dtype, + mma_qk_tiler_mn, + mma_pv_tiler_mn, + max_active_clusters, + is_persistent, + is_cpasync, + use_page_table, + is_var_seq, + is_var_split_kv, + ) + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + stream = cuda.CUstream(torch_stream.cuda_stream) + + # compile mla kernel + compiled_mla = cute.compile( + mla, + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + o, + lse, + workspace, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale, + output_scale, + stream, + ) + + def torch_reference_mla( + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + cache_seqs, + softmax_scale=1.0, + output_scale=1.0, + ): + # expand and concat q_latent and q_rope to have the dimension of sequence length for q + q_ref = torch.cat([q_latent, q_rope], dim=1).permute(2, 0, 1).unsqueeze(2) + # expand and concat c_latent and c_rope to have the dimension of num_heads for k and v + if use_page_table: + page_count = page_table_ref.shape[1] + k_ref_paged = ( + torch.cat([c_latent, c_rope], dim=1) + .permute(2, 0, 1) + .reshape(batch_size * page_count, page_size, latent_dim + rope_dim) + ) + v_ref_paged = c_latent.permute(2, 0, 1).reshape( + batch_size * page_count, page_size, latent_dim + ) + + if is_var_seq: + max_seq_len = torch.max(cache_seqs_ref) + else: + max_seq_len = seq_len + + k_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim + rope_dim]) + v_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim]) + k_ref = torch.index_select( + k_ref_paged, 0, torch.flatten(page_table_ref) + ).reshape(batch_size, 1, -1, latent_dim + rope_dim)[:, :, :max_seq_len, :] + v_ref = torch.index_select( + v_ref_paged, 0, torch.flatten(page_table_ref) + ).reshape(batch_size, 1, -1, latent_dim)[:, :, :max_seq_len, :] + for b in range(batch_size): + k_ref[b, :, cache_seqs_ref[b] :, :] = 0 + v_ref[b, :, cache_seqs_ref[b] :, :] = 0 + else: + k_ref = torch.cat([c_latent, c_rope], dim=1).permute(2, 0, 1).unsqueeze(1) + v_ref = c_latent.permute(2, 0, 1).unsqueeze(1) + + o_ref = F.scaled_dot_product_attention( + q_ref, + k_ref, + v_ref, + attn_mask=None, + dropout_p=0.0, + scale=softmax_scale, + is_causal=False, + ) + s_ref = torch.einsum("bhld,bhsd->bhls", q_ref, k_ref) + s_ref_max = torch.max(s_ref, dim=-1, keepdim=True).values + softmax_scale_log2 = LOG2_E * softmax_scale + s_ref_sum = torch.sum( + torch.exp2((s_ref - s_ref_max) * softmax_scale_log2), dim=-1, keepdim=True + ) + lse_ref = s_ref_max * softmax_scale_log2 + torch.log2(s_ref_sum) + lse_ref = lse_ref.squeeze(3).squeeze(2).permute(1, 0) + o_ref = o_ref * output_scale + o_ref = o_ref.squeeze(2).permute(1, 2, 0) + + return o_ref, lse_ref + + if not skip_ref_check: + # Execute kernel once for reference checking + compiled_mla( + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + o, + lse, + workspace, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale, + output_scale, + stream, + ) + torch.cuda.synchronize() + print("Verifying results...") + if in_dtype == cutlass.Float8E4M3FN: + tolerance = 0.13 + o_ref, lse_ref = torch_reference_mla( + q_latent_ref, + q_rope_ref, + c_latent_ref, + c_rope_ref, + page_table, + cache_seqs, + softmax_scale, + output_scale, + ) + + if out_dtype in [cutlass.Float8E5M2, cutlass.Float8E4M3FN]: + # convert o back to f32 for comparison + o_fp32, o_fp32_torch = cutlass_torch.cute_tensor_like( + torch.empty(*o_torch.shape, dtype=torch.float32), + cutlass.Float32, + is_dynamic_layout=True, + assumed_align=16, + ) + cute.testing.convert(o, o_fp32) + o = o_fp32_torch.cpu() + ref_fp8, _ = cutlass_torch.cute_tensor_like( + torch.empty(*o_ref.permute(2, 0, 1).shape, dtype=torch.uint8).permute( + 1, 2, 0 + ), + out_dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + o_ref_gpu = o_ref.cuda() + # Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance. + o_ref_f32 = from_dlpack( + o_ref_gpu, use_32bit_stride=True + ).mark_layout_dynamic(leading_dim=1) + + # convert ref : f32 -> fp8 -> f32 + cute.testing.convert(o_ref_f32, ref_fp8) + cute.testing.convert(ref_fp8, o_ref_f32) + + o_ref = o_ref_gpu.cpu() + else: + o = o_torch.cpu().to(torch.float32) + lse = lse_torch.cpu() + lse_ref = lse_ref.to(cutlass.torch.dtype(lse_dtype)) + # Assert close results + torch.testing.assert_close(o, o_ref, atol=tolerance, rtol=1e-05) + torch.testing.assert_close(lse, lse_ref, atol=tolerance, rtol=1e-05) + print("Results verified successfully!") + + def generate_tensors(): + _, cache_seqs, _ = create_cache_seqs(batch_size, seq_len, is_var_seq) + _, page_table, _ = create_page_table( + batch_size, seq_len, is_var_seq, use_page_table, page_size + ) + _split_kv, _, block_split_kvs, _ = create_block_split_kvs( + batch_size, + split_kv, + cache_seqs_ref, + is_var_split_kv, + mma_qk_tiler_mn, + cluster_shape_mnk, + max_active_clusters, + ) + + _, q_latent, _ = create_data_tensor( + batch_size, + num_heads, + latent_dim, + in_dtype, + is_dynamic_layout=True, + ) + _, q_rope, _ = create_data_tensor( + batch_size, + num_heads, + rope_dim, + in_dtype, + is_dynamic_layout=True, + ) + + _, c_latent, _ = create_data_tensor( + batch_size, + seq_len, + latent_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + _, c_rope, _ = create_data_tensor( + batch_size, + seq_len, + rope_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + _, o, _ = create_data_tensor( + batch_size, num_heads, latent_dim, out_dtype, is_dynamic_layout=True + ) + _, lse, _ = create_data_tensor( + batch_size, num_heads, 1, lse_dtype, is_dynamic_layout=True, is_lse=True + ) + workspace, workspace_torch = create_workspace( + num_heads, latent_dim, batch_size, _split_kv, acc_dtype + ) + return testing.JitArguments( + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + o, + lse, + workspace, + _split_kv, + cache_seqs, + block_split_kvs, + softmax_scale, + output_scale, + stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + q_latent_torch.numel() * q_latent_torch.element_size() + + q_rope_torch.numel() * q_rope_torch.element_size() + + c_latent_torch.numel() * c_latent_torch.element_size() + + c_rope_torch.numel() * c_rope_torch.element_size() + + o_torch.numel() * o_torch.element_size() + + lse_torch.numel() * lse_torch.element_size() + + cache_seqs_torch.numel() * cache_seqs_torch.element_size() + ) + if use_page_table: + one_workspace_bytes += ( + page_table_torch.numel() * page_table_torch.element_size() + ) + if is_var_split_kv: + one_workspace_bytes += ( + block_split_kvs_torch.numel() * block_split_kvs_torch.element_size() + ) + if workspace_torch is not None: + one_workspace_bytes += ( + workspace_torch.numel() * workspace_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + avg_time_us = testing.benchmark( + compiled_mla, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return avg_time_us # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + def parse_mma_tiler(s: str) -> Tuple[int, int, Tuple[int, int]]: + ret = parse_comma_separated_ints(s) + if len(ret) != 2: + raise argparse.ArgumentTypeError( + "Invalid format. Expected 2 comma-separated integers." + ) + return (ret[0], ret[1]) + + parser = argparse.ArgumentParser(description="Example of MLA on Blackwell.") + + parser.add_argument( + "--in_dtype", + type=cutlass.dtype, + default=cutlass.Float8E4M3FN, + help="Input data type", + ) + + parser.add_argument( + "--out_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + help="Output data type", + ) + + parser.add_argument( + "--acc_dtype", + type=cutlass.dtype, + default=cutlass.Float32, + help="Accumulator data type", + ) + + parser.add_argument( + "--lse_dtype", + type=cutlass.dtype, + default=cutlass.Float32, + help="LSE data type", + ) + parser.add_argument( + "--mma_qk_tiler_mn", + type=parse_mma_tiler, + default=(128, 128), + help="MMA tile shape (H, K)", + ) + parser.add_argument( + "--mma_pv_tiler_mn", + type=parse_mma_tiler, + default=(128, 256), + help="MMA tile shape (H, D)", + ) + + parser.add_argument( + "--is_persistent", + action="store_true", + help="Is persistent", + ) + + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size", + ) + + parser.add_argument( + "--seq_len", + type=int, + default=128, + help="Sequence length of K/V", + ) + + parser.add_argument( + "--num_heads", + type=int, + default=128, + help="Number of heads of Q", + ) + + parser.add_argument( + "--latent_dim", + type=int, + default=512, + help="Latent dimension of Q/C", + ) + + parser.add_argument( + "--rope_dim", + type=int, + default=64, + help="Rope dimension of Q/C", + ) + + parser.add_argument( + "--is_cpasync", + action="store_true", + help="Use cpasync for load or not", + ) + + parser.add_argument( + "--is_var_seq", + action="store_true", + help="Use variable length of sequence length or not", + ) + + parser.add_argument( + "--is_var_split_kv", + action="store_true", + help="Use variable length of split kv or not", + ) + + parser.add_argument( + "--use_page_table", + action="store_true", + help="Use page table or not, must be True when is_cpasync is True", + ) + + parser.add_argument( + "--page_size", + type=int, + default=128, + help="Page size of page table", + ) + + parser.add_argument( + "--split_kv", + type=int, + default=-1, + help="Split KV setting", + ) + + parser.add_argument( + "--softmax_scale", + type=float, + default=1.0, + help="Scaling factor to scale softmax", + ) + + parser.add_argument( + "--output_scale", + type=float, + default=1.0, + help="Scaling factor to scale output", + ) + + parser.add_argument( + "--tolerance", type=float, default=1e-02, help="Tolerance for validation" + ) + + parser.add_argument( + "--warmup_iterations", + type=int, + default=0, + help="Number of iterations for warmup", + ) + + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations after warmup", + ) + + parser.add_argument( + "--skip_ref_check", + action="store_true", + help="Skip reference check", + ) + + parser.add_argument( + "--use_cold_l2", + action="store_true", + help="Use cold L2 cache", + ) + + args = parser.parse_args() + + run( + args.batch_size, + args.seq_len, + args.num_heads, + args.latent_dim, + args.rope_dim, + args.in_dtype, + args.out_dtype, + args.acc_dtype, + args.lse_dtype, + args.mma_qk_tiler_mn, + args.mma_pv_tiler_mn, + args.split_kv, + args.is_persistent, + args.is_cpasync, + args.is_var_seq, + args.is_var_split_kv, + args.use_page_table, + args.page_size, + args.softmax_scale, + args.output_scale, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + + print("PASS") diff --git a/examples/python/CuTeDSL/blackwell/programmatic_dependent_launch.py b/examples/python/CuTeDSL/blackwell/programmatic_dependent_launch.py new file mode 100644 index 00000000..44948c26 --- /dev/null +++ b/examples/python/CuTeDSL/blackwell/programmatic_dependent_launch.py @@ -0,0 +1,381 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import argparse +import cuda.bindings.driver as cuda + +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +from cutlass.cute.runtime import from_dlpack + + +def supports_pdl(): + return torch.cuda.get_device_capability()[0] >= 9 + + +""" +This example demonstrates the use of Programmatic Dependent Launch (PDL) using +CuTe DSL. + +PDL is a mechanism which allows for overlapping execution of back-to-back kernels +within the same stream. +For example, consider the following two elementwise add operations, where the second +operation's first operand is the result of the first operation. While performing +``w = u + v`` we will load u and v, add them, and then store the result. Once we +have finished loading data, we are no longer utilizing the read bandwidth. +To effectively utilize the read bandwidth, we can start loading ``x`` +immediately upon finishing reading. This is what PDL enables us to do. + +.. code-block:: bash + +w = u + v +y = w + x + +To enable PDL, we need to do two things: + +1. Insert the ``griddepcontrol.launch_dependents`` and ``griddepcontrol.wait`` instructions in the kernel. +2. Set the PDL launch attribute when launching the kernel. + +The ``griddepcontrol.launch_dependents`` and ``griddepcontrol.wait`` +instructions enable fine-grained control over kernel execution in PDL. +Once all thread blocks execute the ``griddepcontrol.launch_dependents`` +instruction, the dependent kernels can opportunistically be early-launched. +``griddepcontrol.wait`` functions as a synchronization barrier - any warp +executing this instruction will block until the previous kernel finishes +execution. This allows precise control over data dependencies between kernels. + +The following diagram shows the overlapping execution of two dependent kernels. +We call the instructions before ``griddepcontrol.wait`` as prologue (``P0``), +which may include barrier initialization and loading of independent data, etc. +We call the instructions after ``griddepcontrol.launch_dependents`` as epilogue +(``P2``), which may include math operations, data stores, etc. PDL enables +these prologue and epilogue phases to execute concurrently across dependent +kernels, improving GPU resource utilization. This is particularly beneficial +when prologue and epilogue are bound by different resources (e.g., memory +bandwidth vs compute throughput). + + # P0: Prologue, P1: Main compute, P2: Epilogue + + P0 P1 P2 + K1: |=====|+++++|-----| + + <-----> K2 can start early + (K1's P2 overlaps with K2's P0) + + P0 P1 P2 + K2: |=====| |+++++|-----| + ^ + | + wait for K1 to complete +Time ------------------------------------------------------> + +We could run this example with and without PDL: + +.. code-block:: bash + + python examples/blackwell/programmatic_dependent_launch.py --benchmark + python examples/blackwell/programmatic_dependent_launch.py --benchmark --use_pdl + +From the benchmark results, you can see some speedups for the PDL version in most cases, benefiting from +the overlapping execution of consecutive kernels. Moreover, you can use nsys to observe the overlapping execution. + +.. code-block:: bash + + nsys profile python examples/blackwell/programmatic_dependent_launch.py --benchmark --use_pdl + +Note, PDL feature is supported on Hopper and later GPUs. + +See [the programming guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization) +and the [PTX documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol) +for more details. +""" + + +@cute.kernel +def elementwise_add_kernel( + gA: cute.Tensor, + gB: cute.Tensor, + gC: cute.Tensor, + cC: cute.Tensor, # coordinate tensor + shape: cute.Shape, + thr_layout: cute.Layout, + val_layout: cute.Layout, + use_pdl: cutlass.Constexpr = True, + is_first_kernel: cutlass.Constexpr = True, +): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + blk_coord = ((None, None), bidx) + blkA = gA[blk_coord] # (TileM,TileN) + blkB = gB[blk_coord] # (TileM,TileN) + blkC = gC[blk_coord] # (TileM,TileN) + blkCrd = cC[blk_coord] # (TileM, TileN) + + copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type) + copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type) + + tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout) + tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout) + tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout) + + thr_copy_A = tiled_copy_A.get_slice(tidx) + thr_copy_B = tiled_copy_B.get_slice(tidx) + thr_copy_C = tiled_copy_C.get_slice(tidx) + + thrA = thr_copy_A.partition_S(blkA) + thrB = thr_copy_B.partition_S(blkB) + thrC = thr_copy_C.partition_S(blkC) + + frgA = cute.make_fragment_like(thrA) + frgB = cute.make_fragment_like(thrB) + frgC = cute.make_fragment_like(thrC) + + thrCrd = thr_copy_C.partition_S(blkCrd) + frgPred = cute.make_rmem_tensor(thrCrd.shape, cutlass.Boolean) + + for i in range(cute.size(frgPred)): + val = cute.elem_less(thrCrd[i], shape) + frgPred[i] = val + + # Note: when not using cuda-graph, the kernel execution may be blocked by the host overhead. + # In this case we won't see overlapping even when pdl is enabled. + # In this example, we add a loop (10 times) for all the copy and compute operations in the following code + # to make kernel running longer and make pdl benefits observable for both cuda-graph enabled and disabled cases. + if not use_pdl: + for _ in range(10): + cute.copy(copy_atom_load, thrA, frgA, pred=frgPred) + cute.copy(copy_atom_load, thrB, frgB, pred=frgPred) + else: + if is_first_kernel: + for _ in range(10): + cute.copy(copy_atom_load, thrA, frgA, pred=frgPred) + cute.copy(copy_atom_load, thrB, frgB, pred=frgPred) + # Here we add the launch dependents instruction for the first kernel as a hint to the runtime to early-launch + # the next kernel. If the next kernel becomes concurrent, we will have overlap where the second kernel + # can start reading x to ensure an E2E speedup. Note the placement of launch dependents has no implication + # on correctness, only performance. + cute.arch.griddepcontrol_launch_dependents() + else: + # In this example, the second kernel's second operand ``gB`` has no dependencies, its loading can overlap + # with the computation of ``gC`` from the first kernel. + for _ in range(10): + cute.copy(copy_atom_load, thrB, frgB, pred=frgPred) + + # For the second kernel, its first operand ``gA`` is dependent on the previous kernel, we must call + # griddepcontrol.wait to assure correctness. This instruction will block until the prior kernels finishes + # and its memory operations are visible. Since gA is written by the prior kernel, this will block until gA + # is visible to our kernel. Without it, we would have undefined behavior due to a race condition. + cute.arch.griddepcontrol_wait() + + for _ in range(10): + cute.copy(copy_atom_load, thrA, frgA, pred=frgPred) + + for _ in range(10): + result = frgA.load() + frgB.load() + frgC.store(result) + cute.copy(copy_atom_store, frgC, thrC, pred=frgPred) + + +@cute.jit +def elementwise_add( + mA, + mB, + mC, + stream: cuda.CUstream, + use_pdl: cutlass.Constexpr = True, + is_first_kernel: cutlass.Constexpr = True, +): + dtype = mA.element_type + # copy_bits for a thread is 128 bits, and we use 128 // dtype.width to get the vector size + vector_size = 128 // dtype.width + + thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0)) + val_layout = cute.make_ordered_layout((4, vector_size), order=(1, 0)) + tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout) + + gA = cute.zipped_divide(mA, tiler_mn) # ((TileM,TileN),(RestM,RestN)) + gB = cute.zipped_divide(mB, tiler_mn) # ((TileM,TileN),(RestM,RestN)) + gC = cute.zipped_divide(mC, tiler_mn) # ((TileM,TileN),(RestM,RestN)) + + idC = cute.make_identity_tensor(mC.shape) + cC = cute.zipped_divide(idC, tiler=tiler_mn) + + elementwise_add_kernel( + gA, gB, gC, cC, mC.shape, thr_layout, val_layout, use_pdl, is_first_kernel + ).launch( + grid=[cute.size(gC, mode=[1]), 1, 1], + block=[cute.size(tv_layout, mode=[0]), 1, 1], + # set cluster to enable cuLaunchKernelEx API for additional launch attributes setting + cluster=(1, 1, 1), + stream=stream, + # Currently, pdl launch attribute is set in compile phase, + # so we need to recompile the function if we change the value of use_pdl for multiple runs. + use_pdl=use_pdl, + ) + + +def run_pdl_example( + M, + N, + skip_ref_check=False, + benchmark=True, + warmup_iterations=5, + iterations=10, + use_pdl=True, +): + if not torch.cuda.is_available(): + raise RuntimeError("Blackwell/Hopper GPU is required to run this example!") + + print("\nRunning Elementwise Add test with:") + print(f"Tensor dimensions: [{M}, {N}]") + print(f"Use PDL: {use_pdl}") + + u = torch.randn(M, N, dtype=torch.float32, device="cuda") + v = torch.randn(M, N, dtype=torch.float32, device="cuda") + w = torch.randn(M, N, dtype=torch.float32, device="cuda") + x = torch.randn(M, N, dtype=torch.float32, device="cuda") + y = torch.empty(M, N, dtype=torch.float32, device="cuda") + + u_tensor = from_dlpack(u).mark_layout_dynamic() + v_tensor = from_dlpack(v).mark_layout_dynamic() + w_tensor = from_dlpack(w).mark_layout_dynamic() + x_tensor = from_dlpack(x).mark_layout_dynamic() + y_tensor = from_dlpack(y).mark_layout_dynamic() + + stream = torch.cuda.Stream() + current_stream = cuda.CUstream(stream.cuda_stream) + # Since use_pdl and is_first_kernel are cutlass.Constexpr, we need to compile for + # the first and second kernel separately. + compiled_func_first_kernel = cute.compile( + elementwise_add, + u_tensor, + v_tensor, + w_tensor, + current_stream, + use_pdl, + is_first_kernel=True, + ) + compiled_func_second_kernel = cute.compile( + elementwise_add, + w_tensor, + x_tensor, + y_tensor, + current_stream, + use_pdl, + is_first_kernel=False, + ) + + # launch and run the two consecutive kernels in a same stream. + # Here, we simply use default stream. + def run_func(current_stream, u_tensor, v_tensor, w_tensor, x_tensor, y_tensor): + # Run first operation: w_tensor = u_tensor + v_tensor + compiled_func_first_kernel( + u_tensor, + v_tensor, + w_tensor, + current_stream, + ) + # Run second operation: y_tensor = w_tensor + x_tensor + # its first operand ``w_tensor`` is the result of the first operation, + # they use the same memory space. + compiled_func_second_kernel( + w_tensor, + x_tensor, + y_tensor, + current_stream, + ) + + if not skip_ref_check: + run_func(current_stream, u_tensor, v_tensor, w_tensor, x_tensor, y_tensor) + print("Verifying results...") + torch.testing.assert_close(u.cpu() + v.cpu() + x.cpu(), y.cpu()) + print("Results verified successfully!") + + if not benchmark: + return + + def generate_kernel_arguments(): + u = torch.randn(M, N, dtype=torch.float32, device="cuda") + v = torch.randn(M, N, dtype=torch.float32, device="cuda") + w = torch.randn(M, N, dtype=torch.float32, device="cuda") + x = torch.randn(M, N, dtype=torch.float32, device="cuda") + y = torch.empty(M, N, dtype=torch.float32, device="cuda") + + u_tensor = from_dlpack(u).mark_layout_dynamic() + v_tensor = from_dlpack(v).mark_layout_dynamic() + w_tensor = from_dlpack(w).mark_layout_dynamic() + x_tensor = from_dlpack(x).mark_layout_dynamic() + y_tensor = from_dlpack(y).mark_layout_dynamic() + return testing.JitArguments( + current_stream, u_tensor, v_tensor, w_tensor, x_tensor, y_tensor + ) + + avg_time_us = testing.benchmark( + run_func, + workspace_generator=generate_kernel_arguments, + workspace_count=10, + warmup_iterations=warmup_iterations, + iterations=iterations, + stream=current_stream, + ) + print(f"Execution time: {avg_time_us:.4f} us") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="example of Programmatic Dependent Launch (PDL) using CuTe DSL" + ) + parser.add_argument("--M", default=512, type=int) + parser.add_argument("--N", default=512, type=int) + parser.add_argument("--warmup_iterations", default=3, type=int) + parser.add_argument("--iterations", default=10, type=int) + parser.add_argument("--skip_ref_check", action="store_true") + parser.add_argument("--benchmark", action="store_true") + parser.add_argument("--use_pdl", action="store_true") + + args = parser.parse_args() + if supports_pdl(): + run_pdl_example( + args.M, + args.N, + skip_ref_check=args.skip_ref_check, + benchmark=args.benchmark, + warmup_iterations=args.warmup_iterations, + iterations=args.iterations, + use_pdl=args.use_pdl, + ) + print("\nPASS") + else: + print( + "PDL is not supported on this device, it requires Hopper or newer generations" + ) diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/README.md b/examples/python/CuTeDSL/blackwell/tutorial_gemm/README.md new file mode 100644 index 00000000..fc9b2f26 --- /dev/null +++ b/examples/python/CuTeDSL/blackwell/tutorial_gemm/README.md @@ -0,0 +1,25 @@ +# CUTLASS Tutorial Examples for Blackwell GEMM + +This folder contains tutorial examples demonstrating how to write performant GEMM (General Matrix Multiplication) kernels using Tensor Cores on NVIDIA Blackwell GPUs. + +## Overview + +The examples showcase different scenarios and optimization techniques for implementing GEMM operations: + +- Basic FP16 GEMM implementation +- Software Pipeline optimizations +- Tensor Core utilization +- Thread/warp/block level parallelism + +## Examples + +### tutorial_fp16_gemm_0.py + +A basic example showing: +- FP16 GEMM implementation using Tensor Cores +- TMA (Tensor Memory Access) for efficient data loading +- SMEM (Shared Memory) layouts and access patterns +- Usage of ``cutlass.range(..., prefetch_stages=...)`` to replace boilerplate code for multi-stage software pipeline + +With some minor optimization tricks +- Tiling Epilogue to avoid bursty write out and reduce register pressure diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py new file mode 100644 index 00000000..6e720235 --- /dev/null +++ b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py @@ -0,0 +1,444 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import argparse +import torch +from typing import Tuple + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.torch as cutlass_torch +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack + +""" +The first tutorial GEMM demonstrating a simple kernel implementation in CuTeDSL + +This dense GEMM kernel is implemented in just over 200 lines of code. +With large tile sizes, it can achieve very high performance on 8k×8k×8k problem sizes. +It can serve as a starting point to help users quickly experiment +with optimizations for challenges that may arise with other problem sizes. + +To run this example: +.. code-block:: bash + + python examples/blackwell/tutorial_fp16_gemm_0.py \ + --mnk 8192,8192,8192 \ + --tolerance 1e-01 + +Constraints for this example: +* The problem size of m and n must be divisible by the tile size m & n (128, 256) +""" + +io_dtype = cutlass.Float16 +acc_dtype = cutlass.Float32 +mma_inst_shape_mnk = (128, 256, 16) +mma_tiler_mnk = (128, 256, 64) +threads_per_cta = 128 + +# Pipeline stage configuration +ab_stages = 4 +acc_stage = 1 + + +@cute.struct +class SharedStorage: + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, ab_stages * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, acc_stage * 2] + tmem_holding_buf: cutlass.Int32 + + +@cute.kernel +def kernel( + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + mC_mnl: cute.Tensor, + a_smem_layout: cute.ComposedLayout, + b_smem_layout: cute.ComposedLayout, +): + # Current thread/warp/block coordinates + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + bidx, bidy, _ = cute.arch.block_idx() + mma_coord_mnk = (bidx, bidy, None) + + # + # 1. Prepare args + # + + # Allocate SMEM + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sA = smem.allocate_tensor( + element_type=io_dtype, + layout=a_smem_layout.outer, + byte_alignment=128, + swizzle=a_smem_layout.inner, + ) + sB = smem.allocate_tensor( + element_type=io_dtype, + layout=b_smem_layout.outer, + byte_alignment=128, + swizzle=b_smem_layout.inner, + ) + + # Allocate all TMEM columns + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=threads_per_cta, + ) + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + ) + num_tmem_cols = 512 + tmem.allocate(num_tmem_cols) + + # Prefetch tma descriptor + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + + # Pipeline configuration + num_tma_copy_bytes = cute.size_in_bytes( + io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2]) + ) + cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2])) + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( + num_stages=ab_stages, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=num_tma_copy_bytes, + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + ).make_participants() + acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, threads_per_cta + ), + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + ).make_participants() + + # Partition tensors for MMA and make fragments + # (bM, bK, RestK) + gA = cute.local_tile(mA_mkl, mma_tiler_mnk, mma_coord_mnk, proj=(1, None, 1)) + # (bN, bK, RestK) + gB = cute.local_tile(mB_nkl, mma_tiler_mnk, mma_coord_mnk, proj=(None, 1, 1)) + # (bM, bN) + gC = cute.local_tile(mC_mnl, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None)) + thr_mma = tiled_mma.get_slice(0) + # (MMA, MMA_M, MMA_K) + tCgA = thr_mma.partition_A(gA) + # (MMA, MMA_N, MMA_K) + tCgB = thr_mma.partition_B(gB) + # (MMA, MMA_M, MMA_N) + tCgC = thr_mma.partition_C(gC) + # (MMA, MMA_M, MMA_K) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) + # (MMA, MMA_M, MMA_N) + tCtAcc = tiled_mma.make_fragment_C(acc_shape) + # Partition tensors for TMA; This requires the tensors partitioned for MMA + tAsA, tAgA = cute.nvgpu.cpasync.tma_partition( + tma_atom_a, + 0, + cute.make_layout(1), + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + tBsB, tBgB = cute.nvgpu.cpasync.tma_partition( + tma_atom_b, + 0, + cute.make_layout(1), + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # CTA-wide sync before retrieving the pointer to the start of the allocated TMEM + # Only warp 0 does the allocation so we need to sync before retrieving the TMEM start address + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(acc_dtype) + # Swap the pointer in tCtAcc + tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc.layout) + + subtile_cnt = 4 + # (EpiTile) + epi_tiler = ( + (cute.size(tCtAcc, mode=[0, 0]), cute.size(tCtAcc, mode=[0, 1]) // subtile_cnt), + ) + # (EpiTile, NumTiles) + tCtAcc_epi = cute.zipped_divide(tCtAcc, epi_tiler) + # (EpiTile, NumTiles) + gC_epi = cute.zipped_divide(tCgC, epi_tiler) + + # Every thread loads 32x128 bits + tmem_atom = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition.x64), + cutlass.Float32, + ) + tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_atom, tCtAcc_epi[None, 0]) + tmem_thr_copy = tmem_tiled_copy.get_slice(tidx) + + # (TmemCpy,NumTmemCpy,NumTiles) + tDtC = tmem_thr_copy.partition_S(tCtAcc_epi) + # (TmemCpy,NumTmemCpy,NumTiles) + tDgC = tmem_thr_copy.partition_D(gC_epi) + + # (TmemCpy,NumTmemCpy) + tCrAcc = cute.make_rmem_tensor(tDgC[None, None, 0].shape, acc_dtype) + # (TmemCpy,NumTmemCpy) + tCrC = cute.make_rmem_tensor(tDgC[None, None, 0].shape, io_dtype) + + # + # 2. Main loop + # + num_k_tiles = cute.size(gA, mode=[2]) + if warp_idx == 0: + # Wait for a empty accumulator buffer + acc_empty = acc_producer.acquire_and_advance() + for k_tile_idx in cutlass.range(num_k_tiles, prefetch_stages=ab_stages - 2): + # Issue TMA loads + ab_empty = ab_producer.acquire_and_advance() + cute.copy( + tma_atom_a, + tAgA[(None, ab_empty.count)], + tAsA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_b, + tBgB[(None, ab_empty.count)], + tBsB[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + + # Execute one K-block worth of MMA instructions + ab_full = ab_consumer.wait_and_advance() + num_k_blocks = cute.size(tCrA, mode=[2]) + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = (None, None, k_block_idx, ab_full.index) + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[k_block_coord], + tCrB[k_block_coord], + tCtAcc, + ) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Signal that the A/B buffers have been consumed and are ready for the next load + ab_full.release() + + # Signal that the accumulator is fully computed + acc_empty.commit() + + # + # 3. Epilogue + # + + # Release TMEM allocation lock + tmem.relinquish_alloc_permit() + + # Wait for the accumulator buffer to be full + acc_full = acc_consumer.wait_and_advance() + + # TMEM -> RMEM -> GEMM + # Sub-tiling for better instruction-level parallelism + for i in cutlass.range(cute.size(tDtC, mode=[2])): + cute.copy(tmem_tiled_copy, tDtC[None, None, i], tCrAcc) + tCrC.store(tCrAcc.load().to(io_dtype)) + cute.autovec_copy(tCrC, tDgC[None, None, i]) + acc_full.release() + + # Deallocate TMEM + pipeline.sync(barrier_id=1) + tmem.free(tmem_ptr) + + +@cute.jit +def host_function( + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, +): + # Construct tiled MMA + op = tcgen05.MmaF16BF16Op( + io_dtype, + acc_dtype, + mma_inst_shape_mnk, + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + ) + tiled_mma = cute.make_tiled_mma(op) + + # Construct SMEM layouts for A and B + a_smem_layout = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a.element_type, + ab_stages, + ) + b_smem_layout = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b.element_type, + ab_stages, + ) + a_smem_layout_one_stage = cute.select(a_smem_layout, mode=[0, 1, 2]) + b_smem_layout_one_stage = cute.select(b_smem_layout, mode=[0, 1, 2]) + + # Construct TMA load atoms + op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + a_tma_atom, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A( + op, + a, + a_smem_layout_one_stage, + mma_tiler_mnk, + tiled_mma, + ) + b_tma_atom, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B( + op, + b, + b_smem_layout_one_stage, + mma_tiler_mnk, + tiled_mma, + ) + + # Pretty prints kernel attributes useful for debugging + # print(f"a = {cute.pretty_str(a)}") + # print(f"b = {cute.pretty_str(b)}") + # print(f"c = {cute.pretty_str(c)}") + # print(f"tiled_mma = {cute.pretty_str(tiled_mma)}") + # print(f"a_tma_atom = {cute.pretty_str(a_tma_atom)}") + # print(f"b_tma_atom = {cute.pretty_str(b_tma_atom)}") + # print(f"a_tma_tensor = {cute.pretty_str(a_tma_tensor)}") + # print(f"b_tma_tensor = {cute.pretty_str(b_tma_tensor)}") + + # Launch the kernel + grid_shape = cute.ceil_div((*c.layout.shape, 1), mma_tiler_mnk[:2]) + kernel( + tiled_mma, + a_tma_atom, + a_tma_tensor, + b_tma_atom, + b_tma_tensor, + c, + a_smem_layout, + b_smem_layout, + ).launch( + grid=grid_shape, + block=(threads_per_cta, 1, 1), + ) + + +def run_dense_gemm( + mnk: Tuple[int, int, int], + tolerance: float, +): + print("===================================================================") + print("Running Blackwell fp16 GEMM example 0 with:") + print(f" mnk: {mnk}") + print(f" tolerance: {tolerance}") + print("===================================================================") + print() + + m, n, k = mnk + torch.manual_seed(1111) + + # Make K-major tensors (torch tensors are row-major) + def make_tensors(mn, k, dtype): + shape = (mn, k) + return ( + torch.empty(*shape, dtype=torch.int32) + .random_(-2, 2) + .to(dtype=dtype, device="cuda") + ) + + a = make_tensors(m, k, cutlass_torch.dtype(io_dtype)) + b = make_tensors(n, k, cutlass_torch.dtype(io_dtype)) + c = make_tensors(m, n, cutlass_torch.dtype(io_dtype)) + a_tensor = ( + from_dlpack(a, assumed_align=32) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=k) + ) + b_tensor = ( + from_dlpack(b, assumed_align=32) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=k) + ) + c_tensor = ( + from_dlpack(c, assumed_align=32) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=n) + ) + + # Entry point to the host JIT function + host_function( + a_tensor, + b_tensor, + c_tensor, + no_cache=True, + ) + + # Compute reference result and verify + ref = (torch.einsum("mk,nk->mn", a.to(torch.float32), b.to(torch.float32))).cpu() + + torch.testing.assert_close( + c.cpu(), ref.to(cutlass_torch.dtype(io_dtype)), atol=tolerance, rtol=1e-05 + ) + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str): + try: + return [int(x.strip()) for x in s.split(",")] + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + if not torch.cuda.is_available(): + raise RuntimeError("A GPU is required to run this example") + + parser = argparse.ArgumentParser(description="Blackwell fp16 GEMM example 0") + parser.add_argument( + "--mnk", + type=parse_comma_separated_ints, + default=[8192, 8192, 8192], + help="MNK dimensions (comma-separated)", + ) + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + args = parser.parse_args() + if len(args.mnk) != 3: + parser.error("--mnk must contain exactly 3 values") + if args.mnk[0] % mma_tiler_mnk[0] != 0 or args.mnk[1] % mma_tiler_mnk[1] != 0: + parser.error("m n must be divisible by mma_tiler_mn") + + run_dense_gemm( + args.mnk, + args.tolerance, + ) + print("PASS") diff --git a/examples/python/CuTeDSL/blackwell_geforce/dense_gemm.py b/examples/python/CuTeDSL/blackwell_geforce/dense_gemm.py new file mode 100644 index 00000000..2ffe80fd --- /dev/null +++ b/examples/python/CuTeDSL/blackwell_geforce/dense_gemm.py @@ -0,0 +1,1326 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Tuple, Type + +import torch +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.torch as cutlass_torch +import cutlass.utils.hopper_helpers as sm90_utils + +""" +A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Blackwell Geforce architecture +using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes non-Tensor Core MMA for matrix multiply-accumulate (MMA) operations + - Supports multi-stage pipeline to overlap computation and memory access + +This GEMM works as follows: +1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. Perform matrix multiply-accumulate (MMA) operations using non-Tensor Core MMA instruction. +3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM) with TMA operations. + +Non-Tensor Core MMA instructions operate as follows: +- Read matrix A from registers +- Read matrix B from registers +- Perform MMA operation and store the result in Accumulator(register) + +To run this example: + +.. code-block:: bash + + python examples/blackwell_geforce/dense_gemm.py \ + --mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \ + --a_dtype Float16 --b_dtype Float16 \ + --c_dtype Float16 --acc_dtype Float32 \ + --a_major k --b_major k --c_major n + +The above example command compute batched gemm with M=8192, N=8192, K=8192, +batch_count=1. The tile shape is 128x256x64 and the cluster shape is (1,1). +The input, mma accumulator and output data type are set as fp16, fp32 +and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell_geforce/dense_gemm.py \ + --mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \ + --a_dtype Float16 --b_dtype Float16 \ + --c_dtype Float16 --acc_dtype Float32 \ + --a_major k --b_major k --c_major n + +Constraints: +* Supported input data types: fp16, bf16 +* For fp16 types, A and B must have the same data type +* Only fp32 accumulation is supported in this example +* CTA tile shape M must be 64/128 +* CTA tile shape N must be 64/128/256 +* CTA tile shape K must be 64 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 4 +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 8, 16 for Float16, respectively. +* OOB tiles are not allowed when TMA store is disabled +""" + + +# ///////////////////////////////////////////////////////////////////////////// +# Helpers to parse args +# ///////////////////////////////////////////////////////////////////////////// +def parse_comma_separated_ints(s: str): + try: + return tuple([int(x.strip()) for x in s.split(",")]) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Example of MxNxKxL GEMM on Blackwell Geforce." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(4096, 4096, 4096, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--tile_shape_mnk", + type=parse_comma_separated_ints, + choices=[ + (64, 64, 64), + (64, 128, 64), + (128, 64, 64), + (128, 128, 64), + (128, 256, 64), + (128, 128, 128), + ], + default=(64, 64, 64), + help="CTA tile shape (comma-separated)", + ) + parser.add_argument( + "--a_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + ) + parser.add_argument( + "--b_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + ) + parser.add_argument( + "--c_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + ) + parser.add_argument( + "--acc_dtype", + type=cutlass.dtype, + default=cutlass.Float32, + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", + action="store_true", + default=False, + help="Skip reference checking", + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + return args + + +# ///////////////////////////////////////////////////////////////////////////// +# Host setup and device kernel launch +# ///////////////////////////////////////////////////////////////////////////// + + +class Sm120GemmKernel: + def __init__( + self, + acc_dtype, + tile_shape_mnk, + ): + self.acc_dtype = acc_dtype + self.cluster_shape_mnk = (1, 1, 1) + self.tile_shape_mnk = tuple(tile_shape_mnk) + self.tiled_mma = None + self.num_mcast_ctas_a = None + self.num_mcast_ctas_b = None + self.is_a_mcast = False + self.is_b_mcast = False + + self.occupancy = 1 + # TODO: remove this hard code for user input ? + self.atom_layout = (2, 2, 1) + self.num_mma_warps = ( + self.atom_layout[0] * self.atom_layout[1] * self.atom_layout[2] + ) + self.num_threads_per_warp = 32 + self.threads_per_cta = ( + self.num_mma_warps + 1 # 1 warp for DMA + ) * self.num_threads_per_warp + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_120") + + self.ab_stage = None + self.epi_stage = None + + self.a_smem_layout_staged = None + self.b_smem_layout_staged = None + self.epi_smem_layout_staged = None + self.epi_tile = None + + self.shared_storage = None + self.buffer_align_bytes = 1024 + + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=self.num_mma_warps * self.num_threads_per_warp, + ) + self.load_register_requirement = 40 + self.mma_register_requirement = 232 + + def _setup_attributes(self): + # TODO: remove this hard code for user input ? + self.mma_inst_mnk = (16, 8, 16) + op = cute.nvgpu.warp.MmaF16BF16Op( + self.a_dtype, + self.acc_dtype, + self.mma_inst_mnk, + ) + tC = cute.make_layout(self.atom_layout) + permutation_mnk = ( + self.atom_layout[0] * self.mma_inst_mnk[0], + # TODO: to leverage ldmatrix.x4, when self.atom_layout[1] is 1, mma tile is ((8x16)x2) + self.atom_layout[1] * self.mma_inst_mnk[1] * 2, + self.atom_layout[2] * self.mma_inst_mnk[2], + ) + self.tiled_mma = cute.make_tiled_mma( + op, + tC, + permutation_mnk=permutation_mnk, + ) + + self.cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk) + + self.num_mcast_ctas_a = self.cluster_shape_mnk[1] + self.num_mcast_ctas_b = self.cluster_shape_mnk[0] + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + self.epi_tile = sm90_utils.compute_tile_shape_or_override( + self.tile_shape_mnk, self.c_dtype, is_cooperative=False + ) + + # Compute stage before compute smem layout + self.ab_stage, self.epi_stage = self._compute_stages( + self.tile_shape_mnk, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.smem_capacity, + self.occupancy, + ) + + import sys + + if self.ab_stage == 0: + print("ab_stage == 0, no enough shared memory. This case will be skipped.") + sys.exit(0) + + ( + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + ) = self._make_smem_layouts( + self.tile_shape_mnk, + self.epi_tile, + self.a_dtype, + self.a_layout, + self.b_dtype, + self.b_layout, + self.ab_stage, + self.c_dtype, + self.c_layout, + self.epi_stage, + ) + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + ): + """Execute the GEMM operation in steps: + - Setup static attributes + - Setup TMA load/store atoms and tensors + - Compute grid size + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + """ + + # setup static attributes before smem/grid/tma computation + self.a_dtype = a.element_type + self.b_dtype = b.element_type + self.c_dtype = c.element_type + + self.a_layout = utils.LayoutEnum.from_tensor(a) + self.b_layout = utils.LayoutEnum.from_tensor(b) + self.c_layout = utils.LayoutEnum.from_tensor(c) + + if cutlass.const_expr( + self.a_dtype.width == 16 and self.a_dtype != self.b_dtype + ): + raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") + if cutlass.const_expr(self.a_dtype.width != self.b_dtype.width): + raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") + if cutlass.const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8): + raise TypeError("a_dtype should be float16 or float8") + if cutlass.const_expr(self.b_dtype.width != 16 and self.b_dtype.width != 8): + raise TypeError("b_dtype should be float16 or float8") + + self._setup_attributes() + + tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors( + a, + self.a_smem_layout_staged, + (self.tile_shape_mnk[0], self.tile_shape_mnk[2]), + 1, + ) + + tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors( + b, + self.b_smem_layout_staged, + (self.tile_shape_mnk[1], self.tile_shape_mnk[2]), + 1, + ) + + tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors( + c, + self.epi_smem_layout_staged, + self.epi_tile, + ) + + tile_sched_params, grid = self._compute_grid( + c, + self.tile_shape_mnk, + max_active_clusters, + ) + + @cute.struct + class SharedStorage: + mainloop_pipeline_array_ptr: cute.struct.MemRange[ + cutlass.Int64, self.ab_stage * 2 + ] + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, cute.cosize(self.epi_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + self.tiled_mma, + self.cta_layout_mnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + tile_sched_params, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=[1, 1, 1], + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + tiled_mma: cute.TiledMma, + cta_layout_mnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + epi_smem_layout_staged: cute.ComposedLayout, + tile_sched_params: utils.PersistentTileSchedulerParams, + ): + """ + GPU device kernel performing the batched GEMM computation. + + :param tma_atom_a: TMA copy atom for A tensor + :type tma_atom_a: cute.CopyAtom + :param mA_mkl: Input tensor A + :type mA_mkl: cute.Tensor + :param tma_atom_b: TMA copy atom for B tensor + :type tma_atom_b: cute.CopyAtom + :param mB_nkl: Input tensor B + :type mB_nkl: cute.Tensor + :param tma_atom_c: TMA copy atom for C tensor + :type tma_atom_c: cute.CopyAtom + :param mC_mnl: Output tensor C + :type mC_mnl: cute.Tensor + :param tiled_mma: Tiled MMA object + :type tiled_mma: cute.TiledMma + :param cta_layout_mnk: CTA layout + :type cta_layout_mnk: cute.Layout + :param a_smem_layout_staged: Shared memory layout for A + :type a_smem_layout_staged: cute.ComposedLayout + :param b_smem_layout_staged: Shared memory layout for B + :type b_smem_layout_staged: cute.ComposedLayout + :param epi_smem_layout_staged: Shared memory layout for epilogue + :type epi_smem_layout_staged: cute.ComposedLayout + """ + + # /////////////////////////////////////////////////////////////////////////////// + # Get cta/warp/thread idx + # /////////////////////////////////////////////////////////////////////////////// + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + # bidx, bidy, bidz = cute.arch.block_idx() + # bdimx, bdimy, bdimz = cute.arch.grid_dim() + + # ///////////////////////////////////////////////////////////////////////////// + # Prefetch Tma desc + # ///////////////////////////////////////////////////////////////////////////// + if warp_idx == 0: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_a) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_b) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_c) + + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster) + + # /////////////////////////////////////////////////////////////////////////////// + # Get mcast mask + # /////////////////////////////////////////////////////////////////////////////// + a_mcast_mask = cute.make_layout_image_mask( + cta_layout_mnk, cluster_coord_mnk, mode=1 + ) + b_mcast_mask = cute.make_layout_image_mask( + cta_layout_mnk, cluster_coord_mnk, mode=0 + ) + + a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0 + b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0 + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0)) + tma_copy_bytes = cute.size_in_bytes( + self.a_dtype, a_smem_layout + ) + cute.size_in_bytes(self.b_dtype, b_smem_layout) + + # ///////////////////////////////////////////////////////////////////////////// + # Alloc and init AB full/empty + ACC full mbar (pipeline) + # ///////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + # mbar arrays + mainloop_pipeline_array_ptr = storage.mainloop_pipeline_array_ptr.data_ptr() + + # Threads/warps participating in this pipeline + mainloop_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread + ) + # Each warp will constribute to the arrive count with the number of mcast size + mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + consumer_arrive_cnt = mcast_size * self.num_mma_warps + mainloop_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, consumer_arrive_cnt + ) + + cta_layout_vmnk = cute.make_layout((1, *cta_layout_mnk.shape)) + mainloop_pipeline = pipeline.PipelineTmaAsync.create( + num_stages=self.ab_stage, + producer_group=mainloop_pipeline_producer_group, + consumer_group=mainloop_pipeline_consumer_group, + tx_count=tma_copy_bytes, + barrier_storage=mainloop_pipeline_array_ptr, + cta_layout_vmnk=cta_layout_vmnk, + ) + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mnk) > 1: + cute.arch.cluster_arrive_relaxed() + + # /////////////////////////////////////////////////////////////////////////////// + # Generate smem tensor A/B + # /////////////////////////////////////////////////////////////////////////////// + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + sC = storage.sC.get_tensor( + epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Local_tile partition global tensors + # /////////////////////////////////////////////////////////////////////////////// + # (bM, bK, loopM, loopK, loopL) + gA_mkl = cute.local_tile( + mA_mkl, + cute.slice_(self.tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + # (bN, bK, loopN, loopK, loopL) + gB_nkl = cute.local_tile( + mB_nkl, + cute.slice_(self.tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + # (bM, bN, loopM, loopN, loopL) + gC_mnl = cute.local_tile( + mC_mnl, + cute.slice_(self.tile_shape_mnk, (None, None, 0)), + (None, None, None), + ) + + # ////////////////////////////////////////////////////////////////////////////// + # Partition global tensor for TiledMMA_A/B/C + # ////////////////////////////////////////////////////////////////////////////// + thr_mma = tiled_mma.get_slice(tidx) + + # ////////////////////////////////////////////////////////////////////////////// + # Partition shared tensor for TMA load A/B + # ////////////////////////////////////////////////////////////////////////////// + # TMA load A partition_S/D + a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape) + a_cta_crd = cluster_coord_mnk[1] + tAsA, tAgA = cute.nvgpu.cpasync.tma_partition( + tma_atom_a, + a_cta_crd, + a_cta_layout, + cute.group_modes(sA, 0, 2), + cute.group_modes(gA_mkl, 0, 2), + ) + + # TMA load B partition_S/D + b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape) + b_cta_crd = cluster_coord_mnk[0] + tBsB, tBgB = cute.nvgpu.cpasync.tma_partition( + tma_atom_b, + b_cta_crd, + b_cta_layout, + cute.group_modes(sB, 0, 2), + cute.group_modes(gB_nkl, 0, 2), + ) + + # Make frangments + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) + tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) + + tCgC = thr_mma.partition_C(gC_mnl) + acc_shape = tCgC.shape[:3] + accumulators = cute.make_rmem_tensor(acc_shape, self.acc_dtype) + + # cluster wait for barrier init + if cute.size(self.cluster_shape_mnk) > 1: + cute.arch.cluster_wait() + else: + pipeline.sync(barrier_id=1) + + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # Create the tile scheduler + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + # Create the pipeline states for producer and consumer + mainloop_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.ab_stage + ) + mainloop_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage + ) + + # MMA warp group + if warp_idx < self.num_mma_warps: + cute.arch.warpgroup_reg_alloc(self.mma_register_requirement) + + num_k_blocks = cute.size(tCrA, mode=[2]) + + # /////////////////////////////////////////////////////////////////////////////// + # Copy Atom A/B retiling for TMA load A/B + # /////////////////////////////////////////////////////////////////////////////// + atom_copy_ldmatrix_A = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(self.a_layout.is_m_major_a(), 4), + self.a_dtype, + ) + atom_copy_ldmatrix_B = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(self.b_layout.is_n_major_b(), 4), + self.b_dtype, + ) + smem_tiled_copy_A = cute.make_tiled_copy_A(atom_copy_ldmatrix_A, tiled_mma) + + smem_tiled_copy_B = cute.make_tiled_copy_B(atom_copy_ldmatrix_B, tiled_mma) + + thr_copy_ldmatrix_A = smem_tiled_copy_A.get_slice(tidx) + thr_copy_ldmatrix_B = smem_tiled_copy_B.get_slice(tidx) + tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA) + tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA) + + tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB) + tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB) + + while work_tile.is_valid_tile: + tile_coord_mnl = work_tile.tile_idx + gC_mnl_slice = gC_mnl[(None, None, *tile_coord_mnl)] + # Clear the accumulator + accumulators.fill(0.0) + + # ///////////////////////////////////////////////////////////////////////////// + # Pipelined MAINLOOP + # ///////////////////////////////////////////////////////////////////////////// + + mainloop_consumer_state.reset_count() + + peek_ab_full_status = cutlass.Boolean(1) + if mainloop_consumer_state.count < k_tile_cnt: + peek_ab_full_status = mainloop_pipeline.consumer_try_wait( + mainloop_consumer_state + ) + + # Wait for TMA copies to complete + mainloop_pipeline.consumer_wait( + mainloop_consumer_state, peek_ab_full_status + ) + # tCsA_p: (MMA, (4, MMA_M / 4), MMA_K), tCsA_p: (MMA, (4, MMA_N / 4), MMA_K) + tCsA_p = tCsA_copy_view[None, None, None, mainloop_consumer_state.index] + tCsB_p = tCsB_copy_view[None, None, None, mainloop_consumer_state.index] + cute.copy( + smem_tiled_copy_A, + tCsA_p[None, None, 0], + tCrA_copy_view[None, None, 0], + ) + cute.copy( + smem_tiled_copy_B, + tCsB_p[None, None, 0], + tCrB_copy_view[None, None, 0], + ) + + for k_tile in range(0, k_tile_cnt - 1, 1, unroll=1): + # unroll the loop + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_next = ( + 0 if k_block_idx + 1 == num_k_blocks else k_block_idx + 1 + ) + + if k_block_idx == num_k_blocks - 1: + mainloop_pipeline.consumer_release(mainloop_consumer_state) + mainloop_consumer_state.advance() + + peek_ab_full_status = cutlass.Boolean(1) + peek_ab_full_status = mainloop_pipeline.consumer_try_wait( + mainloop_consumer_state + ) + + # tCsA_p: (MMA, (4, MMA_M / 4), MMA_K), tCsA_p: (MMA, (4, MMA_N / 4), MMA_K) + tCsA_p = tCsA_copy_view[ + None, None, None, mainloop_consumer_state.index + ] + tCsB_p = tCsB_copy_view[ + None, None, None, mainloop_consumer_state.index + ] + mainloop_pipeline.consumer_wait( + mainloop_consumer_state, peek_ab_full_status + ) + + # Copy data from smem to tCrA/tCrB for the next k_block + cute.copy( + smem_tiled_copy_A, + tCsA_p[None, None, k_block_next], + tCrA_copy_view[None, None, k_block_next], + ) + cute.copy( + smem_tiled_copy_B, + tCsB_p[None, None, k_block_next], + tCrB_copy_view[None, None, k_block_next], + ) + # Gemm of the current k_block + cute.gemm( + tiled_mma, + accumulators, + tCrA[None, None, k_block_idx], + tCrB[None, None, k_block_idx], + accumulators, + ) + # end of for loop + # Hoist out last k_tile + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_next = ( + 0 if k_block_idx + 1 == num_k_blocks else k_block_idx + 1 + ) + + if k_block_idx == num_k_blocks - 1: + mainloop_pipeline.consumer_release(mainloop_consumer_state) + mainloop_consumer_state.advance() + + if k_block_next > 0: + cute.copy( + smem_tiled_copy_A, + tCsA_p[None, None, k_block_next], + tCrA_copy_view[None, None, k_block_next], + ) + cute.copy( + smem_tiled_copy_B, + tCsB_p[None, None, k_block_next], + tCrB_copy_view[None, None, k_block_next], + ) + # Gemm of the current k_block + cute.gemm( + tiled_mma, + accumulators, + tCrA[None, None, k_block_idx], + tCrB[None, None, k_block_idx], + accumulators, + ) + + # ///////////////////////////////////////////////////////////////////////////// + # EPILOG + # ///////////////////////////////////////////////////////////////////////////// + + copy_atom_r2s = sm90_utils.sm90_get_smem_store_op( + self.c_layout, + elem_ty_d=self.c_dtype, + elem_ty_acc=self.acc_dtype, + ) + + copy_atom_C = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp( + self.c_layout.is_m_major_c(), + 4, + ), + self.c_dtype, + ) + + tiled_copy_C_Atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma) + + tiled_copy_r2s = cute.make_tiled_copy_S( + copy_atom_r2s, + tiled_copy_C_Atom, + ) + + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + # (R2S, R2S_M, R2S_N, PIPE_D) + tRS_sD = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rAcc = tiled_copy_r2s.retile(accumulators) + + # Allocate D registers. + rD_shape = cute.shape(thr_copy_r2s.partition_S(sC)) + tRS_rD_layout = cute.make_layout(rD_shape[:3]) + tRS_rD = cute.make_rmem_tensor(tRS_rD_layout.shape, self.acc_dtype) + size_tRS_rD = cute.size(tRS_rD) + + sepi_for_tma_partition = cute.group_modes(sC, 0, 2) + tcgc_for_tma_partition = cute.zipped_divide(gC_mnl_slice, self.epi_tile) + + bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sepi_for_tma_partition, + tcgc_for_tma_partition, + ) + + epi_tile_num = cute.size(tcgc_for_tma_partition, mode=[1]) + epi_tile_shape = tcgc_for_tma_partition.shape[1] + epi_tile_layout = cute.make_layout( + epi_tile_shape, stride=(1, epi_tile_shape[0]) + ) + + # Initialize tma store pipeline + tma_store_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_mma_warps * self.num_threads_per_warp, + ) + tma_store_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.epi_stage, + producer_group=tma_store_producer_group, + ) + + for epi_idx in cutlass.range_constexpr(epi_tile_num): + # Copy from accumulators to D registers + for epi_v in cutlass.range_constexpr(size_tRS_rD): + tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v] + + # Type conversion + tRS_rD_out = cute.make_rmem_tensor( + tRS_rD_layout.shape, self.c_dtype + ) + acc_vec = tRS_rD.load() + tRS_rD_out.store(acc_vec.to(self.c_dtype)) + + # Register to shared memory + epi_buffer = epi_idx % cute.size(tRS_sD, mode=[3]) + cute.copy( + tiled_copy_r2s, + tRS_rD_out, + tRS_sD[(None, None, None, epi_buffer)], + ) + + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + # barrier for sync + self.epilog_sync_barrier.arrive_and_wait() + + # Get the global memory coordinate for the current epi tile. + gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) + # Copy from shared memory to global memory + if warp_idx == 0: + cute.copy( + tma_atom_c, + bSG_sD[(None, epi_buffer)], + bSG_gD[(None, gmem_coord)], + ) + tma_store_pipeline.producer_commit() + tma_store_pipeline.producer_acquire() + + # Advance to the next work tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + tma_store_pipeline.producer_tail() + # End of for k_tile loop + # End of while loop + # End of MMA warp group + # Start of DMA warp group + elif warp_idx == self.num_mma_warps: + cute.arch.warpgroup_reg_dealloc(self.load_register_requirement) + + while work_tile.is_valid_tile: + tile_coord_mnl = work_tile.tile_idx + tAgA_mkl = tAgA[(None, tile_coord_mnl[0], None, tile_coord_mnl[2])] + tBgB_nkl = tBgB[(None, tile_coord_mnl[1], None, tile_coord_mnl[2])] + + mainloop_producer_state.reset_count() + + for k_tile in range(0, k_tile_cnt, 1, unroll=1): + # ///////////////////////////////////////////////////////////////////////////// + # Wait for A/B buffers to be empty before loading into them + # Also sets the transaction barrier for the A/B buffers + # ///////////////////////////////////////////////////////////////////////////// + mainloop_pipeline.producer_acquire(mainloop_producer_state) + + # ///////////////////////////////////////////////////////////////////////////// + # Slice to global/shared memref to current k_tile + # ///////////////////////////////////////////////////////////////////////////// + tAgA_k = tAgA_mkl[(None, mainloop_producer_state.count)] + tAsA_pipe = tAsA[(None, mainloop_producer_state.index)] + + tBgB_k = tBgB_nkl[(None, mainloop_producer_state.count)] + tBsB_pipe = tBsB[(None, mainloop_producer_state.index)] + + # ///////////////////////////////////////////////////////////////////////////// + # TMA load A/B + # ///////////////////////////////////////////////////////////////////////////// + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ), + mcast_mask=a_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ), + mcast_mask=b_mcast_mask, + ) + # Mainloop pipeline's producer commit is a NOP + mainloop_pipeline.producer_commit(mainloop_producer_state) + mainloop_producer_state.advance() + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # end of while loop + + # Wait A/B buffer empty + mainloop_pipeline.producer_tail(mainloop_producer_state) + return + + @staticmethod + def _compute_stages( + tile_shape_mnk: tuple[int, int, int], + a_dtype: type[cutlass.Numeric], + b_dtype: type[cutlass.Numeric], + epi_tile: tuple[int, int], + c_dtype: type[cutlass.Numeric], + smem_capacity: int, + occupancy: int, + ) -> tuple[int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type tile_shape_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (A/B operand stages, epilogue stages) + :rtype: tuple[int, int] + """ + epi_stage = 8 + c_bytes_per_stage = cute.size(epi_tile) * c_dtype.width // 8 + epi_bytes = c_bytes_per_stage * epi_stage + + a_shape = cute.slice_(tile_shape_mnk, (None, 0, None)) + b_shape = cute.slice_(tile_shape_mnk, (0, None, None)) + ab_bytes_per_stage = ( + cute.size(a_shape) * a_dtype.width // 8 + + cute.size(b_shape) * b_dtype.width // 8 + ) + mbar_helpers_bytes = 1024 + + ab_stage = ( + (smem_capacity - occupancy * 1024) // occupancy + - mbar_helpers_bytes + - epi_bytes + ) // ab_bytes_per_stage + return ab_stage, epi_stage + + @staticmethod + def _make_smem_layouts( + tile_shape_mnk: tuple[int, int, int], + epi_tile: tuple[int, int], + a_dtype: type[cutlass.Numeric], + a_layout: cute.Layout, + b_dtype: type[cutlass.Numeric], + b_layout: cute.Layout, + ab_stage: int, + c_dtype: type[cutlass.Numeric], + c_layout: cute.Layout, + epi_stage: int, + ) -> tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]: + """Create shared memory layouts for A, B, and C tensors. + + :param tile_shape_mnk: CTA tile shape (M,N,K) + :type tile_shape_mnk: Tuple[int, int, int] + :param epi_tile: Epilogue tile shape + :type epi_tile: Tuple[int, int] + :param a_dtype: Data type for matrix A + :type a_dtype: type[cutlass.Numeric] + :param a_layout: Layout for matrix A + :type a_layout: Layout + :param b_dtype: Data type for matrix B + :type b_dtype: type[cutlass.Numeric] + :param b_layout: Layout for matrix B + :type b_layout: Layout + :param ab_stage: Number of stages for A/B tensors + :type ab_stage: int + :param c_dtype: Data type for output matrix C + :type c_dtype: type[cutlass.Numeric] + :param c_layout: leading dimension of the output matrix C + :type c_layout: Layout + :param epi_stage: Number of epilogue stages + :type epi_stage: int + + :return: Tuple of shared memory layouts for A, B, and C + :rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout] + """ + a_smem_layout_staged = sm90_utils.make_smem_layout_a( + a_layout, + tile_shape_mnk, + a_dtype, + ab_stage, + ) + + b_smem_layout_staged = sm90_utils.make_smem_layout_b( + b_layout, + tile_shape_mnk, + b_dtype, + ab_stage, + ) + + epi_smem_layout_staged = sm90_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + epi_stage, + ) + + return a_smem_layout_staged, b_smem_layout_staged, epi_smem_layout_staged + + @staticmethod + def _compute_grid( + c: cute.Tensor, + tile_shape_mnk: tuple[int, int, int], + max_active_clusters: cutlass.Constexpr, + ) -> tuple[int, int, int]: + """Compute grid shape for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type tile_shape_mnk: tuple[int, int, int] + + :return: Grid shape for kernel launch. + :rtype: tuple[int, int, int] + """ + + c_shape = cute.slice_(tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (1, 1, 1) + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + return tile_sched_params, grid + + @staticmethod + def _make_tma_store_atoms_and_tensors( + tensor_c: cute.Tensor, + epi_smem_layout_staged: cute.ComposedLayout, + epi_tile: tuple[int, int], + ) -> tuple[cute.CopyAtom, cute.Tensor]: + """Create TMA atoms and tensors for C tensor storage. + + :param tensor_c: Output tensor C + :type tensor_c: cute.Tensor + :param epi_smem_layout_staged: Shared memory layout for epilogue + :type epi_smem_layout_staged: cute.ComposedLayout + :param epi_tile: Epilogue tile shape + :type epi_tile: Tuple[int, int] + + :return: TMA atom and tensor for C + :rtype: Tuple[cute.CopyAtom, cute.Tensor] + """ + epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tiled_tma_atom( + cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(), + tensor_c, + epi_smem_layout, + epi_tile, + ) + + return tma_atom_c, tma_tensor_c + + @staticmethod + def _make_tma_atoms_and_tensors( + tensor: cute.Tensor, + smem_layout_staged: cute.ComposedLayout, + smem_tile: tuple[int, int], + mcast_dim: int, + ) -> tuple[cute.CopyAtom, cute.Tensor]: + """Create TMA atoms and tensors for input tensors. + + :param tensor: Input tensor (A or B) + :type tensor: cute.Tensor + :param smem_layout_staged: Shared memory layout for the tensor + :type smem_layout_staged: cute.ComposedLayout + :param smem_tile: Shared memory tile shape + :type smem_tile: Tuple[int, int] + :param mcast_dim: Multicast dimension + :type mcast_dim: int + + :return: TMA atom and tensor + :rtype: Tuple[cute.CopyAtom, cute.Tensor] + """ + op = ( + cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() + if mcast_dim == 1 + else cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp() + ) + + smem_layout = cute.slice_(smem_layout_staged, (None, None, 0)) + tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom( + op, + tensor, + smem_layout, + smem_tile, + num_multicast=mcast_dim, + ) + return tma_atom, tma_tensor + + +def run( + mnkl: Tuple[int, int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + tile_shape_mnk: Tuple[int, int, int], + tolerance: float, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, +): + print("Running Blackwell Geforce Dense GEMM with:") + print(f"mnkl: {mnkl}") + print( + f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}" + ) + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Tile Shape: {tile_shape_mnk}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {use_cold_l2}") + + a_dtype = getattr(cutlass, a_dtype) if isinstance(a_dtype, str) else a_dtype + b_dtype = getattr(cutlass, b_dtype) if isinstance(b_dtype, str) else b_dtype + c_dtype = getattr(cutlass, c_dtype) if isinstance(c_dtype, str) else c_dtype + acc_dtype = getattr(cutlass, acc_dtype) if isinstance(acc_dtype, str) else acc_dtype + + # Unpack parameters + m, n, k, l = mnkl + cluster_shape_mnk = (1, 1, 1) + + # Prepare pytorch tensors: A, B (random from 0 to 2) and C (all zero) + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + a_torch_cpu = cutlass_torch.matrix(l, m, k, a_major, a_dtype) + b_torch_cpu = cutlass_torch.matrix(l, n, k, b_major, b_dtype) + c_torch_cpu = cutlass_torch.matrix(l, m, n, c_major, c_dtype) + + def create_cute_tensor(data_ref, cutlass_dtype): + cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like( + data_ref, cutlass_dtype, True, 16 + ) + + if cutlass_dtype.is_float and cutlass_dtype.width == 8: + f32_torch_tensor = data_ref.to(dtype=torch.float32) + cute_tensor = cutlass_torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + cutlass_dtype, + is_dynamic_layout=True, + ) + return cute_tensor, torch_tensor + + a_tensor, a_torch_gpu = create_cute_tensor(a_torch_cpu, a_dtype) + b_tensor, b_torch_gpu = create_cute_tensor(b_torch_cpu, b_dtype) + c_tensor, c_torch_gpu = create_cute_tensor(c_torch_cpu, c_dtype) + + gemm = Sm120GemmKernel( + acc_dtype, + tile_shape_mnk, + ) + + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mnk[0] * cluster_shape_mnk[1] + ) + + # Initialize stream + stream = cutlass_torch.default_stream() + # compile gemm kernel + compiled_gemm = cute.compile( + gemm, a_tensor, b_tensor, c_tensor, max_active_clusters, stream + ) + + if not skip_ref_check: + print("Reference checking ...") + # execution + compiled_gemm(a_tensor, b_tensor, c_tensor, stream) + torch.cuda.synchronize() + + # Ref check + ref = torch.einsum( + "mkl,nkl->mnl", + a_torch_cpu.to(dtype=torch.float32), + b_torch_cpu.to(dtype=torch.float32), + ) + + # Copy gpu tensor to cpu + kernel_result = c_torch_gpu.cpu() + + # Convert ref to c_dtype + _, ref_torch_gpu = create_cute_tensor(ref, c_dtype) + ref_result = ref_torch_gpu.cpu() + + # Assert close results + torch.testing.assert_close( + kernel_result, ref_result, atol=tolerance, rtol=1e-03 + ) + + def generate_tensors(): + a_torch_cpu = cutlass_torch.matrix(l, m, k, a_major, a_dtype) + b_torch_cpu = cutlass_torch.matrix(l, n, k, b_major, b_dtype) + c_torch_cpu = cutlass_torch.matrix(l, m, n, c_major, c_dtype) + mA_workspace, _ = create_cute_tensor(a_torch_cpu, a_dtype) + mB_workspace, _ = create_cute_tensor(b_torch_cpu, b_dtype) + mC_workspace, _ = create_cute_tensor(c_torch_cpu, c_dtype) + return testing.JitArguments(mA_workspace, mB_workspace, mC_workspace, stream) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch_gpu.numel() * a_torch_gpu.element_size() + + b_torch_gpu.numel() * b_torch_gpu.element_size() + + c_torch_gpu.numel() * c_torch_gpu.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + print(f"Execution time: {exec_time} microseconds per iteration") + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + args = parse_arguments() + run( + args.mnkl, + args.a_dtype, + args.b_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.tile_shape_mnk, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/examples/python/CuTeDSL/cute/ffi/CMakeLists.txt b/examples/python/CuTeDSL/cute/ffi/CMakeLists.txt index f2cf2336..1870197f 100644 --- a/examples/python/CuTeDSL/cute/ffi/CMakeLists.txt +++ b/examples/python/CuTeDSL/cute/ffi/CMakeLists.txt @@ -30,11 +30,12 @@ cmake_minimum_required(VERSION 3.15) project(tensor) # Find Python +find_package(Python COMPONENTS Interpreter Development REQUIRED) find_package(Python3 COMPONENTS Interpreter Development REQUIRED) # Get Python site-packages directory using Python execute_process( - COMMAND ${Python_EXECUTABLE} -c "import site; print(site.getsitepackages()[0])" + COMMAND ${Python3_EXECUTABLE} -c "import site; print(site.getsitepackages()[0])" OUTPUT_VARIABLE Python_SITE_PACKAGES OUTPUT_STRIP_TRAILING_WHITESPACE ) @@ -45,7 +46,13 @@ message(STATUS "Python site-packages directory: ${Python_SITE_PACKAGES}") list(APPEND CMAKE_PREFIX_PATH ${Python_SITE_PACKAGES}/nanobind/cmake) # Find nanobind -find_package(nanobind REQUIRED) +find_package(nanobind) +if(NOT nanobind_FOUND) + message(FATAL_ERROR + "nanobind not found!\n" + "Please install nanobind with: pip install nanobind\n" + ) +endif() # Add the module nanobind_add_module(tensor tensor.cpp) diff --git a/examples/python/CuTeDSL/cute/ffi/jit_argument.py b/examples/python/CuTeDSL/cute/ffi/jit_argument.py index acdb42ef..bf21bc17 100644 --- a/examples/python/CuTeDSL/cute/ffi/jit_argument.py +++ b/examples/python/CuTeDSL/cute/ffi/jit_argument.py @@ -54,7 +54,6 @@ import cutlass.cute as cute from cutlass._mlir import ir from cutlass._mlir.dialects import llvm -import cutlass._mlir.extras.types as T class ExampleTensorValue(ir.Value): @@ -244,7 +243,7 @@ import tempfile import torch -def run_test(tmpdir=None): +def run_test(tmpdir=None, cmake_args=""): # Skip cleanup if user provides tmpdir cleanup = tmpdir is None # Initialize temporary build directory @@ -253,7 +252,8 @@ def run_test(tmpdir=None): try: current_dir = os.path.dirname(os.path.abspath(__file__)) - subprocess.run(["cmake", "-B", tmpdir, current_dir], check=True) + cmake_args = cmake_args.split() + subprocess.run(["cmake", "-B", tmpdir, current_dir] + cmake_args, check=True) subprocess.run(["cmake", "--build", tmpdir], check=True) sys.path.append(tmpdir) @@ -284,7 +284,10 @@ def run_test(tmpdir=None): # Execute compiled function compiled_func(tensor) except Exception as e: - print(e) + import traceback + + traceback.print_exception(type(e), e, e.__traceback__) + raise e finally: if cleanup: # Clean up the temporary directory @@ -298,8 +301,17 @@ if __name__ == "__main__": description="Set temporary directory for building C modules" ) parser.add_argument( - "--tmp-dir", type=str, help="Temporary directory path for building C modules" + "--tmp-dir", + type=str, + default=None, + help="Temporary directory path for building C modules", + ) + parser.add_argument( + "--cmake-args", + type=str, + default="", + help="Extra CMake arguments for building C modules", ) args = parser.parse_args() - run_test(args.tmp_dir) + run_test(tmpdir=args.tmp_dir, cmake_args=args.cmake_args) diff --git a/examples/python/CuTeDSL/cute/torch_fake_tensor.py b/examples/python/CuTeDSL/cute/torch_fake_tensor.py new file mode 100644 index 00000000..06c51ea8 --- /dev/null +++ b/examples/python/CuTeDSL/cute/torch_fake_tensor.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import torch + +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack + + +"""Example demonstrating how to use CuTe with PyTorch's FakeTensor mode. + +This example shows how to: +1. Use PyTorch's FakeTensor mode to compile a CuTe function without real data +2. Execute the compiled function on real data later + +FakeTensor mode allows compiling code without allocating real memory, which is useful +for ahead-of-time compilation scenarios. The compiled function can then be executed +on real tensors that match the expected shapes and dtypes. + +Primary goals of this example are to demonstrate: How to use PyTorch's FakeTensor mode with CuTe +to enable ahead-of-time compilation without real data allocation. + +The example: +1. Creates a fake tensor in PyTorch using FakeTensor mode +2. Compiles a CuTe function using the fake tensor without allocating real memory +3. Creates a real tensor with matching shape and dtype +4. Executes the compiled function on the real tensor + +To run this example: + +.. code-block:: bash + + python examples/cute/torch_fake_tensor.py +""" + + +@cute.jit +def print_tensor(t: cute.Tensor): + cute.print_tensor(t) + + +if __name__ == "__main__": + from torch._subclasses.fake_tensor import FakeTensorMode + + shape = (3, 4) + with FakeTensorMode(): + fake_tensor = torch.zeros(shape, dtype=torch.float32) + compiled_fn = cute.compile(print_tensor, from_dlpack(fake_tensor)) + + real_tensor = torch.randn(shape, dtype=torch.float32) + compiled_fn(from_dlpack(real_tensor)) diff --git a/examples/python/CuTeDSL/hopper/dense_gemm.py b/examples/python/CuTeDSL/hopper/dense_gemm.py index c59ace02..ce7afa45 100644 --- a/examples/python/CuTeDSL/hopper/dense_gemm.py +++ b/examples/python/CuTeDSL/hopper/dense_gemm.py @@ -91,10 +91,11 @@ To collect performance with NCU profiler: --a_major k --b_major k --c_major n Constraints: -* Supported input data types: fp16, fp8 (e4m3fn, e5m2) +* Supported input data types: fp16, fp8 (e4m3fn, e5m2), int8, uint8 * For fp16 types, A and B must have the same data type -* For fp8 types, A and B can have different types (e4m3fn or e5m2) but both must be 8-bit -* Fp8 types only support k-major layout +* For fp8 types, A and B can have different types (e4m3fn or e5m2) +* For 8-bit integer types, A and B can have different types (int8 or uint8) +* 8-bit types (e4m3fn, e5m2, int8, uint8) only support k-major layout * CTA tile shape M must be 64/128 * CTA tile shape N must be 64/128/256 * Cluster shape M/N must be positive and power of 2, total cluster size <= 4 @@ -212,17 +213,19 @@ class HopperWgmmaGemmKernel: :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing :type cluster_shape_mn: Tuple[int, int] - :note: Data type requirements: - - For 16-bit types: A and B must have the same data type - - For 8-bit types: A and B can have different types (Float8E4M3FN/Float8E5M2) as long as both are 8-bit - - Float8 types only support k-major layout - - :note: Supported data types: + :note: Supported A/B data types: - Float16 + A and B must have the same data type - Float8E4M3FN/Float8E5M2 + A and B can have different types (Float8E4M3FN/Float8E5M2) + only support k-major layout + - Int8/Uint8 + A and B can have different types (Int8/Uint8) + only support k-major layout :note: Supported accumulation types: - - Float32 (for all floating point inputs) + - Float32/Float16 (for all floating point inputs) + - Int32 (for Int8/Uint8 inputs) :note: Constraints: - CTA tile M must be 64/128 @@ -339,7 +342,7 @@ class HopperWgmmaGemmKernel: self.is_b_mcast = self.num_mcast_ctas_b > 1 is_cooperative = self.atom_layout_mnk == (2, 1, 1) - self.epi_tile = self._sm90_compute_tile_shape_or_override( + self.epi_tile = sm90_utils.compute_tile_shape_or_override( self.tile_shape_mnk, self.c_dtype, is_cooperative=is_cooperative ) @@ -411,7 +414,7 @@ class HopperWgmmaGemmKernel: f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}" ) if cutlass.const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8): - raise TypeError(f"a_dtype should be float16 or float8") + raise TypeError("a_dtype should be float16 or float8") self._setup_attributes() @@ -708,7 +711,7 @@ class HopperWgmmaGemmKernel: tCrB = tiled_mma.make_fragment_B(tCsB) acc_shape = tCgC.shape - accumulators = cute.make_fragment(acc_shape, self.acc_dtype) + accumulators = cute.make_rmem_tensor(acc_shape, self.acc_dtype) # /////////////////////////////////////////////////////////////////////////////// # Cluster wait @@ -960,7 +963,7 @@ class HopperWgmmaGemmKernel: # Allocate D registers. rD_shape = cute.shape(thr_copy_r2s.partition_S(sC)) tRS_rD_layout = cute.make_layout(rD_shape[:3]) - tRS_rD = cute.make_fragment_like(tRS_rD_layout, self.acc_dtype) + tRS_rD = cute.make_rmem_tensor_like(tRS_rD_layout, self.acc_dtype) size_tRS_rD = cute.size(tRS_rD) sepi_for_tma_partition = cute.group_modes(sC, 0, 2) @@ -982,7 +985,7 @@ class HopperWgmmaGemmKernel: # Initialize tma store c_pipeline c_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta + pipeline.Agent.Thread, self.threads_per_cta ) c_pipeline = pipeline.PipelineTmaStore.create( num_stages=self.epi_stage, @@ -995,7 +998,7 @@ class HopperWgmmaGemmKernel: tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v] # Type conversion - tRS_rD_out = cute.make_fragment_like(tRS_rD_layout, self.c_dtype) + tRS_rD_out = cute.make_rmem_tensor_like(tRS_rD_layout, self.c_dtype) acc_vec = tRS_rD.load() tRS_rD_out.store(acc_vec.to(self.c_dtype)) @@ -1010,7 +1013,7 @@ class HopperWgmmaGemmKernel: space=cute.arch.SharedSpace.shared_cta, ) # barrier for sync - cute.arch.barrier() + pipeline.sync(barrier_id=1) gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) # Copy from shared memory to global memory @@ -1023,7 +1026,7 @@ class HopperWgmmaGemmKernel: c_pipeline.producer_commit() c_pipeline.producer_acquire() - cute.arch.barrier() + pipeline.sync(barrier_id=1) if warp_idx == 0: c_pipeline.producer_tail() @@ -1073,39 +1076,6 @@ class HopperWgmmaGemmKernel: ) // ab_bytes_per_stage return ab_stage, epi_stage - @staticmethod - def _sm90_compute_tile_shape_or_override( - tile_shape_mnk: tuple[int, int, int], - element_type: type[cutlass.Numeric], - is_cooperative: bool = False, - epi_tile_override: tuple[int, int] | None = None, - ) -> tuple[int, int]: - """Compute the epilogue tile shape or use override if provided. - - :param tile_shape_mnk: CTA tile shape (M,N,K) - :type tile_shape_mnk: Tuple[int, int, int] - :param element_type: Data type of elements - :type element_type: type[cutlass.Numeric] - :param is_cooperative: Whether to use cooperative approach - :type is_cooperative: bool - :param epi_tile_override: Optional override for epilogue tile shape - :type epi_tile_override: Tuple[int, int] or None - - :return: Computed epilogue tile shape - :rtype: Tuple[int, int] - """ - if epi_tile_override is not None: - return epi_tile_override - if is_cooperative: - tile_m = min(128, cute.size(tile_shape_mnk, mode=[0])) - tile_n = min(32, cute.size(tile_shape_mnk, mode=[1])) - return (tile_m, tile_n) - else: - n_perf = 64 if element_type.width == 8 else 32 - tile_m = min(64, cute.size(tile_shape_mnk, mode=[0])) - tile_n = min(n_perf, cute.size(tile_shape_mnk, mode=[1])) - return (tile_m, tile_n) - @staticmethod def _make_smem_layouts( tile_shape_mnk: tuple[int, int, int], @@ -1145,60 +1115,25 @@ class HopperWgmmaGemmKernel: :return: Tuple of shared memory layouts for A, B, and C :rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout] """ - a_smem_shape = cute.slice_(tile_shape_mnk, (None, 0, None)) - - a_is_k_major = ( - a_layout.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K - ) - b_is_k_major = ( - b_layout.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K - ) - a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0] - a_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( - sm90_utils.get_smem_layout_atom( - a_layout, - a_dtype, - a_major_mode_size, - ), + a_smem_layout_staged = sm90_utils.make_smem_layout_a( + a_layout, + tile_shape_mnk, a_dtype, - ) - a_smem_layout_staged = cute.tile_to_shape( - a_smem_layout_atom, - cute.append(a_smem_shape, ab_stage), - order=(0, 1, 2) if a_is_k_major else (1, 0, 2), + ab_stage, ) - b_smem_shape = cute.slice_(tile_shape_mnk, (0, None, None)) - - b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1] - b_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( - sm90_utils.get_smem_layout_atom( - b_layout, - b_dtype, - b_major_mode_size, - ), + b_smem_layout_staged = sm90_utils.make_smem_layout_b( + b_layout, + tile_shape_mnk, b_dtype, - ) - b_smem_layout_staged = cute.tile_to_shape( - b_smem_layout_atom, - cute.append(b_smem_shape, ab_stage), - order=(0, 1, 2) if b_is_k_major else (1, 0, 2), + ab_stage, ) - c_smem_shape = epi_tile - c_major_mode_size = epi_tile[1] if c_layout.is_n_major_c() else epi_tile[0] - c_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( - sm90_utils.get_smem_layout_atom( - c_layout, - c_dtype, - c_major_mode_size, - ), + epi_smem_layout_staged = sm90_utils.make_smem_layout_epi( c_dtype, - ) - epi_smem_layout_staged = cute.tile_to_shape( - c_smem_layout_atom, - cute.append(c_smem_shape, epi_stage), - order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2), + c_layout, + epi_tile, + epi_stage, ) return a_smem_layout_staged, b_smem_layout_staged, epi_smem_layout_staged @@ -1248,14 +1183,11 @@ class HopperWgmmaGemmKernel: :rtype: Tuple[cute.CopyAtom, cute.Tensor] """ epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0)) - c_cta_v_layout = cute.composition( - cute.make_identity_layout(tensor_c.shape), epi_tile - ) tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tiled_tma_atom( cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(), tensor_c, epi_smem_layout, - c_cta_v_layout, + epi_tile, ) return tma_atom_c, tma_tensor_c @@ -1326,44 +1258,130 @@ class HopperWgmmaGemmKernel: :rtype: bool """ is_valid = True - # tested a_dtype - if a_dtype not in { + + valid_ab_dtypes = { cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2, - }: + cutlass.Uint8, + cutlass.Int8, + } + if a_dtype not in valid_ab_dtypes: is_valid = False - # tested b_dtype - if b_dtype not in { - cutlass.Float16, - cutlass.Float8E4M3FN, - cutlass.Float8E5M2, - }: - is_valid = False - # tested acc_dtype - if acc_dtype not in {cutlass.Float32, cutlass.Float16}: - is_valid = False - # tested c_dtype - if c_dtype not in { - cutlass.Float32, - cutlass.Float16, - cutlass.Float8E4M3FN, - cutlass.Float8E5M2, - }: + if b_dtype not in valid_ab_dtypes: is_valid = False + # make sure a_dtype == b_dtype for Float16 if a_dtype.width == 16 and a_dtype != b_dtype: is_valid = False - # make sure a_dtype.width == b_dtype.width (i.e, Float8E4M3FN or Float8E5M2) if a_dtype.width != b_dtype.width: is_valid = False + if not a_dtype.is_same_kind(b_dtype): + is_valid = False - # for Float8 types, this implementation only supports k-major layout + # for 8-bit types, this implementation only supports k-major layout if (a_dtype.width == 8 and a_major != "k") or ( b_dtype.width == 8 and b_major != "k" ): is_valid = False + # Define compatibility mapping between accumulator type and AB type + acc_ab_compatibility = { + cutlass.Float32: { + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, + cutlass.Float16: { + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, + cutlass.Int32: {cutlass.Uint8, cutlass.Int8}, + } + # Check compatibility between accumulator type and A type + if a_dtype not in acc_ab_compatibility[acc_dtype]: + is_valid = False + + # Define compatibility mapping between accumulator type and C type + acc_c_compatibility = { + cutlass.Float32: { + cutlass.Float32, + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, + cutlass.Float16: { + cutlass.Float32, + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, + cutlass.Int32: { + cutlass.Float32, + cutlass.Float16, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + }, + } + # Check compatibility between accumulator type and C type + if c_dtype not in acc_c_compatibility[acc_dtype]: + is_valid = False + + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False return is_valid @@ -1418,7 +1436,7 @@ def run( :rtype: float """ - print(f"Running Hopper Dense GEMM with:") + print("Running Hopper Dense GEMM with:") print(f"mnkl: {mnkl}") print( f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}" @@ -1434,15 +1452,19 @@ def run( # Unpack parameters m, n, k, l = mnkl - # Skip unsupported types if not HopperWgmmaGemmKernel.is_valid_dtypes( a_dtype, b_dtype, acc_dtype, c_dtype, a_major, b_major ): raise TypeError( - f"Skipping due to unsupported combination of types and majors: {a_dtype}, {b_dtype}, {acc_dtype}, {c_dtype}, {a_major=}, {b_major=}" + f"unsupported combination of types and majors: A {a_dtype}, B {b_dtype}, Acc {acc_dtype}, C {c_dtype}, {a_major=}, {b_major=}" + ) + if not HopperWgmmaGemmKernel.is_valid_tensor_alignment( + m, n, k, l, a_dtype, c_dtype, a_major, b_major, c_major + ): + raise TypeError( + "the contiguous dimension of A/B/C tensors is not 16 bytes aligned" ) - # Prepare pytorch tensors: A, B (random from 0 to 2) and C (all zero) if not torch.cuda.is_available(): raise RuntimeError("GPU is required to run this example!") diff --git a/examples/python/CuTeDSL/hopper/dense_gemm_persistent.py b/examples/python/CuTeDSL/hopper/dense_gemm_persistent.py new file mode 100644 index 00000000..cd6e2d6a --- /dev/null +++ b/examples/python/CuTeDSL/hopper/dense_gemm_persistent.py @@ -0,0 +1,1619 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Optional, Tuple, Type +import math +import cuda.bindings.driver as cuda + +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.pipeline as pipeline +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.utils.hopper_helpers as sm90_utils + +""" +A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture +using CuTe DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Hopper's WGMMA for matrix multiply-accumulate (MMA) operations + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with MMA between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and MMA + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. MMA warp: + - Perform matrix multiply-accumulate (MMA) operations using WGMMA instruction. + - Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM) with TMA operations. + +Hopper WGMMA instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Perform MMA operation and store the result in Accumulator(register) + +To run this example: + +.. code-block:: bash + + python examples/hopper/dense_gemm_persistent.py \ + --mnkl 8192,8192,8192,1 --tile_shape_mn 128,256 \ + --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \ + --c_dtype Float16 --acc_dtype Float32 \ + --a_major k --b_major k --c_major n + +The above example command compute batched gemm with M=8192, N=8192, K=8192, +batch_count=1. The Hopper WGMMA tile shape is 128x256x64 and the cluster shape +is (1,1). The input, mma accumulator and output data type are set as fp16, fp32 +and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/hopper/dense_gemm.py \ + --mnkl 8192,8192,8192,1 --tile_shape_mn 128,256 \ + --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \ + --c_dtype Float16 --acc_dtype Float32 \ + --a_major k --b_major k --c_major n + +Constraints are same as dense_gemm.py: +* Supported input data types: fp16, fp8 (e4m3fn, e5m2), int8, uint8 +* For fp16 types, A and B must have the same data type +* For fp8 types, A and B can have different types (e4m3fn or e5m2) +* For 8-bit integer types, A and B can have different types (int8 or uint8) +* 8-bit types (e4m3fn, e5m2, int8, uint8) only support k-major layout +* CTA tile shape M must be 64/128 +* CTA tile shape N must be 64/128/256 +* CTA tile shape K must be 64 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 4 +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively. +""" + + +# Helpers to parse args +def parse_comma_separated_ints(s: str): + try: + return tuple([int(x.strip()) for x in s.split(",")]) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Example of MxNxKxL GEMM on Hopper.") + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(4096, 4096, 4096, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--tile_shape_mn", + type=parse_comma_separated_ints, + choices=[(128, 128), (128, 256), (128, 64), (64, 64)], + default=(128, 128), + help="Cta tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + choices=[(1, 1), (2, 1), (1, 2), (2, 2)], + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument( + "--swizzle_size", + type=int, + default=1, + help="Swizzling size in the unit of cluster for improving L2 cache hit rate", + ) + parser.add_argument( + "--raster_order", + type=str, + choices=["along_m", "along_n"], + default="along_m", + help="Rasterization order of clusters", + ) + parser.add_argument( + "--a_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + ) + parser.add_argument( + "--b_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + ) + parser.add_argument( + "--c_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + ) + parser.add_argument( + "--acc_dtype", + type=cutlass.dtype, + default=cutlass.Float32, + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + if len(args.tile_shape_mn) != 2: + parser.error("--tile_shape_mn must contain exactly 2 values") + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + return args + + +class HopperWgmmaGemmPersistentKernel: + """ + This class implements batched matrix multiplication (C = A x B) with support for various data types + and architectural features specific to Hopper GPUs. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param tile_shape_mn: Shape of the CTA tile (M,N) + :type tile_shape_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: Supported A/B data types: + - Float16 + A and B must have the same data type + - Float8E4M3FN/Float8E5M2 + A and B can have different types (Float8E4M3FN/Float8E5M2) + only support k-major layout + - Int8/Uint8 + A and B can have different types (Int8/Uint8) + only support k-major layout + + :note: Supported accumulation types: + - Float32/Float16 (for all floating point inputs) + - Int32 (for Int8/Uint8 inputs) + + :note: Constraints: + - CTA tile M must be 64/128 + - CTA tile N must be 64/128/256 + - CTA tile K must be 64 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 4 + + Example: + >>> gemm = HopperWgmmaGemmPersistentKernel( + ... acc_dtype=cutlass.Float32, + ... tile_shape_mn=(128, 256), + ... cluster_shape_mn=(1, 1) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, stream) + """ + + def __init__( + self, + acc_dtype: type[cutlass.Numeric], + tile_shape_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + swizzle_size: int, + raster_along_m: bool, + ): + """ + Initializes the configuration for a Hopper dense GEMM kernel. + + This configuration includes data types for operands, tile shape, cluster configuration, + and thread layout. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param tile_shape_mn: Shape of the CTA tile (M,N) + :type tile_shape_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype = acc_dtype + + self.cluster_shape_mn = cluster_shape_mn + self.swizzle_size = swizzle_size + self.raster_along_m = raster_along_m + self.mma_inst_shape_mn = None + # K dimension is deferred in _setup_attributes + self.tile_shape_mnk = (*tile_shape_mn, 1) + # For large tile size, using two warp groups is preferred because using only one warp + # group may result in register spill + self.atom_layout_mnk = ( + (2, 1, 1) + if self.tile_shape_mnk[0] > 64 and self.tile_shape_mnk[1] > 128 + else (1, 1, 1) + ) + self.num_mcast_ctas_a = None + self.num_mcast_ctas_b = None + self.is_a_mcast = False + self.is_b_mcast = False + self.tiled_mma = None + + self.occupancy = 1 + self.num_dma_warp_groups = 1 + self.num_mma_warp_groups = math.prod(self.atom_layout_mnk) + self.num_warps_per_warp_group = 4 + self.num_threads_per_warp_group = self.num_warps_per_warp_group * 32 + self.threads_per_cta = ( + self.num_dma_warp_groups + self.num_mma_warp_groups + ) * self.num_threads_per_warp_group + self.load_warp_id = 0 + self.epi_store_warp_id = ( + self.num_dma_warp_groups * self.num_warps_per_warp_group + ) + self.load_register_requirement = 40 + self.mma_register_requirement = 232 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90") + + self.ab_stage = None + self.epi_stage = None + + self.a_smem_layout_staged = None + self.b_smem_layout_staged = None + self.epi_smem_layout_staged = None + self.epi_tile = None + + self.shared_storage = None + self.buffer_align_bytes = 1024 + + self.num_mma_threads = ( + self.num_mma_warp_groups * self.num_threads_per_warp_group + ) + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, num_threads=self.num_mma_threads + ) + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + """ + + # check the cta tile shape + if self.tile_shape_mnk[0] not in [64, 128]: + raise ValueError("CTA tile shape M must be 64/128") + if self.tile_shape_mnk[1] not in [64, 128, 256]: + raise ValueError("CTA tile shape N must be 64/128/256") + + self.tiled_mma = sm90_utils.make_trivial_tiled_mma( + self.a_dtype, + self.b_dtype, + self.a_layout.sm90_mma_major_mode(), + self.b_layout.sm90_mma_major_mode(), + self.acc_dtype, + self.atom_layout_mnk, + tiler_mn=(64, self.tile_shape_mnk[1]), + ) + mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.tile_shape_mnk = ( + self.tile_shape_mnk[0], + self.tile_shape_mnk[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + + self.cta_layout_mnk = cute.make_layout((*self.cluster_shape_mn, 1)) + self.num_mcast_ctas_a = self.cluster_shape_mn[1] + self.num_mcast_ctas_b = self.cluster_shape_mn[0] + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + is_cooperative = self.atom_layout_mnk == (2, 1, 1) + self.epi_tile = self._sm90_compute_tile_shape_or_override( + self.tile_shape_mnk, self.c_dtype, is_cooperative=is_cooperative + ) + + # Compute stage before compute smem layout + self.ab_stage, self.epi_stage = self._compute_stages( + self.tile_shape_mnk, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.smem_capacity, + self.occupancy, + ) + + ( + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + ) = self._make_smem_layouts( + self.tile_shape_mnk, + self.epi_tile, + self.a_dtype, + self.a_layout, + self.b_dtype, + self.b_layout, + self.ab_stage, + self.c_dtype, + self.c_layout, + self.epi_stage, + ) + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + ): + """Execute the GEMM operation in steps: + - Setup static attributes + - Setup TMA load/store atoms and tensors + - Compute grid size + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + """ + + # setup static attributes before smem/grid/tma computation + self.a_dtype = a.element_type + self.b_dtype = b.element_type + self.c_dtype = c.element_type + self.a_layout = utils.LayoutEnum.from_tensor(a) + self.b_layout = utils.LayoutEnum.from_tensor(b) + self.c_layout = utils.LayoutEnum.from_tensor(c) + + if cutlass.const_expr( + self.a_dtype.width == 16 and self.a_dtype != self.b_dtype + ): + raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") + if cutlass.const_expr(self.a_dtype.width != self.b_dtype.width): + raise TypeError( + f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}" + ) + if cutlass.const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8): + raise TypeError("a_dtype should be float16, float8, or int8 ") + + self._setup_attributes() + + tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors( + a, + self.a_smem_layout_staged, + (self.tile_shape_mnk[0], self.tile_shape_mnk[2]), + self.cluster_shape_mn[1], + ) + + tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors( + b, + self.b_smem_layout_staged, + (self.tile_shape_mnk[1], self.tile_shape_mnk[2]), + self.cluster_shape_mn[0], + ) + + tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors( + c, + self.epi_smem_layout_staged, + self.epi_tile, + ) + + tile_sched_params, grid = self._compute_grid( + c, + self.tile_shape_mnk, + self.cluster_shape_mn, + self.swizzle_size, + self.raster_along_m, + max_active_clusters, + ) + + @cute.struct + class SharedStorage: + mainloop_pipeline_array_ptr: cute.struct.MemRange[ + cutlass.Int64, self.ab_stage * 2 + ] + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.epi_smem_layout_staged), + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + self.tiled_mma, + self.cta_layout_mnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + tile_sched_params, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + min_blocks_per_mp=1, + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + tiled_mma: cute.TiledMma, + cta_layout_mnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + epi_smem_layout_staged: cute.ComposedLayout, + tile_sched_params: utils.PersistentTileSchedulerParams, + ): + """ + GPU device kernel performing the batched GEMM computation. + + :param tma_atom_a: TMA copy atom for A tensor + :type tma_atom_a: cute.CopyAtom + :param mA_mkl: Input tensor A + :type mA_mkl: cute.Tensor + :param tma_atom_b: TMA copy atom for B tensor + :type tma_atom_b: cute.CopyAtom + :param mB_nkl: Input tensor B + :type mB_nkl: cute.Tensor + :param tma_atom_c: TMA copy atom for C tensor + :type tma_atom_c: cute.CopyAtom + :param mC_mnl: Output tensor C + :type mC_mnl: cute.Tensor + :param tiled_mma: Tiled MMA object + :type tiled_mma: cute.TiledMma + :param cta_layout_mnk: CTA layout + :type cta_layout_mnk: cute.Layout + :param a_smem_layout_staged: Shared memory layout for A + :type a_smem_layout_staged: cute.ComposedLayout + :param b_smem_layout_staged: Shared memory layout for B + :type b_smem_layout_staged: cute.ComposedLayout + :param epi_smem_layout_staged: Shared memory layout for epilogue + :type epi_smem_layout_staged: cute.ComposedLayout + :param tile_sched_params: Parameters for the persistent tile scheduler + :type tile_sched_params: utils.PersistentTileSchedulerParams + """ + + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # Prefetch Tma desc + if warp_idx == 0: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_a) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_b) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_c) + + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster) + + a_mcast_mask = cute.make_layout_image_mask( + cta_layout_mnk, cluster_coord_mnk, mode=1 + ) + b_mcast_mask = cute.make_layout_image_mask( + cta_layout_mnk, cluster_coord_mnk, mode=0 + ) + + a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0 + b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0 + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0)) + tma_copy_bytes = cute.size_in_bytes( + self.a_dtype, a_smem_layout + ) + cute.size_in_bytes(self.b_dtype, b_smem_layout) + + # Alloc and init AB full/empty + ACC full mbar (pipeline) + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + # mbar arrays + mainloop_pipeline_array_ptr = storage.mainloop_pipeline_array_ptr.data_ptr() + + # Threads/warps participating in this pipeline + mainloop_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread + ) + # Each warp will constribute to the arrive count with the number of mcast size + mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + consumer_arrive_cnt = ( + mcast_size * self.num_mma_warp_groups * self.num_warps_per_warp_group + ) + mainloop_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, consumer_arrive_cnt + ) + + mainloop_pipeline = pipeline.PipelineTmaAsync.create( + barrier_storage=mainloop_pipeline_array_ptr, + num_stages=self.ab_stage, + producer_group=mainloop_pipeline_producer_group, + consumer_group=mainloop_pipeline_consumer_group, + tx_count=tma_copy_bytes, + cta_layout_vmnk=cute.make_layout((1, *cta_layout_mnk.shape)), + ) + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # Generate smem tensor A/B + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + sC = storage.sC.get_tensor( + epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner + ) + + # Local_tile partition global tensors + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, + cute.slice_(self.tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, + cute.slice_(self.tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, + cute.slice_(self.tile_shape_mnk, (None, None, 0)), + (None, None, None), + ) + + # Partition shared tensor for TMA load A/B + # TMA load A partition_S/D + a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape) + a_cta_crd = cluster_coord_mnk[1] + tAsA, tAgA = cute.nvgpu.cpasync.tma_partition( + tma_atom_a, + a_cta_crd, + a_cta_layout, + cute.group_modes(sA, 0, 2), + cute.group_modes(gA_mkl, 0, 2), + ) + + # TMA load B partition_S/D + b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape) + b_cta_crd = cluster_coord_mnk[0] + tBsB, tBgB = cute.nvgpu.cpasync.tma_partition( + tma_atom_b, + b_cta_crd, + b_cta_layout, + cute.group_modes(sB, 0, 2), + cute.group_modes(gB_nkl, 0, 2), + ) + + # Partition global tensor for TiledMMA_A/B/C + warp_group_idx = cute.arch.make_warp_uniform( + tidx // self.num_threads_per_warp_group + ) + mma_warp_group_thread_layout = cute.make_layout( + self.num_mma_warp_groups, stride=self.num_threads_per_warp_group + ) + thr_mma = tiled_mma.get_slice( + mma_warp_group_thread_layout(warp_group_idx - self.num_dma_warp_groups) + ) + + # Make fragments + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCrA = tiled_mma.make_fragment_A(tCsA) + tCrB = tiled_mma.make_fragment_B(tCsB) + + tCgC = thr_mma.partition_C(gC_mnl) + acc_shape = tCgC.shape[:3] + accumulators = cute.make_rmem_tensor(acc_shape, self.acc_dtype) + + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # Cluster wait for barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + cute.arch.sync_threads() + + is_dma_warp_group = warp_group_idx < self.num_dma_warp_groups + if is_dma_warp_group: + cute.arch.warpgroup_reg_dealloc(self.load_register_requirement) + + if warp_idx == self.load_warp_id: + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + mainloop_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.ab_stage + ) + + while work_tile.is_valid_tile: + tile_coord_mnl = work_tile.tile_idx + tAgA_mkl = tAgA[(None, tile_coord_mnl[0], None, tile_coord_mnl[2])] + tBgB_nkl = tBgB[(None, tile_coord_mnl[1], None, tile_coord_mnl[2])] + + mainloop_producer_state.reset_count() + + for k_tile in range(k_tile_cnt): + # Conditionally wait for AB buffer empty + mainloop_pipeline.producer_acquire(mainloop_producer_state) + # Slice to global/shared memref to current k_tile + tAgA_k = tAgA_mkl[(None, mainloop_producer_state.count)] + tAsA_pipe = tAsA[(None, mainloop_producer_state.index)] + + tBgB_k = tBgB_nkl[(None, mainloop_producer_state.count)] + tBsB_pipe = tBsB[(None, mainloop_producer_state.index)] + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ), + mcast_mask=a_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ), + mcast_mask=b_mcast_mask, + ) + + # Mainloop pipeline's producer commit is a NOP + mainloop_pipeline.producer_commit(mainloop_producer_state) + mainloop_producer_state.advance() + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + mainloop_pipeline.producer_tail(mainloop_producer_state) + + # MMA warp group + if not is_dma_warp_group: + cute.arch.warpgroup_reg_alloc(self.mma_register_requirement) + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + mainloop_consumer_read_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage + ) + mainloop_consumer_release_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage + ) + + num_k_blocks = cute.size(tCrA, mode=[2]) + + # Partition for epilogue + copy_atom_r2s = sm90_utils.sm90_get_smem_store_op( + self.c_layout, + elem_ty_d=self.c_dtype, + elem_ty_acc=self.acc_dtype, + ) + + copy_atom_C = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp( + self.c_layout.is_m_major_c(), + 4, + ), + self.c_dtype, + ) + + tiled_copy_C_Atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma) + + tiled_copy_r2s = cute.make_tiled_copy_S( + copy_atom_r2s, + tiled_copy_C_Atom, + ) + + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice( + tidx - self.num_dma_warp_groups * self.num_threads_per_warp_group + ) + # (t)hread-partition for (r)egister to (s)mem copy (tRS_) + tRS_sD = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rAcc = tiled_copy_r2s.retile(accumulators) + + # Allocate D registers. + rD_shape = cute.shape(thr_copy_r2s.partition_S(sC)) + tRS_rD_layout = cute.make_layout(rD_shape[:3]) + tRS_rD = cute.make_rmem_tensor(tRS_rD_layout.shape, self.acc_dtype) + tRS_rD_out = cute.make_rmem_tensor(tRS_rD_layout.shape, self.c_dtype) + size_tRS_rD = cute.size(tRS_rD) + + k_pipe_mmas = 1 + prologue_mma_cnt = min(k_pipe_mmas, k_tile_cnt) + + # Initialize tma store pipeline + tma_store_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_mma_threads, + ) + tma_store_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.epi_stage, + producer_group=tma_store_producer_group, + ) + + while work_tile.is_valid_tile: + tile_coord_mnl = work_tile.tile_idx + gC_mnl_slice = gC_mnl[(None, None, *tile_coord_mnl)] + + # MAINLOOP + mainloop_consumer_read_state.reset_count() + mainloop_consumer_release_state.reset_count() + accumulators.fill(0.0) + tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + cute.nvgpu.warpgroup.fence() + + for k_tile in range(prologue_mma_cnt): + # Wait for TMA copies to complete + mainloop_pipeline.consumer_wait(mainloop_consumer_read_state) + # WGMMA + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = ( + None, + None, + k_block_idx, + mainloop_consumer_read_state.index, + ) + cute.gemm( + tiled_mma, + accumulators, + tCrA[k_block_coord], + tCrB[k_block_coord], + accumulators, + ) + + cute.nvgpu.warpgroup.commit_group() + mainloop_consumer_read_state.advance() + + for k_tile in range(prologue_mma_cnt, k_tile_cnt): + # Wait for TMA copies to complete + mainloop_pipeline.consumer_wait(mainloop_consumer_read_state) + # WGMMA + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = ( + None, + None, + k_block_idx, + mainloop_consumer_read_state.index, + ) + cute.gemm( + tiled_mma, + accumulators, + tCrA[k_block_coord], + tCrB[k_block_coord], + accumulators, + ) + + cute.nvgpu.warpgroup.commit_group() + # Wait on the wgmma barrier for WGMMA to complete + cute.nvgpu.warpgroup.wait_group(k_pipe_mmas) + + mainloop_pipeline.consumer_release(mainloop_consumer_release_state) + mainloop_consumer_release_state.advance() + mainloop_consumer_read_state.advance() + + cute.nvgpu.warpgroup.wait_group(0) + for k_tile in range(prologue_mma_cnt): + mainloop_pipeline.consumer_release(mainloop_consumer_release_state) + mainloop_consumer_release_state.advance() + + # Epilogue + tCgC_for_tma_partition = cute.zipped_divide(gC_mnl_slice, self.epi_tile) + + # thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + cute.group_modes(sC, 0, 2), + tCgC_for_tma_partition, + ) + + epi_tile_num = cute.size(tCgC_for_tma_partition, mode=[1]) + epi_tile_shape = tCgC_for_tma_partition.shape[1] + epi_tile_layout = cute.make_layout( + epi_tile_shape, stride=(epi_tile_shape[1], 1) + ) + + num_prev_epi_tiles = tile_sched.num_tiles_executed * epi_tile_num + for epi_idx in cutlass.range_constexpr(epi_tile_num): + # Copy from accumulators to D registers + for epi_v in cutlass.range_constexpr(size_tRS_rD): + tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v] + + # Type conversion + acc_vec = tRS_rD.load() + tRS_rD_out.store(acc_vec.to(self.c_dtype)) + + # Copy from D registers to shared memory + epi_buffer = (num_prev_epi_tiles + epi_idx) % cute.size( + tRS_sD, mode=[3] + ) + cute.copy( + tiled_copy_r2s, + tRS_rD_out, + tRS_sD[(None, None, None, epi_buffer)], + ) + + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.epilog_sync_barrier.arrive_and_wait() + + gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) + # Copy from shared memory to global memory + if warp_idx == self.epi_store_warp_id: + cute.copy( + tma_atom_c, + bSG_sD[(None, epi_buffer)], + bSG_gD[(None, gmem_coord)], + ) + tma_store_pipeline.producer_commit() + tma_store_pipeline.producer_acquire() + + self.epilog_sync_barrier.arrive_and_wait() + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + tma_store_pipeline.producer_tail() + + @staticmethod + def _compute_stages( + tile_shape_mnk: tuple[int, int, int], + a_dtype: type[cutlass.Numeric], + b_dtype: type[cutlass.Numeric], + epi_tile: tuple[int, int], + c_dtype: type[cutlass.Numeric], + smem_capacity: int, + occupancy: int, + ) -> tuple[int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type tile_shape_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: Epilogue tile shape + :type epi_tile: Tuple[int, int] + :param c_dtype: The data type of the output tensor + :type c_dtype: type[cutlass.Numeric] + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (A/B operand stages, epilogue stages) + :rtype: tuple[int, int] + """ + + a_shape = cute.slice_(tile_shape_mnk, (None, 0, None)) + b_shape = cute.slice_(tile_shape_mnk, (0, None, None)) + ab_bytes_per_stage = ( + cute.size(a_shape) * a_dtype.width // 8 + + cute.size(b_shape) * b_dtype.width // 8 + ) + c_bytes_per_stage = cute.size(epi_tile) * c_dtype.width // 8 + epi_stage = 4 + epi_bytes = c_bytes_per_stage * epi_stage + + mbar_helpers_bytes = 1024 + + ab_stage = ( + smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes) + ) // ab_bytes_per_stage + return ab_stage, epi_stage + + @staticmethod + def _sm90_compute_tile_shape_or_override( + tile_shape_mnk: tuple[int, int, int], + element_type: type[cutlass.Numeric], + is_cooperative: bool = False, + epi_tile_override: Optional[tuple[int, int]] = None, + ) -> tuple[int, int]: + """Compute the epilogue tile shape or use override if provided. + + :param tile_shape_mnk: CTA tile shape (M,N,K) + :type tile_shape_mnk: Tuple[int, int, int] + :param element_type: Data type of elements + :type element_type: type[cutlass.Numeric] + :param is_cooperative: Whether to use cooperative approach + :type is_cooperative: bool + :param epi_tile_override: Optional override for epilogue tile shape + :type epi_tile_override: Tuple[int, int] or None + + :return: Computed epilogue tile shape + :rtype: Tuple[int, int] + """ + if epi_tile_override is not None: + return epi_tile_override + if is_cooperative: + tile_m = min(128, cute.size(tile_shape_mnk, mode=[0])) + tile_n = min(32, cute.size(tile_shape_mnk, mode=[1])) + return (tile_m, tile_n) + else: + n_perf = 64 if element_type.width == 8 else 32 + tile_m = min(64, cute.size(tile_shape_mnk, mode=[0])) + tile_n = min(n_perf, cute.size(tile_shape_mnk, mode=[1])) + return (tile_m, tile_n) + + @staticmethod + def _make_smem_layouts( + tile_shape_mnk: tuple[int, int, int], + epi_tile: tuple[int, int], + a_dtype: type[cutlass.Numeric], + a_layout: utils.LayoutEnum, + b_dtype: type[cutlass.Numeric], + b_layout: utils.LayoutEnum, + ab_stage: int, + c_dtype: type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + epi_stage: int, + ) -> tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]: + """Create shared memory layouts for A, B, and C tensors. + + :param tile_shape_mnk: CTA tile shape (M,N,K) + :type tile_shape_mnk: Tuple[int, int, int] + :param epi_tile: Epilogue tile shape + :type epi_tile: Tuple[int, int] + :param a_dtype: Data type for matrix A + :type a_dtype: type[cutlass.Numeric] + :param a_layout: Layout enum for matrix A + :type a_layout: utils.LayoutEnum + :param b_dtype: Data type for matrix B + :type b_dtype: type[cutlass.Numeric] + :param b_layout: Layout enum for matrix B + :type b_layout: utils.LayoutEnum + :param ab_stage: Number of stages for A/B tensors + :type ab_stage: int + :param c_dtype: Data type for output matrix C + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum for the output matrix C + :type c_layout: utils.LayoutEnum + :param epi_stage: Number of epilogue stages + :type epi_stage: int + + :return: Tuple of shared memory layouts for A, B, and C + :rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout] + """ + a_smem_shape = cute.slice_(tile_shape_mnk, (None, 0, None)) + + a_is_k_major = ( + a_layout.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K + ) + b_is_k_major = ( + b_layout.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K + ) + a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0] + a_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom( + a_layout, + a_dtype, + a_major_mode_size, + ), + a_dtype, + ) + a_smem_layout_staged = cute.tile_to_shape( + a_smem_layout_atom, + cute.append(a_smem_shape, ab_stage), + order=(0, 1, 2) if a_is_k_major else (1, 0, 2), + ) + + b_smem_shape = cute.slice_(tile_shape_mnk, (0, None, None)) + + b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1] + b_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom( + b_layout, + b_dtype, + b_major_mode_size, + ), + b_dtype, + ) + b_smem_layout_staged = cute.tile_to_shape( + b_smem_layout_atom, + cute.append(b_smem_shape, ab_stage), + order=(0, 1, 2) if b_is_k_major else (1, 0, 2), + ) + + c_smem_shape = epi_tile + c_major_mode_size = epi_tile[1] if c_layout.is_n_major_c() else epi_tile[0] + c_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom( + c_layout, + c_dtype, + c_major_mode_size, + ), + c_dtype, + ) + epi_smem_layout_staged = cute.tile_to_shape( + c_smem_layout_atom, + cute.append(c_smem_shape, epi_stage), + order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2), + ) + + return a_smem_layout_staged, b_smem_layout_staged, epi_smem_layout_staged + + @staticmethod + def _compute_grid( + c: cute.Tensor, + tile_shape_mnk: tuple[int, int, int], + cluster_shape_mn: tuple[int, int], + swizzle_size: int, + raster_along_m: bool, + max_active_clusters: cutlass.Constexpr, + ) -> tuple[int, int, int]: + """Compute grid shape for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: Grid shape for kernel launch. + :rtype: tuple[int, int, int] + """ + + c_shape = cute.slice_(tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, + cluster_shape_mnl, + swizzle_size, + raster_along_m, + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + return tile_sched_params, grid + + @staticmethod + def _make_tma_store_atoms_and_tensors( + tensor_c: cute.Tensor, + epi_smem_layout_staged: cute.ComposedLayout, + epi_tile: tuple[int, int], + ) -> tuple[cute.CopyAtom, cute.Tensor]: + """Create TMA atoms and tensors for C tensor storage. + + :param tensor_c: Output tensor C + :type tensor_c: cute.Tensor + :param epi_smem_layout_staged: Shared memory layout for epilogue + :type epi_smem_layout_staged: cute.ComposedLayout + :param epi_tile: Epilogue tile shape + :type epi_tile: Tuple[int, int] + + :return: TMA atom and tensor for C + :rtype: Tuple[cute.CopyAtom, cute.Tensor] + """ + epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tiled_tma_atom( + cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(), + tensor_c, + epi_smem_layout, + epi_tile, + ) + + return tma_atom_c, tma_tensor_c + + @staticmethod + def _make_tma_atoms_and_tensors( + tensor: cute.Tensor, + smem_layout_staged: cute.ComposedLayout, + smem_tile: tuple[int, int], + mcast_dim: int, + ) -> tuple[cute.CopyAtom, cute.Tensor]: + """Create TMA atoms and tensors for input tensors. + + :param tensor: Input tensor (A or B) + :type tensor: cute.Tensor + :param smem_layout_staged: Shared memory layout for the tensor + :type smem_layout_staged: cute.ComposedLayout + :param smem_tile: Shared memory tile shape + :type smem_tile: Tuple[int, int] + :param mcast_dim: Multicast dimension + :type mcast_dim: int + + :return: TMA atom and tensor + :rtype: Tuple[cute.CopyAtom, cute.Tensor] + """ + op = ( + cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() + if mcast_dim == 1 + else cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp() + ) + + smem_layout = cute.slice_(smem_layout_staged, (None, None, 0)) + tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom( + op, + tensor, + smem_layout, + smem_tile, + num_multicast=mcast_dim, + ) + return tma_atom, tma_tensor + + @staticmethod + def is_valid_dtypes( + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + ) -> bool: + """ + Check if the dtypes are valid + + :param a_dtype: The data type of tensor A + :type a_dtype: Type[cutlass.Numeric] + :param b_dtype: The data type of tensor B + :type b_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: major mode of tensor A + :type a_major: str + :param b_major: major mode of tensor B + :type b_major: str + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + + valid_ab_dtypes = { + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + cutlass.Uint8, + cutlass.Int8, + } + if a_dtype not in valid_ab_dtypes: + is_valid = False + if b_dtype not in valid_ab_dtypes: + is_valid = False + + # make sure a_dtype == b_dtype for Float16 + if a_dtype.width == 16 and a_dtype != b_dtype: + is_valid = False + if a_dtype.width != b_dtype.width: + is_valid = False + if not a_dtype.is_same_kind(b_dtype): + is_valid = False + + # for 8-bit types, this implementation only supports k-major layout + if (a_dtype.width == 8 and a_major != "k") or ( + b_dtype.width == 8 and b_major != "k" + ): + is_valid = False + + # Define compatibility mapping between accumulator type and AB type + acc_ab_compatibility = { + cutlass.Float32: { + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, + cutlass.Float16: { + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, + cutlass.Int32: {cutlass.Uint8, cutlass.Int8}, + } + # Check compatibility between accumulator type and A type + if a_dtype not in acc_ab_compatibility[acc_dtype]: + is_valid = False + + # Define compatibility mapping between accumulator type and C type + acc_c_compatibility = { + cutlass.Float32: { + cutlass.Float32, + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, + cutlass.Float16: { + cutlass.Float32, + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }, + cutlass.Int32: { + cutlass.Float32, + cutlass.Float16, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + }, + } + # Check compatibility between accumulator type and C type + if c_dtype not in acc_c_compatibility[acc_dtype]: + is_valid = False + + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + +def run( + mnkl: Tuple[int, int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + tile_shape_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + swizzle_size: int = 1, + raster_along_m: bool = True, + tolerance: float = 1e-01, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + """ + Prepare A/B/C tensors, launch GPU kernel, and reference checking. + + :param mnkl: Problem size (M, N, K, L) + :type mnkl: Tuple[int, int, int, int] + :param a_dtype: Data type for input tensor A + :type a_dtype: Type[cutlass.Numeric] + :param b_dtype: Data type for input tensor B + :type b_dtype: Type[cutlass.Numeric] + :param c_dtype: Data type for output tensor C + :type c_dtype: Type[cutlass.Numeric] + :param acc_dtype: Data type for accumulation during matrix multiplication + :type acc_dtype: Type[cutlass.Numeric] + :param a_major/b_major/c_major: Memory layout of tensor A/B/C + :type a_major/b_major/c_major: str + :param tile_shape_mn: CTA tile shape (M, N) + :type tile_shape_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster shape (M, N) + :type cluster_shape_mn: Tuple[int, int] + :param tolerance: Tolerance value for reference validation comparison + :type tolerance: float + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 1 + :type iterations: int, optional + :param skip_ref_check: Whether to skip reference result validation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :return: Execution time of the GEMM kernel in microseconds + :rtype: float + """ + + print("Running Hopper Persistent Dense GEMM with:") + print(f"mnkl: {mnkl}") + print( + f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}" + ) + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Tile Shape: {tile_shape_mn}, Cluster Shape: {cluster_shape_mn}") + print( + f"Swizzle size: {swizzle_size}, Raster order:", + "along_m" if raster_along_m else "along_n", + ) + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {use_cold_l2}") + + # Unpack parameters + m, n, k, l = mnkl + + if not HopperWgmmaGemmPersistentKernel.is_valid_dtypes( + a_dtype, b_dtype, acc_dtype, c_dtype, a_major, b_major + ): + raise TypeError( + f"unsupported combination of types and majors: A {a_dtype}, B {b_dtype}, Acc {acc_dtype}, C {c_dtype}, {a_major=}, {b_major=}" + ) + if not HopperWgmmaGemmPersistentKernel.is_valid_tensor_alignment( + m, n, k, l, a_dtype, c_dtype, a_major, b_major, c_major + ): + raise TypeError( + "the contiguous dimension of A/B/C tensors is not 16 bytes aligned" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + # Create and permute tensor A/B/C + a_torch_cpu = cutlass_torch.matrix(l, m, k, a_major == "m", a_dtype) + b_torch_cpu = cutlass_torch.matrix(l, n, k, b_major == "n", b_dtype) + c_torch_cpu = cutlass_torch.matrix(l, m, n, c_major == "m", c_dtype) + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_torch_cpu, a_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_torch_cpu, b_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like( + c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + gemm = HopperWgmmaGemmPersistentKernel( + acc_dtype, tile_shape_mn, cluster_shape_mn, swizzle_size, raster_along_m + ) + + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + + torch_stream = torch.cuda.Stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + # Compile gemm kernel + compiled_gemm = cute.compile( + gemm, a_tensor, b_tensor, c_tensor, max_active_clusters, stream + ) + + if not skip_ref_check: + compiled_gemm(a_tensor, b_tensor, c_tensor, stream) + torch.cuda.synchronize() + + # Compute reference result + ref = torch.einsum( + "mkl,nkl->mnl", + a_torch_cpu.to(dtype=torch.float32), + b_torch_cpu.to(dtype=torch.float32), + ) + + # Convert ref to c_dtype + _, ref_torch_gpu = cutlass_torch.cute_tensor_like( + ref, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + ref_c = ref_torch_gpu.cpu() + + # Assert close results + torch.testing.assert_close(c_torch_gpu.cpu(), ref_c, atol=tolerance, rtol=1e-03) + + def generate_tensors(): + a_tensor_workspace, _ = cutlass_torch.cute_tensor_like( + a_torch_cpu, a_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor_workspace, _ = cutlass_torch.cute_tensor_like( + b_torch_cpu, b_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor_workspace, _ = cutlass_torch.cute_tensor_like( + c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + return testing.JitArguments( + a_tensor_workspace, b_tensor_workspace, c_tensor_workspace, stream + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch_cpu.numel() * a_torch_cpu.element_size() + + b_torch_cpu.numel() * b_torch_cpu.element_size() + + c_torch_cpu.numel() * c_torch_cpu.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + args = parse_arguments() + run( + args.mnkl, + args.a_dtype, + args.b_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.tile_shape_mn, + args.cluster_shape_mn, + args.swizzle_size, + True if args.raster_order == "along_m" else False, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/examples/python/CuTeDSL/hopper/fmha.py b/examples/python/CuTeDSL/hopper/fmha.py new file mode 100644 index 00000000..0d04831d --- /dev/null +++ b/examples/python/CuTeDSL/hopper/fmha.py @@ -0,0 +1,2537 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +A fused multi-head attention (FMHA) example for the NVIDIA Hopper SM90 architecture using CUTE DSL + +This example demonstrates an implementation of fused multi-head attention using a TMA + Hopper SM90 +TensorCore warp-specialized kernel. The implementation integrates the Q*K^T matrix multiplication, +softmax normalization, and softmax(Q*K^T)*V into a single kernel, avoiding intermediate data movement between +global memory and shared memory, thus improving computational efficiency. + +The kernel implements key optimizations including: +- Warp specialization for different computation phases (load, MMA) +- 2 MMA WarpGroup for compute +- Pipeline stages between different warps for overlapping computation and memory access +- Support for different precision data types +- Optional causal masking for autoregressive models +- Sliding window attention masking for efficient long sequence processing + +To run this example: + +.. code-block:: bash + + python examples/hopper/fmha.py \ + --qk_acc_dtype Float32 --pv_acc_dtype Float32 \ + --mma_tiler_mn 64,128 \ + --q_shape 4,1024,8,64 --k_shape 4,1024,8,64 \ + --is_persistent + +The above example runs FMHA with batch size 4, sequence length 1024, 8 attention heads, and head +dimension 64. The Hopper MMA tile shape is (64, 128), and the kernel uses fp16 for input/output +with fp32 for accumulation. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/hopper/fmha.py \ + --qk_acc_dtype Float32 --pv_acc_dtype Float32 \ + --mma_tiler_mn 64,128 \ + --q_shape 4,1024,8,64 --k_shape 4,1024,8,64 \ + --is_persistent --warmup_iterations 10 \ + --iterations 10 --skip_ref_check + +Constraints for this example: +* Supported head dimensions: 32, 64, 128, 256 + ** 256 for `mma_tiler_mn` shoule be (64, 256) with non-persistent mode(not present `--is_persistent` in command line +* Number of heads in Q must equal with number of heads in K +* For causal masking, use --is_causal (note: specify without =True/False) +* For persistent scheduling, use --is_persistent (note: specify without =True/False) +* For sliding window, use --window_size x,y where x is left window size and y is right window size +""" + +import argparse +import math +import os +import sys +import time +from typing import Type, Tuple, Optional + +import torch + + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute + +import cutlass.cute.testing as testing +import cutlass.cute.nvgpu.warpgroup as warpgroup +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.torch as cutlass_torch +from cutlass._mlir.dialects import math as _math + +import cutlass.utils.hopper_helpers as sm90_utils +from cutlass.cute.runtime import from_dlpack + +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.join(current_dir, "..")) +from utils import fmha_helpers as fmha_utils + + +class HopperFusedMultiHeadAttentionForward: + def __init__( + self, + qk_acc_dtype, + pv_acc_dtype, + mma_tiler, + is_persistent, + mask_type: fmha_utils.MaskEnum, + ): + """Initializes the configuration for a Hopper Fused Multi-Head Attention (FMHA) kernel. + + This configuration includes several key aspects: + + 1. Data Type Settings: + - qk_acc_dtype: Data type for Q*K^T matrix multiplication accumulator + - pv_acc_dtype: Data type for P*V matrix multiplication accumulator + + 2. MMA Instruction Settings: + - mma_tiler: The (M, N, K) shape of the MMA instruction unit + - qk_mma_tiler: MMA shape for Q*K^T computation + - pv_mma_tiler: MMA shape for P*V computation + + 3. Kernel Execution Mode: + - is_persistent: Boolean indicating whether to use persistent kernel mode + - mask_type: Specifies the type of mask to use (no mask, residual mask, or causal mask) + - window_size_left/right: Sliding window parameters for attention masking + + :param qk_acc_dtype: Data type for Q*K^T matrix multiplication accumulator + :type qk_acc_dtype: Type[cutlass.Numeric] + :param pv_acc_dtype: Data type for P*V matrix multiplication accumulator + :type pv_acc_dtype: Type[cutlass.Numeric] + :param mma_tiler: The (M, N, K) shape of the MMA instruction + :type mma_tiler: Tuple[int, int, int] + :param is_persistent: Whether to use persistent kernel mode + :type is_persistent: bool + :param mask_type: Type of mask to use + :type mask_type: fmha_utils.MaskEnum + """ + + self.num_mma_warp_groups = 2 + self.qk_acc_dtype = qk_acc_dtype + self.pv_acc_dtype = pv_acc_dtype + self.cta_tiler = self.cta_tile_shape_mnk = ( + mma_tiler[0] * self.num_mma_warp_groups, + mma_tiler[1], + mma_tiler[2], + ) + + self.qk_mma_tiler = ( + mma_tiler[0], + mma_tiler[1], + mma_tiler[2], + ) + + self.pv_mma_tiler = ( + self.qk_mma_tiler[0], + self.qk_mma_tiler[2], + self.qk_mma_tiler[1], + ) + + self.cluster_shape_mn = (1, 1) + self.atom_layout_mnk = (1, 1, 1) + self.is_persistent = is_persistent + self.mask_type = mask_type + self.threads_per_warp = 32 + self.num_threads_per_warp_group = 128 + self.num_warps_per_warp_group = ( + self.num_threads_per_warp_group / self.threads_per_warp + ) + + # WarpGroupRole + self.load_warp_group_id = 0 + self.compute_epilogue_0_warp_group_id = 1 + self.compute_epilogue_1_warp_group_id = 2 + # ProducerWarpRole + self.producer_warp_loadkv_id = 1 + + self.num_regs_load = 40 - 2 * 8 + num_load_warp_groups = 1 + self.num_threads_per_warp_group = 128 + max_threads_per_block = ( + self.num_mma_warp_groups + num_load_warp_groups + ) * self.num_threads_per_warp_group + self.threads_per_cta = max_threads_per_block + self.num_regs_mma = 240 + self.buffer_align_bytes = 1024 + + def _setup_attributes(self): + self.q_stage = 2 + self.kv_stage = 5 + self.epi_stage = 2 + + @cute.jit + def __call__( + self, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + o: cute.Tensor, + lse: cute.Tensor, + scale_softmax_log2: cutlass.Float32, + scale_softmax: cutlass.Float32, + scale_output: cutlass.Float32, + window_size_left: Optional[cutlass.Int32], + window_size_right: Optional[cutlass.Int32], + stream: cuda.CUstream, + ): + # setup static attributes before smem/grid/tma computation + self.q_dtype = q.element_type + self.k_dtype = k.element_type + self.v_dtype = v.element_type + self.o_dtype = o.element_type + + # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast + k = cute.make_tensor( + k.iterator, + cute.make_layout( + (k.shape[0], k.shape[1], ((q.shape[2], k.shape[3]), k.shape[4])), + stride=( + k.stride[0], + k.stride[1], + ((0, k.stride[3]), k.stride[4]), + ), + ), + ) + + # (d, s, ((h_r, h_k), b)), 0-stride for h_r to broadcast + v = cute.make_tensor( + v.iterator, + cute.make_layout( + (v.shape[1], v.shape[0], ((q.shape[2], v.shape[3]), v.shape[4])), + stride=( + v.stride[1], + v.stride[0], + ((0, v.stride[3]), v.stride[4]), + ), + ), + ) + + # (s, d, ((h_r, h_k), b)) + q = cute.group_modes(cute.group_modes(q, begin=2, end=4), begin=2, end=4) + o = cute.group_modes(cute.group_modes(o, begin=2, end=4), begin=2, end=4) + + # (s, ((h_r, h_k), b)) + lse = cute.make_tensor( + lse.iterator, + cute.make_layout( + ( + lse.shape[0], + self.pv_mma_tiler[1], + ((lse.shape[2], lse.shape[3]), lse.shape[4]), + ), + stride=( + lse.stride[0], + 0, + ((lse.stride[2], lse.stride[3]), lse.stride[4]), + ), + ), + ) + + if cutlass.const_expr(self.q_dtype != self.k_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") + if cutlass.const_expr(self.q_dtype != self.v_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") + + if cutlass.const_expr(q.leading_dim != 1): # k-major + raise RuntimeError("The layout of q is not supported") + + if cutlass.const_expr(k.leading_dim != 1): # k-major + raise RuntimeError("The layout of k is not supported") + + self._setup_attributes() + + tile_shape_mnk = self.cta_tiler + self.epi_tile = sm90_utils.compute_tile_shape_or_override( + tile_shape_mnk, self.o_dtype + ) + + self.q_layout = utils.LayoutEnum.from_tensor(q) + self.k_layout = utils.LayoutEnum.from_tensor(k) + self.v_layout = utils.LayoutEnum.from_tensor(v) + self.o_layout = utils.LayoutEnum.from_tensor(o) + + self.q_major_mode = self.q_layout.sm90_mma_major_mode() + self.k_major_mode = self.k_layout.sm90_mma_major_mode() + self.v_major_mode = self.v_layout.sm90_mma_major_mode() + + p_major_mode = cute.nvgpu.warpgroup.OperandMajorMode.K + qk_tiled_mma = sm90_utils.make_trivial_tiled_mma( + self.q_dtype, + self.k_dtype, + self.q_major_mode, + self.k_major_mode, + self.qk_acc_dtype, + self.atom_layout_mnk, + self.qk_mma_tiler[:2], + ) + + pv_tiled_mma = sm90_utils.make_trivial_tiled_mma( + self.v_dtype, + self.v_dtype, + p_major_mode, + self.v_major_mode, + self.pv_acc_dtype, + self.atom_layout_mnk, + self.pv_mma_tiler[:2], + warpgroup.OperandSource.RMEM, + ) + + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (qk_tiled_mma.thr_id.shape,), + ) + + q_smem_layout_staged = sm90_utils.make_smem_layout_a( + self.q_layout, + self.qk_mma_tiler, + self.q_dtype, + self.q_stage, + ) + + k_smem_layout_staged = sm90_utils.make_smem_layout_b( + self.k_layout, + self.qk_mma_tiler, + self.k_dtype, + self.kv_stage, + ) + + v_smem_layout_staged = sm90_utils.make_smem_layout_b( + self.v_layout, + self.pv_mma_tiler, + self.v_dtype, + self.kv_stage, + ) + + o_smem_layout_staged = sm90_utils.make_smem_layout_epi( + self.o_dtype, + self.o_layout, + self.epi_tile, + self.epi_stage, + cute.append( + cute.append(self.epi_tile, self.epi_stage), self.num_mma_warp_groups + ), + smem_order=(1, 0, 2, 3) if self.o_layout.is_m_major_c() else (0, 1, 2, 3), + ) + + # TMA load for Q + q_smem_layout = cute.slice_(q_smem_layout_staged, (None, None, 0)) + tma_atom_q, tma_tensor_q = self._make_tma_atoms_and_tensors( + q, + q_smem_layout_staged, + (self.qk_mma_tiler[0], self.qk_mma_tiler[2]), + self.cluster_shape_mnk[1], + ) + + # TMA load for K + k_smem_layout = cute.slice_(k_smem_layout_staged, (None, None, 0)) + tma_atom_k, tma_tensor_k = self._make_tma_atoms_and_tensors( + k, + k_smem_layout_staged, + (self.qk_mma_tiler[1], self.qk_mma_tiler[2]), + self.cluster_shape_mnk[0], + ) + + # TMA load for V + pv_tile_shape_mnk = ( + self.qk_mma_tiler[0], + self.qk_mma_tiler[2], + self.qk_mma_tiler[1], + ) + tma_atom_v, tma_tensor_v = self._make_tma_atoms_and_tensors( + v, + v_smem_layout_staged, + (pv_tile_shape_mnk[1], pv_tile_shape_mnk[2]), + self.cluster_shape_mnk[0], + ) + + o_cta_v_layout = cute.composition( + cute.make_identity_layout(o.shape), self.epi_tile + ) + o_smem_layout = cute.slice_(o_smem_layout_staged, (None, None, 0, 0)) + + tma_store_op = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp() + tma_atom_o, tma_tensor_o = cute.nvgpu.cpasync.make_tiled_tma_atom( + tma_store_op, + o, + o_smem_layout, + self.epi_tile, + ) + + q_copy_size = cute.size_in_bytes(self.q_dtype, q_smem_layout) + k_copy_size = cute.size_in_bytes(self.k_dtype, k_smem_layout) + self.tma_copy_q_bytes = q_copy_size + self.tma_copy_kv_bytes = k_copy_size + + self.tile_sched_params, grid = fmha_utils.compute_grid( + o.shape, + self.cta_tiler, + self.is_persistent, + ) + + @cute.struct + class SharedStorage: + # 2 for full/empty + load_q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.q_stage * 2] + load_kv_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.kv_stage * 2] + MathWarpGroupOrderBarrier: cute.struct.MemRange[ + cutlass.Int64, self.num_mma_warp_groups + ] + + sO: cute.struct.Align[ + cute.struct.MemRange[ + self.o_dtype, + ( + cute.cosize(o_smem_layout_staged) + if cutlass.const_expr(self.is_persistent) + else 0 + ), + ], + self.buffer_align_bytes, + ] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(q_smem_layout_staged)], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(k_smem_layout_staged)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + qk_tiled_mma, + pv_tiled_mma, + tma_atom_q, + tma_tensor_q, + tma_atom_k, + tma_tensor_k, + tma_atom_v, + tma_tensor_v, + tma_atom_o, + tma_tensor_o, + lse, + scale_softmax_log2, + scale_softmax, + scale_output, + window_size_left, + window_size_right, + q_smem_layout_staged, + k_smem_layout_staged, + v_smem_layout_staged, + o_smem_layout_staged, + self.tile_sched_params, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + qk_tiled_mma: cute.TiledMma, + pv_tiled_mma: cute.TiledMma, + tma_atom_q: cute.CopyAtom, + mQ_qdl: cute.Tensor, + tma_atom_k: cute.CopyAtom, + mK_kdl: cute.Tensor, + tma_atom_v: cute.CopyAtom, + mV_dkl: cute.Tensor, + tma_atom_o: cute.CopyAtom, + mO_qdl: cute.Tensor, + mLse_qdl: cute.Tensor, + scale_softmax_log2: cutlass.Float32, + scale_softmax: cutlass.Float32, + scale_output: cutlass.Float32, + window_size_left: Optional[cutlass.Int32], + window_size_right: Optional[cutlass.Int32], + q_smem_layout_staged: cute.ComposedLayout, + k_smem_layout_staged: cute.ComposedLayout, + v_smem_layout_staged: cute.ComposedLayout, + o_smem_layout_staged: cute.ComposedLayout, + tile_sched_params: fmha_utils.FmhaStaticTileSchedulerParams, + ): + """The device kernel implementation of the Fused Multi-Head Attention for Hopper architecture. + + This kernel coordinates multiple specialized warps to perform different phases of the FMHA computation: + 1. Load warp group: Loads Q, K, V data from global memory to shared memory using TMA + 2. Comput warps groups: Performs matrix multiplications (Q*K^T and P*V) using Hopper TensorCores, + then compute softmax normalization on attention scores with numerical stability. + Handle final output transformation and storage. + + The kernel implements a complex pipeline with overlapping computation and memory operations, + using tensor memory access (TMA) for efficient data loading, warp specialization for different + computation phases, and optional attention masking for causal or residual attention patterns. + + Key optimizations include: + - Warp group specialization for load, compute/epilogue phases + - Pipeline stages between different warps for overlapping computation and memory access + - Efficient shared memory layouts optimized for Hopper architecture + - Support for different precision data types and accumulation types + - Optional causal masking for autoregressive models + - Sliding window attention masking for efficient long sequence processing + + :param qk_tiled_mma: Tiled MMA for Q*K^T matrix multiplication + :type qk_tiled_mma: cute.TiledMma + :param pv_tiled_mma: Tiled MMA for P*V matrix multiplication + :type pv_tiled_mma: cute.TiledMma + :param tma_atom_q: TMA copy atom for query tensor loading + :type tma_atom_q: cute.CopyAtom + :param mQ_qdl: Partitioned query tensor for TMA loading + :type mQ_qdl: cute.Tensor + :param tma_atom_k: TMA copy atom for key tensor loading + :type tma_atom_k: cute.CopyAtom + :param mK_kdl: Partitioned key tensor for TMA loading + :type mK_kdl: cute.Tensor + :param tma_atom_v: TMA copy atom for value tensor loading + :type tma_atom_v: cute.CopyAtom + :param mV_dkl: Partitioned value tensor for TMA loading + :type mV_dkl: cute.Tensor + :param tma_atom_o: TMA copy atom for output tensor storage + :type tma_atom_o: cute.CopyAtom + :param mO_qdl: Partitioned output tensor for TMA storage + :type mO_qdl: cute.Tensor + :param mLse_qdl: Tensor for lse + :type mLse_qdl: cute.Tensor + :param scale_softmax_log2: The log2 scale factor for softmax computation + :type scale_softmax_log2: cutlass.Float32 + :param scale_softmax: The scale factor for softmax (currently unused) + :type scale_softmax: cutlass.Float32 + :param scale_output: The scale factor for the final output + :type scale_output: cutlass.Float32 + :param window_size_left: Left-side sliding window size for attention masking + :type window_size_left: Optional[cutlass.Int32] + :param window_size_right: Right-side sliding window size for attention masking + :type window_size_right: Optional[cutlass.Int32] + :param q_smem_layout_staged: Shared memory layout for query tensor with staging + :type q_smem_layout_staged: cute.ComposedLayout + :param k_smem_layout_staged: Shared memory layout for key tensor with staging + :type k_smem_layout_staged: cute.ComposedLayout + :param v_smem_layout_staged: Shared memory layout for value tensor with staging + :type v_smem_layout_staged: cute.ComposedLayout + :param o_smem_layout_staged: Shared memory layout for output tensor with staging + :type o_smem_layout_staged: cute.ComposedLayout + :param tile_sched_params: Scheduling parameters for work distribution across blocks + :type tile_sched_params: fmha_utils.FmhaStaticTileSchedulerParams + """ + + tidx, _, _ = cute.arch.thread_idx() + + # Alloc + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + load_q_producer, load_q_consumer = self.make_and_init_load_q_pipeline( + storage.load_q_mbar_ptr.data_ptr() + ) + load_kv_producer, load_kv_consumer = self.make_and_init_load_kv_pipeline( + storage.load_kv_mbar_ptr.data_ptr() + ) + tma_store_pipeline = self.make_and_init_tma_store_pipeline() + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_group_idx = cute.arch.make_warp_uniform( + tidx // self.num_threads_per_warp_group + ) + + math_wg_order_barrier = self.make_and_init_order_barrier( + storage.MathWarpGroupOrderBarrier.data_ptr(), + warp_group_idx - 1, + ) + + # Generate smem tensor Q/K/V/O + # (MMA, MMA_Q, MMA_D, PIPE) + sQ = storage.sQ.get_tensor( + q_smem_layout_staged.outer, swizzle=q_smem_layout_staged.inner + ) + # (MMA, MMA_K, MMA_D, PIPE) + sK = storage.sK.get_tensor( + k_smem_layout_staged.outer, swizzle=k_smem_layout_staged.inner + ) + # (MMA, MMA_K, MMA_D, PIPE) + sV_ptr = cute.recast_ptr(sK.iterator, v_smem_layout_staged.inner) + sV = cute.make_tensor(sV_ptr, v_smem_layout_staged.outer) + + if cutlass.const_expr(self.is_persistent): + sO = storage.sO.get_tensor( + o_smem_layout_staged.outer, swizzle=o_smem_layout_staged.inner + ) + else: + sO = cute.make_tensor( + cute.recast_ptr(sQ.iterator, o_smem_layout_staged.inner, self.o_dtype), + o_smem_layout_staged.outer, + ) + + seqlen_q = mQ_qdl.shape[0] + gQ_qdl = cute.flat_divide(mQ_qdl, cute.select(self.qk_mma_tiler, mode=[0, 2])) + qk_thr_mma = qk_tiled_mma.get_slice(tidx) + tSgQ_qdl = qk_thr_mma.partition_A(gQ_qdl) + + tQsQ, tQgQ_qdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_q, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sQ, 0, 2), + cute.group_modes(tSgQ_qdl, 0, 3), + ) + + seqlen_k = mK_kdl.shape[0] + gK_kdl = cute.flat_divide(mK_kdl, cute.select(self.qk_mma_tiler, mode=[1, 2])) + tSgK_kdl = qk_thr_mma.partition_B(gK_kdl) + tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_k, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sK, 0, 2), + cute.group_modes(tSgK_kdl, 0, 3), + ) + + gV_dkl = cute.flat_divide(mV_dkl, cute.select(self.pv_mma_tiler, mode=[1, 2])) + pv_thr_mma = pv_tiled_mma.get_slice(tidx) + tSgV_dkl = pv_thr_mma.partition_B(gV_dkl) + tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_v, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sV, 0, 2), + cute.group_modes(tSgV_dkl, 0, 3), + ) + + producer_warp_role = warp_idx % 4 # self.num_warps_per_warp_group + + # We need this to guarantee that the Pipeline init is visible + # To all producers and consumer blocks in the Cluster + # and to finish smem init + if cute.size(self.cluster_shape_mnk) > 1: + cute.arch.cluster_arrive_relaxed() + cute.arch.cluster_wait() + else: + cute.arch.sync_threads() + + if warp_idx == 0: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_q) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_k) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_v) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_o) + + if warp_group_idx == self.load_warp_group_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_load) + + tile_sched = fmha_utils.create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + + q0_index = 0 + k_index = fmha_utils.FusedMask.get_trip_start( + self.mask_type, + curr_block_coord, + self.cta_tiler, + seqlen_q, + seqlen_k, + window_size_left, + ) + fusion_tile_count = fmha_utils.FusedMask.get_trip_count( + self.mask_type, + curr_block_coord, + self.cta_tiler, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + + q_tile_count = self.num_mma_warp_groups + k_tile_count = 2 * fusion_tile_count + + curr_block_coord_m = curr_block_coord[0] + _tQgQ = tQgQ_qdl[(None, None, 0, curr_block_coord[2])] + tQgQ = cute.domain_offset( + (0, curr_block_coord_m * self.num_mma_warp_groups), _tQgQ + ) + + if producer_warp_role == self.producer_warp_loadkv_id: + # LoadQ + if q_tile_count > 0: + q_handle = load_q_producer.acquire_and_advance() + cute.copy( + tma_atom_q, + tQgQ[(None, q0_index)], + tQsQ[(None, q_handle.index)], + tma_bar_ptr=q_handle.barrier, + ) + q0_index += 1 + + q_tile_count -= 1 + + tKgK = tKgK_kdl[(None, None, 0, curr_block_coord[2])] + tVgV = tVgV_dkl[(None, 0, None, curr_block_coord[2])] + + # Load K + if k_tile_count > 0: + k_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_k, + tKgK[(None, k_index)], + tKsK[(None, k_handle.index)], + tma_bar_ptr=k_handle.barrier, + ) + + k_tile_count -= 1 + + # Q1 + if q_tile_count > 0: + q_handle = load_q_producer.acquire_and_advance() + cute.copy( + tma_atom_q, + tQgQ[(None, q0_index)], + tQsQ[(None, q_handle.index)], + tma_bar_ptr=q_handle.barrier, + ) + q0_index += 1 + q_tile_count -= 1 + + # LoadV + if k_tile_count > 0: + k_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_v, + tVgV[(None, k_index)], + tVsV[(None, k_handle.index)], + tma_bar_ptr=k_handle.barrier, + ) + + k_index += 1 + k_tile_count -= 1 + + while k_tile_count > 0: + # Load KV + k_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_k, + tKgK[(None, k_index)], + tKsK[(None, k_handle.index)], + tma_bar_ptr=k_handle.barrier, + ) + + k_tile_count -= 1 + + v_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_v, + tVgV[(None, k_index)], + tVsV[(None, v_handle.index)], + tma_bar_ptr=v_handle.barrier, + ) + + k_index += 1 + k_tile_count -= 1 + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # Mainloop + if ( + warp_group_idx == self.compute_epilogue_0_warp_group_id + or warp_group_idx == self.compute_epilogue_1_warp_group_id + ): + cute.arch.warpgroup_reg_alloc(self.num_regs_mma) + + tile_sched = fmha_utils.create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + kOuterLoads = 1 + + cP = cute.make_identity_tensor((mQ_qdl.shape[0], seqlen_k)) + gPcP = cute.local_tile(cP, self.qk_mma_tiler[:2], (None, None)) + + while work_tile.is_valid_tile: + for i in cutlass.range((warp_group_idx - 1) * kOuterLoads, unroll=1): + load_q_consumer.advance() + + curr_block_coord = work_tile.tile_idx + + # _wg_coord_1 is work_tile.tile_idx[1], which is always 0. + _wg_coord_0 = self.num_mma_warp_groups * curr_block_coord[0] + ( + warp_group_idx - 1 + ) + _wg_coord_1 = curr_block_coord[1] + + wg_coord = (_wg_coord_0, _wg_coord_1, *curr_block_coord[2:]) + + # Mainloop setup QK + tSsQ = qk_thr_mma.partition_A(sQ) # (MMA,MMA_M,MMA_K,PIPE) + tSsK = qk_thr_mma.partition_B(sK) # (MMA,MMA_N,MMA_K,PIPE) + tSrQ = qk_thr_mma.make_fragment_A(tSsQ) # (MMA,MMA_M,MMA_K,PIPE) + tSrK = qk_thr_mma.make_fragment_B(tSsK) # (MMA,MMA_N,MMA_K,PIPE) + + # Prepare: MMA PV + thr_mma_pv = pv_tiled_mma.get_slice(tidx) + + # Mainloop setup PV + tOsV = thr_mma_pv.partition_B(sV) # (MMA,MMA_N,MMA_K,PIPE) + tOrV = thr_mma_pv.make_fragment_B(tOsV) # (MMA,MMA_M,MMA_N,PIPE) + + q_handle = load_q_consumer.wait() + + # mapping into QK accumulator + ptPcP = qk_thr_mma.partition_C(gPcP) + + # Allocate PV acc + pv_acc_shape = pv_thr_mma.partition_shape_C( + (self.pv_mma_tiler[0], self.pv_mma_tiler[1]) + ) + acc_pv = pv_thr_mma.make_fragment_C(pv_acc_shape) + + qk_acc_shape = qk_thr_mma.partition_shape_C( + (self.qk_mma_tiler[0], self.qk_mma_tiler[1]) + ) + + s_max_layout = cute.make_layout( + cute.size(layout_acc_mn(pv_tiled_mma, acc_pv.layout), mode=[0]) + ) + s_max = cute.make_rmem_tensor_like(s_max_layout, self.qk_acc_dtype) + a_sum = cute.make_rmem_tensor_like(s_max, cutlass.Float32) + + kv_offset = fmha_utils.FusedMask.get_trip_start( + self.mask_type, + curr_block_coord, + self.cta_tiler, + seqlen_q, + seqlen_k, + window_size_left, + ) + + masked_leading_count = fmha_utils.FusedMask.get_masked_leading_count( + self.mask_type, + curr_block_coord, + self.cta_tiler, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + unmasked_trip_count = fmha_utils.FusedMask.get_unmasked_trip_count( + self.mask_type, + curr_block_coord, + self.cta_tiler, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + + # mapping into QK accumulator + tPcP = cute.slice_(ptPcP, (None, None, None, wg_coord[0], kv_offset)) + kv_offset += 1 + + qk_acc_shape = qk_thr_mma.partition_shape_C( + (self.qk_mma_tiler[0], self.qk_mma_tiler[1]) + ) + + # Allocate QK acc + acc_qk = qk_thr_mma.make_fragment_C(qk_acc_shape) + k_handle = load_kv_consumer.wait_and_advance() + math_wg_order_barrier.wait() + + # MMA QK + cute.nvgpu.warpgroup.fence() + + gemm_zero_acc( + qk_tiled_mma, + tSrQ[(None, None, None, q_handle.index)], + tSrK[(None, None, None, k_handle.index)], + acc_qk, + ) + cute.nvgpu.warpgroup.commit_group() + + math_wg_order_barrier.arrive() + + # Wait for the pipeline MMAs to drain + cute.nvgpu.warpgroup.wait_group(0) + + s_max, a_sum = softmax_step( + True, + self.mask_type, + acc_qk, + qk_tiled_mma, + tPcP, + s_max, + a_sum, + acc_qk, + qk_tiled_mma, + scale_softmax_log2, + seqlen_k, + seqlen_q, + window_size_left, + window_size_right, + True, + ) + + acc_qk_fixed = make_acc_into_op( + acc_qk, pv_tiled_mma.tv_layout_A, self.q_dtype + ) + + v_handle = load_kv_consumer.wait_and_advance() + + # MMA PV + cute.nvgpu.warpgroup.fence() + + gemm_zero_acc( + pv_tiled_mma, + acc_qk_fixed, + tOrV[(None, None, None, v_handle.index)], + acc_pv, + ) + cute.nvgpu.warpgroup.commit_group() + cute.nvgpu.warpgroup.wait_group(0) + + k_handle.release() + v_handle.release() + + if masked_leading_count >= 1: + masked_leading_count -= 1 + load_kv_consumer, k_tile_count, kv_offset, s_max, a_sum = ( + self.compute( + True, + masked_leading_count, + qk_thr_mma, + acc_pv, + qk_tiled_mma, + pv_tiled_mma, + load_kv_consumer, + q_handle, + tSrQ, + tSrK, + s_max, + a_sum, + tOrV, + ptPcP, + wg_coord, + kv_offset, + scale_softmax_log2, + seqlen_k, + seqlen_q, + qk_acc_shape, + window_size_left, + window_size_right, + ) + ) + else: + unmasked_trip_count -= 1 + + load_kv_consumer, k_tile_count, kv_offset, s_max, a_sum = self.compute( + False, + unmasked_trip_count, + qk_thr_mma, + acc_pv, + qk_tiled_mma, + pv_tiled_mma, + load_kv_consumer, + q_handle, + tSrQ, + tSrK, + s_max, + a_sum, + tOrV, + ptPcP, + wg_coord, + kv_offset, + scale_softmax_log2, + seqlen_k, + seqlen_q, + qk_acc_shape, + window_size_left, + window_size_right, + ) + + k_tile_count = fmha_utils.FusedMask.get_masked_trailing_count( + self.mask_type, + curr_block_coord, + self.cta_tiler, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + k_tile_count, + ) + + # Use fusion in softmax + load_kv_consumer, k_tile_count, kv_offset, s_max, a_sum = self.compute( + True, + k_tile_count, + qk_thr_mma, + acc_pv, + qk_tiled_mma, + pv_tiled_mma, + load_kv_consumer, + q_handle, + tSrQ, + tSrK, + s_max, + a_sum, + tOrV, + ptPcP, + wg_coord, + kv_offset, + scale_softmax_log2, + seqlen_k, + seqlen_q, + qk_acc_shape, + window_size_left, + window_size_right, + ) + + if cutlass.const_expr(self.is_persistent): + q_handle.release() + + # Wait for the pipeline MMAs to drain + cute.nvgpu.warpgroup.wait_group(0) + + # acc_pv updated + lse = tail( + s_max, a_sum, acc_pv, pv_tiled_mma, scale_softmax, scale_output + ) + + if warp_group_idx == self.compute_epilogue_0_warp_group_id: + for i in cutlass.range_constexpr( + kOuterLoads * (self.num_mma_warp_groups - 0) + ): + load_q_consumer.advance() + + if cutlass.const_expr(self.num_mma_warp_groups >= 2): + if warp_group_idx == self.compute_epilogue_1_warp_group_id: + for i in cutlass.range_constexpr( + kOuterLoads * (self.num_mma_warp_groups - 1) + ): + load_q_consumer.advance() + + math_wg_order_barrier.wait() + + # store log-sum-exp (LSE) + thr_mma = pv_tiled_mma.get_slice(tidx) + + gLSE_full = cute.local_tile( + mLse_qdl, self.pv_mma_tiler[:2], (None, None, None) + ) + gLSE = cute.slice_( + gLSE_full, (None, None, wg_coord[0], wg_coord[1], wg_coord[2]) + ) + + tOgLSE = thr_mma.partition_C(gLSE) + cO = cute.make_identity_tensor( + (self.pv_mma_tiler[0], self.pv_mma_tiler[1]) + ) + tOcO = thr_mma.partition_C(cO) + + if tOcO[0][1] == 0: + tOgLSE_mn = cute.make_tensor( + tOgLSE.iterator, layout_acc_mn(pv_tiled_mma, tOgLSE.layout) + ) + tOcO_mn = cute.make_tensor( + tOcO.iterator, layout_acc_mn(pv_tiled_mma, tOcO.layout) + ) + for i in cutlass.range_constexpr(cute.size(tOgLSE_mn, mode=[0])): + if ( + tOcO_mn[i][0] + wg_coord[0] * self.pv_mma_tiler[0] + < seqlen_q + ): + tOgLSE_mn[(i, 0)] = lse[i] + + # Epilogue + cO = cute.make_identity_tensor((self.cta_tiler[0], self.cta_tiler[2])) + copy_atom_r2s = sm90_utils.sm90_get_smem_store_op( + self.o_layout, + elem_ty_d=self.o_dtype, + elem_ty_acc=self.pv_acc_dtype, + ) + + copy_atom_O = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp( + self.o_layout.is_m_major_c(), + 4, + ), + self.o_dtype, + ) + + tiled_copy_O_Atom = cute.make_tiled_copy_C_atom( + copy_atom_O, pv_tiled_mma + ) + + tiled_copy_r2s = cute.make_tiled_copy_S( + copy_atom_r2s, + tiled_copy_O_Atom, + ) + + thr_copy_r2s = tiled_copy_r2s.get_slice( + tidx % self.num_threads_per_warp_group + ) + tRS_sD = thr_copy_r2s.partition_D(sO) + tRS_rAcc = tiled_copy_r2s.retile(acc_pv) + + # Allocate D registers. + rD_shape = cute.shape(thr_copy_r2s.partition_S(sO)) + tRS_rD_layout = cute.make_layout(rD_shape[:3]) + + tRS_rD = cute.make_rmem_tensor_like(tRS_rD_layout, self.pv_acc_dtype) + size_tRS_rD = cute.size(tRS_rD) + + gD = cute.local_tile( + mO_qdl, + self.pv_mma_tiler[:2], + (wg_coord[0], 0, wg_coord[2]), + ) + + sepi_for_tma_partition = cute.group_modes(sO, 0, 2) + tcgc_for_tma_partition = cute.zipped_divide(gD, self.epi_tile) + + bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition( + tma_atom_o, + 0, + cute.make_layout(1), + sepi_for_tma_partition, + tcgc_for_tma_partition, + ) + + epi_tile_num = cute.size(tcgc_for_tma_partition, mode=[1]) + + for epi_idx in cutlass.range_constexpr(epi_tile_num): + # Copy from accumulators to D registers + for epi_v in cutlass.range_constexpr(size_tRS_rD): + tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v] + + # Type conversion + tRS_rD_out = cute.make_rmem_tensor_like(tRS_rD_layout, self.o_dtype) + acc_vec = tRS_rD.load() + tRS_rD_out.store(acc_vec.to(self.o_dtype)) + + # Copy from D registers to shared memory + epi_buffer = epi_idx % self.epi_stage + cute.copy( + tiled_copy_r2s, + tRS_rD_out, + tRS_sD[(None, None, None, epi_buffer, warp_group_idx - 1)], + ) + + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + pipeline.arrive_and_wait( + barrier_id=warp_group_idx, + num_threads=self.num_threads_per_warp_group, + ) + + # only one warp in each warpgroup copy shared memory to global memory + if warp_idx == 4 or warp_idx == 8: + cute.copy( + tma_atom_o, + bSG_sD[(None, epi_buffer, warp_group_idx - 1)], + bSG_gD[(None, epi_idx)], + ) + + tma_store_pipeline.producer_commit() + tma_store_pipeline.producer_acquire() + + pipeline.arrive_and_wait( + barrier_id=warp_group_idx, + num_threads=self.num_threads_per_warp_group, + ) + + math_wg_order_barrier.arrive() + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + return + + @cute.jit + def compute( + self, + fusion: bool, + k_tile_count: cutlass.Int32, + qk_thr_mma: cute.ThrMma, + acc_pv: cute.ThrMma, + qk_tiled_mma: cute.TiledMma, + pv_tiled_mma: cute.TiledMma, + load_kv_consumer: pipeline.PipelineConsumer, + q_handle: pipeline.PipelineConsumer.ImmutableResourceHandle, + tSrQ: cute.Tensor, + tSrK: cute.Tensor, + s_max: cute.Tensor, + a_sum: cute.Tensor, + tOrV: cute.Tensor, + ptPcP: cute.Tensor, + wg_coord: tuple, + kv_offset: cutlass.Int32, + scale_softmax_log2: cutlass.Float32, + seqlen_k: cutlass.Int32, + seqlen_q: cutlass.Int32, + qk_acc_shape: cute.Shape, + window_size_left: Optional[cutlass.Int32], + window_size_right: Optional[cutlass.Int32], + ) -> Tuple[ + pipeline.PipelineConsumer, + cutlass.Int32, + cutlass.Int32, + cute.Tensor, + cute.Tensor, + ]: + while k_tile_count > 0: + k_tile_count -= 1 + + tPcP = cute.slice_(ptPcP, (None, None, None, wg_coord[0], kv_offset)) + kv_offset += 1 + + # Allocate QK acc + acc_qk = qk_thr_mma.make_fragment_C(qk_acc_shape) + + k_handle = load_kv_consumer.wait_and_advance() + + # MMA QK + cute.nvgpu.warpgroup.fence() + + gemm_zero_acc( + qk_tiled_mma, + tSrQ[(None, None, None, q_handle.index)], + tSrK[(None, None, None, k_handle.index)], + acc_qk, + ) + + cute.nvgpu.warpgroup.commit_group() + + tok = load_kv_consumer.try_wait() + + # Wait for the pipeline MMAs to drain + cute.nvgpu.warpgroup.wait_group(0) + + s_max, a_sum = softmax_step( + fusion, + self.mask_type, + acc_qk, + qk_tiled_mma, + tPcP, + s_max, + a_sum, + acc_pv, + pv_tiled_mma, + scale_softmax_log2, + seqlen_k, + seqlen_q, + window_size_left, + window_size_right, + ) + + acc_qk_fixed = make_acc_into_op( + acc_qk, pv_tiled_mma.tv_layout_A, self.q_dtype + ) + + v_handle = load_kv_consumer.wait_and_advance(tok) + + # MMA PV + cute.nvgpu.warpgroup.fence() + + pv_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + cute.gemm( + pv_tiled_mma, + acc_pv, + acc_qk_fixed, + tOrV[(None, None, None, v_handle.index)], + acc_pv, + ) + + cute.nvgpu.warpgroup.commit_group() + cute.nvgpu.warpgroup.wait_group(0) + + k_handle.release() + v_handle.release() + + return load_kv_consumer, k_tile_count, kv_offset, s_max, a_sum + + @cute.jit + def softmax_step( + fusion: bool, + mask_type: fmha_utils.MaskEnum, + acc_qk: cute.ThrMma, + tiled_mma_qk: cute.TiledMma, + count_qk: cute.Tensor, + s_max: cute.Tensor, + a_sum: cute.Tensor, + acc_pv: cute.ThrMma, + tiled_mma_pv: cute.TiledMma, + scale_softmax_log2: cutlass.Float32, + seqlen_k: cutlass.Int32, + seqlen_q: cutlass.Int32, + window_size_left: Optional[cutlass.Int32], + window_size_right: Optional[cutlass.Int32], + is_first_iter: bool = False, + ) -> Tuple[cute.Tensor, cute.Tensor]: + if cutlass.const_expr(fusion): + fmha_utils.FusedMask.apply_mask( + mask_type, + acc_qk, + count_qk, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + + acc_qk_mn = cute.make_tensor( + acc_qk.iterator, layout_acc_mn(tiled_mma_qk, acc_qk.layout) + ) + + reduction_target_qk = reduction_target_n(tiled_mma_qk) + red_rank = cute.rank(reduction_target_qk) + + s_max_prev = None + acc_pv_mn = None + if cutlass.const_expr(is_first_iter): + # Linear reduction is faster for the first iteration + for i in cutlass.range_constexpr(cute.size(acc_qk_mn, mode=[0])): + s_max[i] = acc_qk_mn[i, 0] + + for j in cutlass.range_constexpr(1, cute.size(acc_qk_mn, mode=[1])): + for i in cutlass.range_constexpr(cute.size(acc_qk_mn, mode=[0])): + s_max[i] = cute.arch.fmax(s_max[i], acc_qk_mn[i, j]) + else: + acc_pv_mn = cute.make_tensor( + acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout) + ) + s_max_prev = cute.make_rmem_tensor_like(s_max, s_max._dtype) + + for i in cutlass.range_constexpr(cute.size(acc_qk_mn, mode=[0])): + if cutlass.const_expr(not is_first_iter): + s_max_prev[i] = s_max[i] + + # Linear reduction is faster here, as well + for j in cutlass.range_constexpr(cute.size(acc_qk_mn, mode=[1])): + s_max[i] = cutlass.max(s_max[i], acc_qk_mn[i, j]) + + # reduce max + for r in cutlass.range_constexpr(red_rank): + s_max[i] = cute.arch.warp_reduction_max( + s_max[i], threads_in_group=reduction_target_qk.shape[r] + ) + + local_max = s_max[i] + if s_max[i] == -cutlass.Float32.inf: + local_max = 0.0 + scale_max = scale_softmax_log2 * local_max + + for j in cutlass.range_constexpr(cute.size(acc_qk_mn, mode=[1])): + acc_qk_mn[i, j] = cute.math.exp2( + scale_softmax_log2 * acc_qk_mn[i, j] - scale_max, fastmath=True + ) + + _a_sum = 0.0 + if cutlass.const_expr(not is_first_iter): + s_max_cur = s_max[i] + if s_max[i] == -cutlass.Float32.inf: + s_max_cur = 0.0 + scale_pv = cute.math.exp2( + (s_max_prev[i] - s_max_cur) * scale_softmax_log2, fastmath=True + ) + a_sum[i] *= scale_pv + + for j in cutlass.range_constexpr(cute.size(acc_pv_mn, mode=[1])): + acc_pv_mn[i, j] *= scale_pv + + _a_sum = a_sum[i] + + a_sum[i] = _a_sum + acc_qk_mn[i, None].load().reduce( + cute.ReductionOp.ADD, cutlass.Float32.zero, 0 + ) + + return s_max, a_sum + + @cute.jit + def reduction_target_n(tiled_mma): + separated = layout_separate( + tiled_mma.shape_mnk[0], + cute.make_layout(tiled_mma.tv_layout_C.shape[0]), + tiled_mma.tv_layout_C.stride[0], + ) + return separated[1] + + @cute.jit + def convert_c_layout_to_a_layout(c, a): + return cute.make_layout( + (a, c.shape[1], (c.shape[2], cute.size(c, mode=[0]) // cute.size(a))), + stride=( + c.stride[0], + c.stride[1], + (c.stride[2], cute.size(a, mode=[2]) * c.stride[0][2]), + ), + ) + + @cute.jit + def make_acc_into_op(acc, operand_layout_tv, Element): + operand = cute.make_rmem_tensor_like( + convert_c_layout_to_a_layout(acc.layout, operand_layout_tv.shape[1]), + Element, + ) + operand_as_acc = cute.make_tensor(operand.iterator, acc.layout) + acc_vec = acc.load() + operand_as_acc.store(acc_vec.to(Element)) + + if cutlass.const_expr(Element.width == 8 and True): + ## 00 11 22 33 00 11 22 33 acc layout + ## 00 00 11 11 22 22 33 33 operand layout + ## BB AA AA BB AA BB BB AA conflict-free exchange pattern + ## 16-bit exchange; so process two at a time potentially + # int tid = threadIdx.x % 4; + tidx, _, _ = cute.arch.thread_idx() + tid = tidx % 4 + values_u32 = cute.recast_tensor(operand, cutlass.Uint32) + for n in cutlass.range_constexpr(cute.size(values_u32, mode=[1])): + for k in cutlass.range_constexpr(cute.size(values_u32, mode=[2])): + for ii in cutlass.range_constexpr(0, 8, 4): + values_tmp_0 = values_u32[ii // 2 + 0, n, k] + values_tmp_1 = values_u32[ii // 2 + 1, n, k] + + ## step A: + ## t 1 v 0 -> t 0 v 1 + ## t 2 v 0 -> t 1 v 0 + ## t 0 v 1 -> t 2 v 0 + ## t 3 v 1 -> t 3 v 1 + + v_to_send = 1 + if tid == 1 or tid == 2: + v_to_send = 0 + + v_to_recv = v_to_send + t_to_recv_from = (0x3021 >> (tid * 4)) & 0xF + + values_tmp_a = values_tmp_1 + if v_to_send == 0: + values_tmp_a = values_tmp_0 + + values_tmp_a = cute.arch.shuffle_sync_op( + values_tmp_a, t_to_recv_from, 0xFFFFFFFF, 7199 + ) + + # step B: + # t 0 v 0 -> t 0 v 0 + # t 3 v 0 -> t 1 v 1 + # t 1 v 1 -> t 2 v 1 + # t 2 v 1 -> t 3 v 0 + + v_to_send = 1 - v_to_send + v_to_recv = 1 - v_to_recv + t_to_recv_from = (0x2130 >> (tid * 4)) & 0xF + + values_tmp_b = values_tmp_1 + if v_to_send == 0: + values_tmp_b = values_tmp_0 + + values_tmp_b = cute.arch.shuffle_sync_op( + values_tmp_b, t_to_recv_from, 0xFFFFFFFF, 7199 + ) + + # __byte_perm + order = 0x5410 + if v_to_send == 0: + order = 0x1054 + + values_u32[ii // 2 + 0, n, k] = cute.arch.prmt( + values_tmp_a, + values_tmp_b, + order, + ) + + order = 0x7632 + if v_to_send == 0: + order = 0x3276 + values_u32[ii // 2 + 1, n, k] = cute.arch.prmt( + values_tmp_a, values_tmp_b, order + ) + return operand + + @cute.jit + def tail(s_max, a_sum, acc_pv, tiled_mma_pv, scale_softmax, scale_output): + """ + Final processing step for FMHA that computes log-sum-exp (LSE) and scales the output. + + This function performs the following operations: + 1. Reduces the attention sums across warps using butterfly shuffle + 2. Computes the log-sum-exp (LSE) for numerical stability + 3. Applies softmax scaling and output scaling to the accumulated values + 4. Handles edge cases like zero sums and NaN values + + :param s_max: Maximum attention scores for each position (for numerical stability) + :type s_max: cute.Tensor + :param a_sum: Sum of attention scores after softmax + :type a_sum: cute.Tensor + :param acc_pv: Accumulated P*V values from the attention computation + :type acc_pv: cute.ThrMma + :param tiled_mma_pv: Tiled MMA for P*V computation + :type tiled_mma_pv: cute.TiledMma + :param scale_softmax: Scaling factor for softmax computation + :type scale_softmax: cutlass.Float32 + :param scale_output: Scaling factor for final output + :type scale_output: cutlass.Float32 + + :return: Log-sum-exp values for each position + :rtype: cute.Tensor + """ + # Create tensor view of accumulated P*V values with M*N layout + acc_pv_mn = cute.make_tensor( + acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout) + ) + reduction_target = reduction_target_n(tiled_mma_pv) + red_rank = cute.rank(reduction_target) + for r in cutlass.range_constexpr(red_rank): + for i in cutlass.range_constexpr(cute.size(acc_pv_mn, mode=[0])): + a_sum[i] = cute.arch.warp_reduction_sum( + a_sum[i], threads_in_group=reduction_target.shape[r] + ) + + acc_mn = cute.make_tensor( + acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout) + ) + + lse = cute.make_rmem_tensor_like(a_sum, a_sum._dtype) + for i in cutlass.range_constexpr(cute.size(acc_mn, mode=[0])): + sum = a_sum[i] + inv_sum = cute.arch.rcp_approx(sum) + if sum == 0.0 or sum != sum: + inv_sum = 1.0 + + lse[i] = s_max[i] * scale_softmax + _math.log(sum) + if sum == 0.0 or sum != sum: + lse[i] = cutlass.Float32.inf + + rp_dropout = 1 + scale = rp_dropout * inv_sum + for j in cutlass.range_constexpr(cute.size(acc_mn, mode=[1])): + acc_mn[i, j] *= scale * scale_output + + return lse + + @cute.jit + def layout_separate(thr, src, ref): + lt = cute.make_layout(()) + ge = cute.make_layout(()) + + for k, v in enumerate(ref): + if cutlass.const_expr(v < thr): + lt = cute.append(lt, src[k]) + else: + ge = cute.append(ge, src[k]) + + r = None + if cutlass.const_expr(cute.rank(lt) == 1): + r = cute.append(lt, ge) + else: + r = cute.append(cute.append(cute.make_layout(()), lt), ge) + return r + + @cute.jit + def gemm_zero_acc(tiled_mma, A, B, C): + rA = cute.rank(A) + rB = cute.rank(B) + rC = cute.rank(C) + if cutlass.const_expr(rA == 2 and rB == 2 and rC == 1): + for k_block_idx in range(cute.size(A, mode=[1]), unroll_full=True): + tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, k_block_idx != 0) + cute.gemm( + tiled_mma, + C, + A[None, k_block_idx], + B[None, k_block_idx], + C, + ) + elif cutlass.const_expr(rA == 3 and rB == 3 and rC == 3): + for k_block_idx in range(cute.size(A, mode=[2]), unroll_full=True): + tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, k_block_idx != 0) + cute.gemm( + tiled_mma, + C, + A[None, None, k_block_idx], + B[None, None, k_block_idx], + C, + ) + else: + assert 0 + + @cute.jit + def layout_acc_mn(tiled_mma, acc): + separated = layout_separate( + tiled_mma.shape_mnk[0], acc[0], tiled_mma.tv_layout_C.stride[1] + ) + + V_M = separated[0] + V_N = separated[1] + V_M1 = None + V_N1 = None + if cutlass.const_expr(cute.rank(V_M) == 1): + V_M1 = cute.append(V_M, acc[1]) + else: + V_M1 = cute.append(cute.append(cute.make_layout(()), V_M), acc[1]) + + if cutlass.const_expr(cute.rank(V_N) == 1): + V_N1 = cute.append(V_N, acc[2]) + else: + V_N1 = cute.append(cute.append(cute.make_layout(()), V_N), acc[2]) + r = None + if cutlass.const_expr(cute.rank(V_M1) == 1): + r = cute.append(V_M1, V_N1) + else: + r = cute.append(cute.append(cute.make_layout(()), V_M1), V_N1) + return r + + def make_and_init_load_q_pipeline(self, load_q_mbar_ptr): + load_q_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.load_warp_group_id]), + ) + load_q_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_warps_per_warp_group, + ) + return pipeline.PipelineTmaAsync.create( + barrier_storage=load_q_mbar_ptr, + num_stages=self.q_stage, + producer_group=load_q_producer_group, + consumer_group=load_q_consumer_group, + tx_count=self.tma_copy_q_bytes, + ).make_participants() + + def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): + load_kv_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.load_warp_group_id]), + ) + load_kv_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_mma_warp_groups * self.num_warps_per_warp_group, + ) + return pipeline.PipelineTmaAsync.create( + barrier_storage=load_kv_mbar_ptr, + num_stages=self.kv_stage, + producer_group=load_kv_producer_group, + consumer_group=load_kv_consumer_group, + tx_count=self.tma_copy_kv_bytes, + ).make_participants() + + def make_and_init_tma_store_pipeline(self): + tma_store_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 1, + ) + return pipeline.PipelineTmaStore.create( + num_stages=self.epi_stage, + producer_group=tma_store_producer_group, + ) + + def make_and_init_order_barrier(self, order_mbar_ptr, group_id): + StagesPerMathWarpGroup = 1 + return pipeline.PipelineOrder.create( + barrier_storage=order_mbar_ptr, + depth=StagesPerMathWarpGroup, + length=self.num_mma_warp_groups, + group_id=group_id, + producer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_threads_per_warp_group, + ), + ) + + @staticmethod + def _make_tma_atoms_and_tensors( + tensor: cute.Tensor, + smem_layout_staged: cute.ComposedLayout, + smem_tile: tuple[int, int], + mcast_dim: int, + ) -> tuple[cute.CopyAtom, cute.Tensor]: + """Create TMA atoms and tensors for input tensors. + + :param tensor: Input tensor (A or B) + :type tensor: cute.Tensor + :param smem_layout_staged: Shared memory layout for the tensor + :type smem_layout_staged: cute.ComposedLayout + :param smem_tile: Shared memory tile shape + :type smem_tile: Tuple[int, int] + :param mcast_dim: Multicast dimension + :type mcast_dim: int + + :return: TMA atom and tensor + :rtype: Tuple[cute.CopyAtom, cute.Tensor] + """ + op = ( + cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() + if mcast_dim == 1 + else cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp() + ) + + smem_layout = cute.slice_(smem_layout_staged, (None, None, 0)) + tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom( + op, + tensor, + smem_layout, + smem_tile, + num_multicast=mcast_dim, + ) + return tma_atom, tma_tensor + + @staticmethod + def can_implement( + q_shape: Tuple[int, int, int, int], + k_shape: Tuple[int, int, int, int], + in_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + qk_acc_dtype: Type[cutlass.Numeric], + pv_acc_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + is_persistent: bool, + scale_softmax: float, + window_size: Tuple[int, int], + iterations: int, + ) -> Tuple[bool, str]: + """Check if the FMHA kernel can be implemented with the given parameters. + + This method validates that the input parameters are compatible with the Hopper + Fused Multi-Head Attention implementation. It checks tensor shapes, data types, + window sizes, and other constraints to ensure the kernel can be successfully + compiled and executed. + + :param q_shape: Query tensor shape (B, S_q, H, D) where B=batch size, S_q=query sequence length, + H=number of heads, D=head dimension + :type q_shape: Tuple[int, int, int, int] + :param k_shape: Key tensor shape (B, S_k, H_k, D) where B=batch size, S_k=key sequence length, + H_k=number of key heads, D=head dimension + :type k_shape: Tuple[int, int, int, int] + :param in_dtype: Input data type for query, key and value tensors + :type in_dtype: Type[cutlass.Numeric] + :param out_dtype: Output data type for attention output + :type out_dtype: Type[cutlass.Numeric] + :param qk_acc_dtype: Accumulator data type for query-key matrix multiplication + :type qk_acc_dtype: Type[cutlass.Numeric] + :param pv_acc_dtype: Accumulator data type for probability-value matrix multiplication + :type pv_acc_dtype: Type[cutlass.Numeric] + :param mma_tiler_mn: Matrix multiply accumulate tile shape (M, N) + :type mma_tiler_mn: Tuple[int, int] + :param is_persistent: Whether to use persistent kernel optimization + :type is_persistent: bool + :param scale_softmax: Attention score scaling factor + :type scale_softmax: float + :param window_size: Sliding window size (left, right) for attention masking + :type window_size: Tuple[int, int] + :param iterations: Number of iterations to run for performance testing + :type iterations: int + + :return: Tuple of (can_implement, error_message) where can_implement is True if the kernel + can be implemented, False otherwise, and error_message contains the reason for failure + :rtype: Tuple[bool, str] + """ + + # Unpack parameters + b, s_q, h, d = q_shape + b_, s_k, h_k, d_ = k_shape + window_size_left, window_size_right = window_size + + if b != b_: + return False, "q & k must have the same batch size" + + if d != d_: + return False, "q & k must have the same head dimension" + + if window_size_left >= s_k - 1: + return False, "window_size_left must be less than s_k_max - 1" + if window_size_right >= s_q - 1: + return False, "window_size_right must be less than s_q_max - 1" + + if h % h_k != 0: + return False, "h must be divisible by h_k" + + if in_dtype not in {cutlass.Float8E4M3FN, cutlass.Float16, cutlass.BFloat16}: + return False, "in_dtype must be Float16, BFloat16, Float8E4M3FN" + + if out_dtype not in {cutlass.Float8E4M3FN, cutlass.Float16, cutlass.BFloat16}: + return False, "out_dtype must be Float16, BFloat16, Float8E4M3FN" + + if qk_acc_dtype not in {cutlass.Float32}: + return False, "qk_acc_dtype must be Float32" + + if pv_acc_dtype not in {cutlass.Float32}: + return False, "pv_acc_dtype must be Float32" + + if iterations < 1: + return False, "iterations must be at least 1" + + if ( + in_dtype.width == 16 + and out_dtype.width == 16 + and ( + (d_ == 256 and mma_tiler_mn[1] >= 128) + or (d_ == 128 and mma_tiler_mn[1] >= 256) + ) + ) or ( + in_dtype.width == 8 + and out_dtype.width == 8 + and d_ == 256 + and mma_tiler_mn[1] >= 256 + ): + return False, "not enough smem" + + if is_persistent and ( + ( + in_dtype.width == 16 + and out_dtype.width == 16 + and ( + (d_ == 128 and mma_tiler_mn[1] >= 256) + or (d_ == 256 and mma_tiler_mn[1] > 32) + ) + ) + or ( + in_dtype.width == 8 + and out_dtype.width == 8 + and d_ == 256 + and mma_tiler_mn[1] == 256 + ) + ): + return False, "not supported persistent" + + return True, None + + +def run( + q_shape: Tuple[int, int, int, int], + k_shape: Tuple[int, int, int, int], + in_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + qk_acc_dtype: Type[cutlass.Numeric], + pv_acc_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + is_persistent: bool, + is_causal: bool, + bottom_right_align: bool, + scale_q: float, + scale_k: float, + scale_v: float, + inv_scale_o: float, + scale_softmax: float, + window_size: Tuple[int, int], + tolerance: float, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, +): + """Execute Fused Multi-Head Attention (FMHA) on Hopper architecture and validate results. + + This function creates random input tensors for query, key, and value, then performs the + complete FMHA computation pipeline. It supports configurable data types, tiling parameters, + and various attention masking options. Results can be validated against a PyTorch reference + implementation or run multiple times for performance measurement. + + The implementation leverages specialized tensor memory operations and efficient math + operations optimized for Hopper architecture, including pipelined computation stages + for maximum throughput. + + :param q_shape: Query tensor shape (B, S_q, H, D) where B=batch size, S_q=query sequence length, + H=number of heads, D=head dimension. + If S_q is a tuple, it is the variable sequence length. + :type q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int] + :param k_shape: Key tensor shape (B, S_k, H_k, D) where B=batch size, S_k=key sequence length, + H_k=number of key heads (H must be divisible by H_k), D=head dimension. + If S_k is a tuple, it is the variable sequence length. + :type k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int] + :param in_dtype: Input data type for query, key and value tensors + :type in_dtype: Type[cutlass.Numeric] + :param out_dtype: Output data type for attention output + :type out_dtype: Type[cutlass.Numeric] + :param qk_acc_dtype: Accumulator data type for query-key matrix multiplication + :type qk_acc_dtype: Type[cutlass.Numeric] + :param pv_acc_dtype: Accumulator data type for probability-value matrix multiplication + :type pv_acc_dtype: Type[cutlass.Numeric] + :param mma_tiler_mn: Matrix multiply accumulate tile shape (M, N) + :type mma_tiler_mn: Tuple[int, int] + :param is_persistent: Whether to use persistent kernel optimization + :type is_persistent: bool + :param is_causal: Whether to apply causal masking + :type is_causal: bool + :param bottom_right_align: Whether to use bottom right align, under this settion, the end of q is aligned with the end of k. + :type bottom_right_align: bool + :param scale_q: Scaling factor for query tensor + :type scale_q: float + :param scale_k: Scaling factor for key tensor + :type scale_k: float + :param scale_v: Scaling factor for value tensor + :type scale_v: float + :param inv_scale_o: Inverse scaling factor for output tensor + :type inv_scale_o: float + :param scale_softmax: Attention score scaling factor (defaults to 1/sqrt(D) if set to 0) + :type scale_softmax: float + :param window_size: Sliding window size (left, right) for attention masking. Controls which positions each query can attend to. Negative values disable windowing. + :type window_size: Tuple[int, int] + :param tolerance: Maximum acceptable error for validation + :type tolerance: float + :param warmup_iterations: Number of warmup iterations + :type warmup_iterations: int + :param iterations: Number of iterations to run for performance testing + :type iterations: int + :param skip_ref_check: Skip validation against reference implementation + :type skip_ref_check: bool + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache + :type use_cold_l2: bool + + :raises ValueError: If input shapes are incompatible or head dimension is unsupported + :raises RuntimeError: If GPU is unavailable for computation + :return: Execution time of the FMHA kernel in microseconds + :rtype: float + """ + print("Running Hopper SM90 FMHA test with:") + print(f" q_shape: {q_shape}") + print(f" k_shape: {k_shape}") + print(f" in_dtype: {in_dtype}") + print(f" out_dtype: {out_dtype}") + print(f" qk_acc_dtype: {qk_acc_dtype}") + print(f" pv_acc_dtype: {pv_acc_dtype}") + print(f" mma_tiler_mn: {mma_tiler_mn}") + print(f" is_persistent: {is_persistent}") + print(f" is_causal: {is_causal}") + print(f" bottom_right_align: {bottom_right_align}") + print(f" scale_q: {scale_q}") + print(f" scale_k: {scale_k}") + print(f" scale_v: {scale_v}") + print(f" inv_scale_o: {inv_scale_o}") + print(f" scale_softmax: {scale_softmax}") + print(f" window_size: {window_size}") + print(f" tolerance: {tolerance}") + print(f" skip_ref_check: {skip_ref_check}") + print(f" use_cold_l2: {use_cold_l2}") + + # Prepare pytorch tensors: Q, K, V (random from 0 to 2) and O (all zero) + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + ret, msg = HopperFusedMultiHeadAttentionForward.can_implement( + q_shape, + k_shape, + in_dtype, + out_dtype, + qk_acc_dtype, + pv_acc_dtype, + mma_tiler_mn, + is_persistent, + scale_softmax, + window_size, + iterations, + ) + if not ret: + raise TypeError(msg) + + # Unpack parameters + b, s_q, h, d = q_shape + b_, s_k, h_k, d_ = k_shape + window_size_left, window_size_right = window_size + if window_size_left == -1: + window_size_left = None + if window_size_right == -1: + window_size_right = None + + h_r = h // h_k + + torch.manual_seed(1111) + + def create_and_permute_tensor( + b, s, h_k, h_r, d, dtype, is_dynamic_layout=True, tensor_name="" + ): + # (b, s, h_k, h_r, d) -> (s, d, h_r, h_k, b) + # torch SPDA order is (h_k, h_r), then kernel is (h_r, h_k) + shape = (b, s, h_k, h_r, d) + permute_order = (1, 4, 3, 2, 0) + is_fp8 = dtype in {cutlass.Float8E4M3FN} + leading_dim = 1 + if is_fp8 and tensor_name == "v": + permute_order = (4, 1, 3, 2, 0) + leading_dim = 0 + shape = (b, d, h_k, h_r, s) + + # torch does not support fp8 type + torch_dtype = cutlass.torch.dtype(dtype) if not is_fp8 else torch.int8 + + # Create dtype torch tensor (cpu) + torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=cutlass.torch.RandomInitConfig( + min_val=-2, + max_val=2, + ), + ) + # Create dtype torch tensor (gpu) + torch_tensor_gpu = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + cute_tensor = from_dlpack(torch_tensor_gpu, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) + cute_tensor = cutlass_torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor_gpu + + q_ref, q_tensor, q_torch = create_and_permute_tensor( + b, s_q, h_k, h_r, d, in_dtype, is_dynamic_layout=True + ) + k_ref, k_tensor, k_torch = create_and_permute_tensor( + b, s_k, h_k, 1, d, in_dtype, is_dynamic_layout=True + ) + v_ref, v_tensor, v_torch = create_and_permute_tensor( + b, s_k, h_k, 1, d, in_dtype, is_dynamic_layout=True, tensor_name="v" + ) + o_ref, o_tensor, o_torch = create_and_permute_tensor( + b, s_q, h_k, h_r, d, out_dtype, is_dynamic_layout=True + ) + lse_ref, lse_tensor, lse_torch = create_and_permute_tensor( + b, s_q, h_k, h_r, 1, qk_acc_dtype, is_dynamic_layout=True + ) + + mma_tiler = (*mma_tiler_mn, d) + + mask_type = fmha_utils.MaskEnum.WINDOW_MASK + if bottom_right_align: + mask_type = fmha_utils.MaskEnum.WINDOW_MASK_INFERENCE + if is_causal: + window_size_right = 0 + elif window_size_left is None and window_size_right is None: + if s_k % mma_tiler_mn[1] != 0: + mask_type = fmha_utils.MaskEnum.RESIDUAL_MASK + + # To avoid mask out the whole row which results in NaN in softmax + def check_seqlen_valid( + s_q, s_k, window_size_left, window_size_right, bottom_right_align + ): + for i in range(s_q): + offset = 0 if not bottom_right_align else s_k - s_q + + s_q_start = 0 if window_size_left is None else i + offset - window_size_left + s_q_end = ( + s_q if window_size_right is None else i + offset + window_size_right + ) + s_q_min = max(s_q_start, 0) + s_q_max = min(s_q_end, s_k) + + if s_q_max - s_q_min == 0 and (i != 0 and i != s_q - 1): + return False + return True + + need_check_seqlen_valid = ( + window_size_left is not None or window_size_right is not None + ) + if need_check_seqlen_valid and not check_seqlen_valid( + s_q, + s_k, + window_size_left, + window_size_right, + bottom_right_align, + ): + raise ValueError("sliding window doesn't support current setting") + + fmha = HopperFusedMultiHeadAttentionForward( + qk_acc_dtype, + pv_acc_dtype, + mma_tiler, + is_persistent, + mask_type, + ) + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + + if scale_softmax == 0.0: # default to 1/sqrt(d) + scale_softmax = 1.0 / math.sqrt(q_shape[1]) + + scale_softmax = scale_q * scale_k * scale_softmax + + LOG2_E = 1.4426950408889634074 + scale_softmax_log2 = scale_softmax * LOG2_E + scale_output = scale_v * inv_scale_o + + print("Compiling kernel with cute.compile ...") + start_time = time.time() + # compile fmha kernel + compiled_fmha = cute.compile( + fmha, + q_tensor, + k_tensor, + v_tensor, + o_tensor, + lse_tensor, + scale_softmax_log2, + scale_softmax, + scale_output, + ( + window_size_left + if window_size_left is None + else cutlass.Int32(window_size_left) + ), + ( + window_size_right + if window_size_right is None + else cutlass.Int32(window_size_right) + ), + current_stream, + ) + compilation_time = time.time() - start_time + print(f"Compilation time: {compilation_time:.4f} seconds") + + def run_torch_fmha( + q, + k, + v, + scale_softmax=1.0, + scale_output=1.0, + is_causal=False, + window_size_left=None, + window_size_right=None, + ): + s_q, d, h_r, h_k, b = q.shape + s_k = k.shape[0] + + # broadcast k and v to have the same shape as q + k = k.expand(s_k, d, h_r, h_k, b) + v = v.expand(s_k, d, h_r, h_k, b) + + q_tmp = q.permute(4, 2, 3, 0, 1).contiguous().view(b, -1, s_q, d) + k_tmp = k.permute(4, 2, 3, 0, 1).contiguous().view(b, -1, s_k, d) + v_tmp = v.permute(4, 2, 3, 0, 1).contiguous().view(b, -1, s_k, d) + + cur_S = torch.einsum("bhqd,bhkd->bhqk", q_tmp, k_tmp) + + # For causal masking, disable right-side windowing (no future tokens) + if is_causal: + window_size_right = 0 + + if window_size_left is not None or window_size_right is not None: + q_coords = torch.arange(0, s_q).cuda().view(-1, 1) + k_coords = torch.arange(0, s_k).cuda().view(1, -1) + offset = 0 if not bottom_right_align else s_k - s_q + if window_size_left is None: + _mask = k_coords > q_coords + offset + window_size_right + elif window_size_right is None: + _mask = k_coords < q_coords + offset - window_size_left + else: + _mask = (k_coords > q_coords + offset + window_size_right) | ( + k_coords < q_coords + offset - window_size_left + ) + cur_S = cur_S.masked_fill(_mask.cpu(), -torch.inf) + + p_tmp = torch.softmax(cur_S * scale_softmax, dim=-1) + ref = torch.einsum("bhsl,bhld->bhsd", p_tmp, v_tmp) + ref = ref.view(b, h_r, h_k, s_q, d).permute(3, 4, 1, 2, 0) * scale_output + + cur_S_max = torch.max(cur_S, dim=-1, keepdim=True).values + cur_S_sum = torch.sum( + torch.exp2((cur_S - cur_S_max) * scale_softmax_log2), dim=-1, keepdim=True + ) + + lse = cur_S_max * scale_softmax + torch.log(cur_S_sum) + + # [B, H, Q, 1]->[Q,1,H,B] + lse = lse.permute(2, 3, 1, 0).contiguous().view(s_q, 1, h_r, h_k, b) + + return ref, lse + + if not skip_ref_check: + # Execute kernel oneshot for verify + compiled_fmha( + q_tensor, + k_tensor, + v_tensor, + o_tensor, + lse_tensor, + scale_softmax_log2, + scale_softmax, + scale_output, + ( + window_size_left + if window_size_left is None + else cutlass.Int32(window_size_left) + ), + ( + window_size_right + if window_size_right is None + else cutlass.Int32(window_size_right) + ), + current_stream, + ) + + print("Verifying results...") + o_ref, lse_ref = run_torch_fmha( + q_ref, + k_ref, + v_ref, + scale_softmax, + scale_output, + is_causal, + window_size_left, + window_size_right, + ) + + # convert o back to f32 for comparison + o_fp32, o_fp32_torch = cutlass_torch.cute_tensor_like( + torch.empty(*o_torch.shape, dtype=torch.float32), + cutlass.Float32, + is_dynamic_layout=True, + assumed_align=16, + ) + cute.testing.convert(o_tensor, o_fp32) + + ref_o_f32, ref_o_f32_torch = cutlass_torch.cute_tensor_like( + o_ref, + cutlass.Float32, + is_dynamic_layout=True, + assumed_align=16, + ) + + if out_dtype.is_float and out_dtype.width <= 8: + ref_narrow_precision, _ = cutlass_torch.cute_tensor_like( + torch.empty_strided(o_ref.shape, o_ref.stride(), dtype=torch.uint8), + out_dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + + # convert ref : f32 -> fp4/fp8 -> f32 + cute.testing.convert(ref_o_f32, ref_narrow_precision) + cute.testing.convert(ref_narrow_precision, ref_o_f32) + + # check output ref + torch.testing.assert_close( + o_fp32_torch, ref_o_f32_torch, atol=tolerance, rtol=1e-05 + ) + + # check lse ref + lse_result = lse_torch.cpu() + torch.testing.assert_close(lse_result, lse_ref, atol=tolerance, rtol=1e-05) + + print("Results verified successfully!") + + def generate_tensors(): + _, q_tensor_workspace, _ = create_and_permute_tensor( + b, s_q, h_k, h_r, d, in_dtype, is_dynamic_layout=True + ) + _, k_tensor_workspace, _ = create_and_permute_tensor( + b, s_k, h_k, 1, d, in_dtype, is_dynamic_layout=True + ) + _, v_tensor_workspace, _ = create_and_permute_tensor( + b, s_k, h_k, 1, d, in_dtype, is_dynamic_layout=True, tensor_name="v" + ) + _, o_tensor_workspace, _ = create_and_permute_tensor( + b, s_q, h_k, h_r, d, out_dtype, is_dynamic_layout=True + ) + _, lse_tensor_workspace, _ = create_and_permute_tensor( + b, s_q, h_k, h_r, 1, qk_acc_dtype, is_dynamic_layout=True + ) + + return testing.JitArguments( + q_tensor_workspace, + k_tensor_workspace, + v_tensor_workspace, + o_tensor_workspace, + lse_tensor_workspace, + scale_softmax_log2, + scale_softmax, + scale_output, + ( + window_size_left + if window_size_left is None + else cutlass.Int32(window_size_left) + ), + ( + window_size_right + if window_size_right is None + else cutlass.Int32(window_size_right) + ), + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + q_torch_effective = q_torch.values() if q_torch.is_nested else q_torch + k_torch_effective = k_torch.values() if k_torch.is_nested else k_torch + v_torch_effective = v_torch.values() if v_torch.is_nested else v_torch + o_torch_effective = o_torch.values() if o_torch.is_nested else o_torch + lse_torch_effective = lse_torch.values() if lse_torch.is_nested else lse_torch + one_workspace_bytes = ( + q_torch_effective.numel() * q_torch_effective.element_size() + + k_torch_effective.numel() * k_torch_effective.element_size() + + v_torch_effective.numel() * v_torch_effective.element_size() + + o_torch_effective.numel() * o_torch_effective.element_size() + + lse_torch_effective.numel() * lse_torch_effective.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_fmha, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str): + try: + return [int(x.strip()) for x in s.split(",")] + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description=""" +This example showcases the use of CUTE DSL builders to easily construct fused multi-head attention forward-pass kernels targeting NVIDIA's Hopper architecture. +""" + ) + + parser.add_argument( + "--in_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + help="Input data type", + ) + + parser.add_argument( + "--out_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + help="Output data type", + ) + + parser.add_argument( + "--qk_acc_dtype", + type=cutlass.dtype, + default=cutlass.Float32, + help="QK accumulator data type", + ) + + parser.add_argument( + "--pv_acc_dtype", + type=cutlass.dtype, + default=cutlass.Float32, + help="PV accumulator data type", + ) + + parser.add_argument( + "--mma_tile_shape_mn", + type=parse_comma_separated_ints, + default=[64, 128], + help="MMA tile shape (M, N)", + ) + + parser.add_argument( + "--is_persistent", + action="store_true", + help="Is persistent", + ) + + parser.add_argument( + "--is_causal", + action="store_true", + help="Whether to use causal mask", + ) + + parser.add_argument( + "--q_shape", + type=parse_comma_separated_ints, + default=[1, 128, 16, 128], + help="Shape of Q (B, S_q, H, D)", + ) + + parser.add_argument( + "--k_shape", + type=parse_comma_separated_ints, + default=[1, 128, 16, 128], + help="Shape of K (B, S_k, H_k, D)", + ) + + parser.add_argument( + "--scale_q", + type=float, + default=1.0, + help="Scaling factors to dequantize Q", + ) + + parser.add_argument( + "--scale_k", + type=float, + default=1.0, + help="Scaling factors to dequantize K", + ) + + parser.add_argument( + "--scale_v", + type=float, + default=1.0, + help="Scaling factors to dequantize V", + ) + + parser.add_argument( + "--inv_scale_o", + type=float, + default=1.0, + help="Scaling factor to quantize O", + ) + + parser.add_argument( + "--scale_softmax", + type=float, + default=1.0, + help="Scaling factor to scale S (i.e. Q*K); if zero, defaults to 1/sqrt(D)", + ) + + parser.add_argument( + "--window_size", + type=parse_comma_separated_ints, + default=(-1, -1), + help="Sliding window size (left, right) for attention masking.", + ) + + parser.add_argument( + "--bottom_right_align", + action="store_true", + help="Whether to use bottom right align, under this settion, the end of q is aligned with the end of k.", + ) + + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + + parser.add_argument( + "--warmup_iterations", + type=int, + default=0, + help="Number of iterations for warmup", + ) + + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations after warmup", + ) + + parser.add_argument( + "--skip_ref_check", + action="store_true", + help="Skip reference check", + ) + + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if len(args.q_shape) != 4: + parser.error("--q_shape must contain exactly 4 values") + + if len(args.k_shape) != 4: + parser.error("--k_shape must contain exactly 4 values") + + if len(args.mma_tile_shape_mn) != 2: + parser.error("--mma_tile_shape_mn must contain exactly 2 values") + + run( + args.q_shape, + args.k_shape, + args.in_dtype, + args.out_dtype, + args.qk_acc_dtype, + args.pv_acc_dtype, + args.mma_tile_shape_mn, + args.is_persistent, + args.is_causal, + args.bottom_right_align, + args.scale_q, + args.scale_k, + args.scale_v, + args.inv_scale_o, + args.scale_softmax, + args.window_size, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + + print("PASS") diff --git a/examples/python/CuTeDSL/notebooks/async_pipeline.ipynb b/examples/python/CuTeDSL/notebooks/async_pipeline.ipynb new file mode 100644 index 00000000..3e8f2a97 --- /dev/null +++ b/examples/python/CuTeDSL/notebooks/async_pipeline.ipynb @@ -0,0 +1,599 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "import cutlass\n", + "import cutlass.cute as cute\n", + "from cutlass.cute.runtime import from_dlpack" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "# Tutorial: Warp Specialization with Async Pipeline in CuTe DSL\n", + "\n", + "This tutorial explores advanced CUDA programming techniques for implementing efficient producer-consumer \n", + "patterns using asynchronous communication primitives in the CuTe Domain Specific Language (DSL).\n", + "\n", + "## Foundation: Inter-Warp Communication Basics\n", + "\n", + "### Understanding CUDA Warps and Shared Memory\n", + "\n", + "A **warp** is the fundamental execution unit in CUDA, consisting of 32 threads that execute instructions in Single Instruction, \n", + "Multiple Thread (SIMT) fashion on a Streaming Multiprocessor (SM). Understanding warp-level programming is crucial for \n", + "achieving optimal GPU performance.\n", + "\n", + "**Key Concepts:**\n", + "- Warps execute in lockstep, making them ideal for SIMD operations\n", + "- Multiple warps within a thread block (CTA) can cooperate through shared memory\n", + "- Shared memory provides low-latency, high-bandwidth communication between threads\n", + "\n", + "### Shared Memory Architecture\n", + "\n", + "**Shared memory** serves as a programmer-managed cache with several important characteristics:\n", + "\n", + "- **Speed**: ~100x faster than global memory access\n", + "- **Scope**: Accessible by all threads within the same thread block\n", + "- **Organization**: Divided into banks (typically 32) to enable parallel access\n", + "- **Conflicts**: Bank conflicts occur when multiple threads access the same bank simultaneously\n", + "\n", + "### Traditional Synchronous Communication\n", + "\n", + "The conventional approach for inter-warp communication relies on explicit synchronization barriers. The following sequence diagram \n", + "illustrates the typical producer-consumer pattern:\n", + "\n", + "```mermaid\n", + "sequenceDiagram\n", + " participant W0 as Producer Warp\n", + " participant SMEM as Shared Memory\n", + " participant W1 as Consumer Warp\n", + " \n", + " W0->>SMEM: Write data\n", + " critical Synchronization Barrier\n", + " W0-->W1: __syncthreads()\n", + " SMEM->>W1: Read data\n", + " W0-->W1: __syncthreads()\n", + " end\n", + "```\n", + "\n", + "**Limitations of Synchronous Communication:**\n", + "- All warps must wait at synchronization points\n", + "- No opportunity for overlapped computation\n", + "- Reduced overall throughput due to forced serialization" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "@cute.kernel\n", + "def synced_producer_consumer(SharedStorage: cutlass.Constexpr, res: cute.Tensor):\n", + " warp_idx = cute.arch.warp_idx()\n", + " warp_idx = cute.arch.make_warp_uniform(warp_idx)\n", + "\n", + " smem = cutlass.utils.SmemAllocator()\n", + " storage = smem.allocate(SharedStorage, 64)\n", + "\n", + " staging_smem = storage.staging_buffer.get_tensor(cute.make_layout(1))\n", + " staging_smem.fill(0)\n", + " cute.arch.sync_threads()\n", + "\n", + " for i in cutlass.range(cute.size(res)):\n", + " if warp_idx == 0:\n", + " staging_smem[0] = i * 1.0\n", + " # mark enter of critical region\n", + " cute.arch.sync_threads()\n", + " if warp_idx == 1:\n", + " res[i] = staging_smem[0]\n", + " # mark exit of critical region\n", + " cute.arch.sync_threads()\n", + "\n", + "\n", + "@cute.jit\n", + "def run_synced_producer_consumer(res: cute.Tensor):\n", + " @cute.struct\n", + " class SharedStorage:\n", + " staging_buffer: cute.struct.Align[\n", + " cute.struct.MemRange[cutlass.Float32, 1], 1024\n", + " ]\n", + "\n", + " synced_producer_consumer(SharedStorage, res).launch(\n", + " grid=(1, 1, 1), block=(64, 1, 1), smem=SharedStorage.size_in_bytes()\n", + " )\n", + "\n", + "\n", + "res = torch.zeros((8,), device=\"cuda\")\n", + "run_synced_producer_consumer(from_dlpack(res))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0., 1., 2., 3., 4., 5., 6., 7.], device='cuda:0')" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "\n", + "\n", + "## Asynchronous Communication: Breaking the Synchronization Bottleneck\n", + "\n", + "### The Problem with Synchronous Patterns\n", + "\n", + "The previous example demonstrates traditional synchronized communication between warps. While functional, this approach \n", + "has significant performance limitations:\n", + "\n", + "**Critical Section Analysis:**\n", + "- **First `__syncthreads()`**: Ensures data is written and ready for consumption\n", + "- **Second `__syncthreads()`**: Guarantees data has been consumed and memory can be safely overwritten\n", + "\n", + "**Performance Impact:**\n", + "- All warps are forced into lockstep execution\n", + "- No computational overlap between producer and consumer operations\n", + "- Wasted cycles as warps wait at synchronization barriers\n", + "\n", + "### Hopper Architecture: Enabling Asynchronous Primitives\n", + "\n", + "Starting with the Hopper architecture, CUDA introduced sophisticated asynchronous communication primitives that enable \n", + "**warp specialization**—allowing different warps to perform distinct, specialized roles while maintaining loose coupling.\n", + "\n", + "**Key Benefits:**\n", + "- **Overlapped Execution**: Producer and consumer warps can perform computations concurrently\n", + "- **Reduced Latency**: Eliminates unnecessary synchronization stalls\n", + "- **Better Resource Utilization**: Maximizes SM occupancy and throughput\n", + "\n", + "### Async Pipeline Communication Pattern\n", + "\n", + "The async pipeline abstraction provides a elegant solution for producer-consumer communication without rigid synchronization constraints:\n", + "\n", + "```mermaid\n", + "sequenceDiagram\n", + " participant W0 as Producer Warp\n", + " participant Pipeline as Async Pipeline\n", + " participant SMEM as Shared Memory \n", + " participant W1 as Consumer Warp\n", + " \n", + " W0->>Pipeline: Acquire (request write slot)\n", + " activate W1\n", + " Pipeline-->>W0: Grant access\n", + " deactivate W1\n", + " \n", + " W1->>Pipeline: Wait (for data availability)\n", + " activate Pipeline\n", + " \n", + " W0->>SMEM: Write data\n", + " W0->>Pipeline: Commit (signal data ready)\n", + " \n", + " Pipeline-->>W1: Data available\n", + " deactivate Pipeline\n", + " \n", + " activate W0\n", + " SMEM->>W1: Read data\n", + " deactivate W0\n", + " W1->>Pipeline: Release (mark slot available)\n", + "```\n", + "\n", + "**Async Pipeline Advantages:**\n", + "- **Non-blocking Operations**: Warps can perform other work while waiting\n", + "- **Fine-grained Control**: Explicit control over data readiness and consumption\n", + "- **Scalable**: Supports multiple producer-consumer pairs efficiently" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Async Pipeline API Reference\n", + "\n", + "The `PipelineAsync` abstraction in CuTe DSL provides a comprehensive set of primitives for implementing efficient producer-consumer patterns:\n", + "\n", + "#### Producer Operations\n", + "- **`PipelineProducer.acquire()`**: Blocks until a write slot becomes available (released by consumer)\n", + " - Returns with a handle pointing to a available slot immediately if there is\n", + " - Enables backpressure control to prevent buffer overflow\n", + " - **`PipelineProducer.acquire_and_advance()`** additionally moves the producer's write index to the next buffer slot\n", + "\n", + "- **`PipelineProducer.commit(PipelineProducer.ImmutableProducerHandle)`** / **`PipelineProducer.ImmutableProducerHandle.commit()`**: Signals that data has been written to the handle-pointed slot and is ready for consumption\n", + " - Triggers waiting consumers\n", + " - Maintains data consistency guarantees\n", + " - If no assigned handle, **`PipelineConsumerHandle.release()`** tracks its internal maintained handle (pointed to the last one it acquires)\n", + "\n", + "#### Consumer Operations \n", + "- **`PipelineConsumer.wait()`**: Blocks until data becomes available for reading\n", + " - Returns with a handle pointing to a committed slot when producer commits new data\n", + " - Supports timeout and polling variants\n", + " - **`PipelineConsumer.wait_and_advance()`** additionally moves the consumer's read index to the next buffer slot\n", + "\n", + "- **`PipelineConsumerHandle.release(PipelineConsumer.ImmutableConsumerHandle)`** / **`PipelineConsumer.ImmutableConsumerHandle.release()`**: Marks data as consumed and the handle-pointed slot as consumed and available for reuse\n", + " - Enables producers to acquire released slots\n", + " - Critical for preventing deadlock in circular buffers\n", + " - If no assigned handle, **`PipelineConsumerHandle.release()`** tracks its internal maintained handle (pointed to the last one it waits for)\n", + "\n", + "#### Disclaimer\n", + "\n", + "The `pipeline` APIs provided abstractions for developers to manage synchornization between warps, thread-blocks, etc. It doesn't provide deadlock-free guarantee. It's still developer's responsibility to write correct code to avoid deadlock.\n", + "\n", + "#### Performance Characteristics\n", + "\n", + "**Computational Overlap**: This asynchronous communication pattern enables limited but significant computational overlap:\n", + "- **Producer**: Can perform preprocessing, data transformation, or prefetching while consumer processes previous data\n", + "- **Consumer**: Can execute post-processing, result computation, or output operations while producer prepares next data\n", + "\n", + "**Memory Efficiency**: Explicit slot management ensures optimal memory utilization without unnecessary copying or buffering." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "@cute.kernel\n", + "def async_pipeline_kernel(res: cute.Tensor):\n", + " warp_idx = cute.arch.warp_idx()\n", + " warp_idx = cute.arch.make_warp_uniform(warp_idx)\n", + "\n", + " @cute.struct\n", + " class SharedStorage:\n", + " tma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]\n", + " staging_buffer: cute.struct.Align[\n", + " cute.struct.MemRange[cutlass.Float32, 1], 1024\n", + " ]\n", + "\n", + " smem = cutlass.utils.SmemAllocator()\n", + " storage = smem.allocate(SharedStorage, 64)\n", + "\n", + " # Warp 0\n", + " producer_group = cutlass.pipeline.CooperativeGroup(\n", + " cutlass.pipeline.Agent.Thread, 32\n", + " )\n", + " # Warp 1\n", + " consumer_group = cutlass.pipeline.CooperativeGroup(\n", + " cutlass.pipeline.Agent.Thread, 32\n", + " )\n", + "\n", + " pipeline = cutlass.pipeline.PipelineAsync.create(\n", + " num_stages=1,\n", + " producer_group=producer_group,\n", + " consumer_group=consumer_group,\n", + " barrier_storage=storage.tma_mbar_ptr.data_ptr(),\n", + " )\n", + "\n", + " staging_smem = storage.staging_buffer.get_tensor(cute.make_layout(1))\n", + " staging_smem.fill(0)\n", + " cute.arch.sync_threads()\n", + "\n", + " producer, consumer = pipeline.make_participants()\n", + "\n", + " # Producer warp\n", + " if warp_idx == 0:\n", + " for i in cutlass.range(cute.size(res)):\n", + " # Producer: Wait for data buffer is available\n", + " handle = producer.acquire_and_advance()\n", + " # Producer: Write data to shared memory\n", + " staging_smem[handle.index] = 1.0 * i\n", + " # Producer: Signal data is ready for consumption\n", + " handle.commit()\n", + " producer.tail()\n", + "\n", + " # Consumer warp\n", + " if warp_idx == 1:\n", + " for i in cutlass.range(cute.size(res)):\n", + " # Consumer: Wait for producer to signal when data is available for use\n", + " handle = consumer.wait_and_advance()\n", + " # Conumer: consumes data\n", + " res[i] = staging_smem[handle.index]\n", + " # Conumer: Signal data buffer is ready for write\n", + " handle.release()\n", + "\n", + "\n", + "@cute.jit\n", + "def async_pipeline(res: cute.Tensor):\n", + " # Launch kernel with two warps: producer and consumer\n", + " async_pipeline_kernel(res).launch(grid=(1, 1, 1), block=(64, 1, 1))\n", + "\n", + "\n", + "res = torch.zeros((8,), device=\"cuda\")\n", + "async_pipeline(from_dlpack(res))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0., 1., 2., 3., 4., 5., 6., 7.], device='cuda:0')" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "## Advanced Pattern: Staged Async Pipeline with Circular Buffering\n", + "\n", + "### Limitations of Single-Stage Pipelines\n", + "\n", + "While async communication provides significant improvements over synchronous patterns, single-stage pipelines \n", + "still exhibit serialization bottlenecks:\n", + "\n", + "**Dependency Chain Analysis:**\n", + "```mermaid\n", + "sequenceDiagram\n", + " participant W0 as Producer\n", + " participant Pipeline as Pipeline\n", + " participant W1 as Consumer\n", + " \n", + " W0->>Pipeline: Acquire\n", + " Note over W0,W1: Producer waits here\n", + " W1->>Pipeline: Release\n", + " Pipeline-->>W0: Granted\n", + "```\n", + "\n", + "**Performance Bottleneck**: The producer must wait for the consumer to complete processing and release the buffer \n", + "before acquiring the next write slot. This creates a serialization point that limits overall throughput.\n", + "\n", + "### Multi-Stage Pipeline Architecture\n", + "\n", + "The **staged async pipeline** implements a circular buffer managed by an array of synchronization barriers, \n", + "enabling much higher degrees of parallelism:\n", + "\n", + "#### Core Concepts\n", + "\n", + "**Circular Buffer Management:**\n", + "- **Multiple Stages**: Support for N concurrent buffer slots (typically 2-8 stages)\n", + "- **Independent Indexing**: Producer and consumer maintain separate advancement indices\n", + "- **Barrier Array**: Each stage has an associated memory barrier for fine-grained synchronization\n", + "\n", + "#### Enhanced API Operations\n", + "\n", + "- **`PipelineProducer.advance()`**: Moves the producer's write index to the next buffer slot\n", + " - Enables round-robin buffer allocation\n", + " - Allows producer to continue without waiting for all previous data to be consumed\n", + " - Can be conducted implicitly when calling **`PipelineProducer.require_and_advance()`**\n", + "\n", + "- **`PipelineConsumer.advance()`**: Moves the consumer's read index to the next buffer slot\n", + " - Maintains proper ordering of data consumption\n", + " - Signals availability of processed slots\n", + " - Can be conducted implicitly when calling **`PipelineConsumer.wait_and_advance()`**\n", + "\n", + "- **`PipelineProducer.ImmutableResourceHandle.index`** / **`PipelineConsumer.ImmutableResourceHandle.index`**: Returns pointed buffer slot index\n", + " - Used for addressing specific staging buffer locations\n", + " - Enables direct slot-based data access\n", + "\n", + "### Circular Buffer State Visualization\n", + "\n", + "```\n", + "Legend:\n", + " W: Currently being written (producer active)\n", + " D: Data ready for consumption \n", + " R: Currently being read (consumer active)\n", + " X: Empty slot available for writing\n", + " \n", + " Advance Direction\n", + " <-------------------\n", + "\n", + " Producer Consumer\n", + " | ^\n", + " V |\n", + " +-----------------+\n", + " --|X|X|W|D|D|D|D|R|X|<-.\n", + " / +-----------------+ \\\n", + " | |\n", + " `------------------------' \n", + "```\n", + "\n", + "**Key Advantages:**\n", + "- **Increased Throughput**: Producer can stay ahead of consumer by multiple stages\n", + "- **Latency Hiding**: Consumer processing latency is hidden by buffered data\n", + "- **Better Resource Utilization**: Both warps can maintain high activity levels\n", + "- **Scalable Design**: Buffer depth can be tuned based on workload characteristics\n", + "\n", + "The following implementation demonstrates efficient multi-stage pipeline communication with proper circular buffer management:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "@cute.kernel\n", + "def async_pipeline_staged_kernel(\n", + " SharedStorage: cutlass.Constexpr, res: cute.Tensor, staging: cute.Tensor\n", + "):\n", + " stages = cute.size(staging)\n", + "\n", + " warp_idx = cute.arch.warp_idx()\n", + " warp_idx = cute.arch.make_warp_uniform(warp_idx)\n", + "\n", + " smem = cutlass.utils.SmemAllocator()\n", + " storage = smem.allocate(SharedStorage, 64)\n", + "\n", + " # Warp 0\n", + " producer_group = cutlass.pipeline.CooperativeGroup(\n", + " cutlass.pipeline.Agent.Thread, 32\n", + " )\n", + " # Warp 1\n", + " consumer_group = cutlass.pipeline.CooperativeGroup(\n", + " cutlass.pipeline.Agent.Thread, 32\n", + " )\n", + "\n", + " pipeline = cutlass.pipeline.PipelineAsync.create(\n", + " num_stages=stages,\n", + " producer_group=producer_group,\n", + " consumer_group=consumer_group,\n", + " barrier_storage=storage.tma_mbar_ptr.data_ptr(),\n", + " )\n", + "\n", + " staging_smem = storage.staging_buffer.get_tensor(staging.layout)\n", + " staging_smem.fill(0)\n", + " cute.arch.sync_threads()\n", + "\n", + " producer, consumer = pipeline.make_participants()\n", + "\n", + " # Producer warp\n", + " if warp_idx == 0:\n", + " for i in cutlass.range(cute.size(res)):\n", + " handle = producer.acquire_and_advance()\n", + " staging_smem[handle.index] = 1.0 * i\n", + " handle.commit() # or producer.commit(handle)\n", + "\n", + " # prevents CTA0 from retiring until it receives all expected arrives.\n", + " producer.tail()\n", + "\n", + " # Consumer warp\n", + " if warp_idx == 1:\n", + " for i in cutlass.range(cute.size(res)):\n", + " handle = consumer.wait_and_advance()\n", + " res[i] = staging_smem[handle.index]\n", + " handle.release() # or consumer.release(handle)\n", + "\n", + " tidx, _, _ = cute.arch.thread_idx()\n", + " if tidx == 0:\n", + " staging.store(staging_smem.load())\n", + "\n", + "\n", + "@cute.jit\n", + "def async_pipeline_staged(res: cute.Tensor, staging: cute.Tensor):\n", + " stages = cute.size(staging)\n", + "\n", + " @cute.struct\n", + " class SharedStorage:\n", + " tma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, stages * 2]\n", + " staging_buffer: cute.struct.Align[\n", + " cute.struct.MemRange[cutlass.Float32, stages], 1024\n", + " ]\n", + "\n", + " async_pipeline_staged_kernel(SharedStorage, res, staging).launch(\n", + " grid=(1, 1, 1), block=(64, 1, 1), smem=SharedStorage.size_in_bytes()\n", + " )\n", + "\n", + "\n", + "res = torch.zeros((8,), device=\"cuda\")\n", + "staging = torch.zeros((5,), device=\"cuda\")\n", + "async_pipeline_staged(from_dlpack(res), from_dlpack(staging))\n", + "torch.cuda.synchronize()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([0., 1., 2., 3., 4., 5., 6., 7.], device='cuda:0'),\n", + " tensor([5., 6., 7., 3., 4.], device='cuda:0'))" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res, staging" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Try Acquire/Wait\n", + "\n", + "In some circumstances, developers may want to just check status of pipeline state without blocking. This could benefit some cases that we have independent instructions to hide latency of checking pipeline state. We provided `try_aquire` or `try_wait` which are non-blocking APIs. " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/python/CuTeDSL/notebooks/benchmark_autotune.ipynb b/examples/python/CuTeDSL/notebooks/benchmark_autotune.ipynb new file mode 100644 index 00000000..17b38429 --- /dev/null +++ b/examples/python/CuTeDSL/notebooks/benchmark_autotune.ipynb @@ -0,0 +1,460 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "import cutlass\n", + "import cutlass.cute as cute\n", + "import cutlass.cute.testing as testing\n", + "import cutlass.torch as cutlass_torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The Usage of Benchmark and Autotune Utilities in CuTe DSL\n", + "\n", + "CuTe DSL provides autotune and benchmark utilities to help users evaluate and optimize kernel performance. This notebook demonstrates how to use these tools.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### Autotune\n", + "\n", + "We provides two kinds of autotune utilities for users: `autotune.jit` decorator and the `tune` function. The former is used as a decorator used on top of `@cute.jit` while the latter is used as an individual function.\n", + "\n", + "#### @autotune.jit\n", + "\n", + "We take the `elementwise_add_kernel` as an example. After writing the jit host function and kernel, we could add the `@autotune_jit` decorator on top of the jit host function to enable autotune. \n", + "```python\n", + "@testing.autotune_jit(\n", + " params_dict={\"copy_bits\": [64, 128]},\n", + " update_on_change=[\"M\", \"N\"],\n", + " warmup_iterations=100,\n", + " iterations=100,\n", + ")\n", + "```\n", + "\n", + "The `autotune_jit` decorator provides several parameters to control the autotuning process:\n", + "\n", + "- params_dict: A dictionary containing the parameters to be tuned and their possible values\n", + "- update_on_change: A list of argument names that trigger re-tuning when their values change\n", + "- warmup_iterations: Number of warmup iterations before timing\n", + "- iterations: Number of iterations for timing each parameter combination\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "@cute.kernel\n", + "def elementwise_add_kernel(\n", + " gA: cute.Tensor,\n", + " gB: cute.Tensor,\n", + " gC: cute.Tensor,\n", + " cC: cute.Tensor, # coordinate tensor\n", + " shape: cute.Shape,\n", + " thr_layout: cute.Layout,\n", + " val_layout: cute.Layout,\n", + "):\n", + " tidx, _, _ = cute.arch.thread_idx()\n", + " bidx, _, _ = cute.arch.block_idx()\n", + "\n", + " # slice for CTAs\n", + " # logical id -> address\n", + " blk_coord = ((None, None), bidx)\n", + " blkA = gA[blk_coord] # (TileM,TileN)\n", + " blkB = gB[blk_coord] # (TileM,TileN)\n", + " blkC = gC[blk_coord] # (TileM,TileN)\n", + " blkCrd = cC[blk_coord] # (TileM, TileN)\n", + "\n", + " # # declare the atoms which will be used later for memory copy\n", + " copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)\n", + " copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type)\n", + "\n", + " tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)\n", + " tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)\n", + " tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout)\n", + "\n", + " thr_copy_A = tiled_copy_A.get_slice(tidx)\n", + " thr_copy_B = tiled_copy_B.get_slice(tidx)\n", + " thr_copy_C = tiled_copy_C.get_slice(tidx)\n", + "\n", + " thrA = thr_copy_A.partition_S(blkA)\n", + " thrB = thr_copy_B.partition_S(blkB)\n", + " thrC = thr_copy_C.partition_S(blkC)\n", + "\n", + " # allocate fragments for gmem->rmem\n", + " frgA = cute.make_fragment_like(thrA)\n", + " frgB = cute.make_fragment_like(thrB)\n", + " frgC = cute.make_fragment_like(thrC)\n", + "\n", + " thrCrd = thr_copy_C.partition_S(blkCrd)\n", + " frgPred = cute.make_rmem_tensor(thrCrd.shape, cutlass.Boolean)\n", + "\n", + " for i in range(0, cute.size(frgPred), 1):\n", + " val = cute.elem_less(thrCrd[i], shape)\n", + " frgPred[i] = val\n", + "\n", + " ##########################################################\n", + " # Move data to reg address space\n", + " ##########################################################\n", + "\n", + " cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)\n", + " cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)\n", + "\n", + " # Load data before use. The compiler will optimize the copy and load\n", + " # operations to convert some memory ld/st into register uses.\n", + " result = frgA.load() + frgB.load()\n", + "\n", + " # Save the results back to registers. Here we reuse b's registers.\n", + " frgC.store(result)\n", + "\n", + " # Copy the results back to c\n", + " cute.copy(copy_atom_store, frgC, thrC, pred=frgPred)\n", + "\n", + "\n", + "@testing.autotune_jit(\n", + " params_dict={\"copy_bits\": [64, 128]},\n", + " update_on_change=[\"M\", \"N\"],\n", + " warmup_iterations=100,\n", + " iterations=100,\n", + ")\n", + "@cute.jit\n", + "def elementwise_add_autotune(mA, mB, mC, M, N, copy_bits: cutlass.Constexpr = 128):\n", + " dtype = mA.element_type\n", + " vector_size = copy_bits // dtype.width\n", + "\n", + " thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))\n", + " val_layout = cute.make_ordered_layout((4, vector_size), order=(1, 0))\n", + " tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n", + "\n", + " gA = cute.zipped_divide(mA, tiler_mn) # ((TileM,TileN),(RestM,RestN))\n", + " gB = cute.zipped_divide(mB, tiler_mn) # ((TileM,TileN),(RestM,RestN))\n", + " gC = cute.zipped_divide(mC, tiler_mn) # ((TileM,TileN),(RestM,RestN))\n", + " idC = cute.make_identity_tensor(mC.shape)\n", + " cC = cute.zipped_divide(idC, tiler=tiler_mn)\n", + "\n", + " elementwise_add_kernel(gA, gB, gC, cC, mC.shape, thr_layout, val_layout).launch(\n", + " grid=[cute.size(gC, mode=[1]), 1, 1],\n", + " block=[cute.size(tv_layout, mode=[0]), 1, 1],\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When we run the jit funciton `elementwise_add_autotune`, the CuTe DSL will help us tune the kernels by looping the specified configs and run the kernel with the best config.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "\n", + "M, N = 1024, 1024\n", + "dtype = cutlass.Float32\n", + "skip_ref_check = False\n", + "\n", + "print(f\"\\nRunning Elementwise Add test with:\")\n", + "print(f\"Tensor dimensions: [{M}, {N}]\")\n", + "print(f\"Input and Output Data type: {dtype}\")\n", + "\n", + "torch_dtype = cutlass_torch.dtype(dtype)\n", + "\n", + "a = torch.randn(M, N, device=torch.device(\"cuda\"), dtype=torch_dtype)\n", + "b = torch.randn(M, N, device=torch.device(\"cuda\"), dtype=torch_dtype)\n", + "\n", + "c = torch.zeros_like(a)\n", + "\n", + "print(f\"Input tensor shapes:\")\n", + "print(f\"a: {a.shape}, dtype: {a.dtype}\")\n", + "print(f\"b: {b.shape}, dtype: {b.dtype}\")\n", + "print(f\"c: {c.shape}, dtype: {c.dtype}\\n\")\n", + "\n", + "elementwise_add_autotune(a, b, c, M, N)\n", + "\n", + "if not skip_ref_check:\n", + " print(\"Verifying results for autotuned function ...\")\n", + " torch.testing.assert_close(a + b, c)\n", + " print(\"Results verified successfully!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The output is as follows:\n", + "\n", + "```\n", + "Running Elementwise Add test with:\n", + "Tensor dimensions: [1024, 1024]\n", + "Input and Output Data type: Float32\n", + "Input tensor shapes:\n", + "a: torch.Size([1024, 1024]), dtype: torch.float32\n", + "b: torch.Size([1024, 1024]), dtype: torch.float32\n", + "c: torch.Size([1024, 1024]), dtype: torch.float32\n", + "Verifying results for autotuned function ...\n", + "Results verified successfully!\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "To monitor the autotuning process in detail, you can enable logging by setting the environment variable `CUTE_DSL_LOG_AUTOTUNE`. \n", + "```shell\n", + "export CUTE_DSL_LOG_AUTOTUNE=1\n", + "```\n", + "This will display comprehensive information including:\n", + "- Each configuration being evaluated and its corresponding execution time\n", + "- The optimal configuration that was selected\n", + "- Total time spent on tuning\n", + "- Cache hit/miss statistics\n", + "\n", + "\n", + "Below is a sample output showing the autotuning process with different configurations:\n", + "```python\n", + "2025-07-23 06:17:03,978 - cutlass.cute.testing_Autotune - INFO - Tuning configuration: {'copy_bits': 64}\n", + "2025-07-23 06:17:04,519 - cutlass.cute.testing_Autotune - INFO - Execution time: 0.010857919985428453 us\n", + "2025-07-23 06:17:04,519 - cutlass.cute.testing_Autotune - INFO - Tuning configuration: {'copy_bits': 128}\n", + "2025-07-23 06:17:04,683 - cutlass.cute.testing_Autotune - INFO - Execution time: 0.011117440033704042 us\n", + "2025-07-23 06:17:04,683 - cutlass.cute.testing_Autotune - INFO - Best configuration: {'copy_bits': 64}, execution time: 0.010857919985428453 us\n", + "2025-07-23 06:17:04,683 - cutlass.cute.testing_Autotune - INFO - Total tuning time: 0.7053244113922119 s\n", + "...\n", + "2025-07-23 06:17:04,700 - cutlass.cute.testing_Autotune - INFO - Using cached best configuration: {'copy_bits': 64}\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### tune\n", + "\n", + "We also provide a `tune` funtion. The interface of the `tune` function is as follows:\n", + "\n", + "```python\n", + "def tune(\n", + " func: Callable[[Any], Callable[[], Any]],\n", + " params_dict: Dict[str, List[Any]] = None,\n", + " kernel_arguments: JitArguments = JitArguments(),\n", + " warmup_iterations=10,\n", + " iterations=100,\n", + " stream: Optional[cuda_driver.CUstream] = None,\n", + ") -> Dict[str, Any]:\n", + "```\n", + "\n", + "The `tune` function takes the following parameters:\n", + "\n", + "- func: A callable that takes configuration parameters and returns a kernel function\n", + "- params_dict: Dictionary mapping parameter names to lists of possible values to tune\n", + "- kernel_arguments: Arguments to pass to the kernel for tuning\n", + "- warmup_iterations: Number of warmup iterations before timing (default: 10)\n", + "- iterations: Number of timing iterations per configuration (default: 100)\n", + "- stream: Optional CUDA stream to use for execution. defaults to default CUDA stream. The stream parameter must match the stream passed to the kernel, mismatched streams will result in an error.\n", + "\n", + "It returns a dictionary containing the best kernel configuration found.\n", + "\n", + "\n", + "Here is an example to use the `tune` function:\n", + "\n", + "1. First remove the `@testing.autotune_jit` decorator from the `elementwise_add_autotune` function:\n", + " ```python\n", + " @testing.autotune_jit(\n", + " params_dict={\"copy_bits\": [64, 128]},\n", + " update_on_change=[\"M\", \"N\"], \n", + " warmup_iterations=100,\n", + " iterations=100,\n", + " )\n", + " ```\n", + "\n", + " 2. Define a `tune_func` that:\n", + " - Takes input tensors (a, b, c), dimensions (M, N) and tuning parameter copy_bits\n", + " - Compiles the `elementwise_add_autotune` function using `cute.compile()`\n", + " - Returns a lambda function that executes the compiled kernel\n", + "\n", + " 3. Pass `tune_func` to `testing.tune` function along with:\n", + " - Parameter space to explore (copy_bits values)\n", + " - Kernel arguments wrapped in JitArguments\n", + " - The `tune` function will find optimal parameters automatically\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "def tune_func(a, b, c, M, N, copy_bits=128):\n", + " compiled_func = cute.compile(elementwise_add_autotune, a, b, c, M, N, copy_bits=128)\n", + " return lambda: compiled_func(a, b, c, M, N)\n", + "\n", + "params = testing.tune(\n", + " tune_func,\n", + " params_dict={\"copy_bits\": [64, 128]},\n", + " kernel_arguments=testing.JitArguments(a, b, c, M, N),\n", + ")\n", + "print(f\"The best kernel configs found: {params}\")\n", + "\n", + "# run the kernel with the best config\n", + "compiled_func = cute.compile(elementwise_add_autotune, a, b, c, M, N, **params)\n", + "compiled_func(a, b, c, M, N)\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### benchmark\n", + "\n", + "In CuTe DSL, the benchmark utility can be used to measure kernel execution time. The interface of benchmark routine is as follows:\n", + "\n", + "```python\n", + "def benchmark(\n", + " callable: Callable,\n", + " *,\n", + " warmup_iterations: int = 10,\n", + " iterations: int = 100,\n", + " stream: Optional[cuda_driver.CUstream] = None,\n", + " kernel_arguments: Optional[JitArguments] = None,\n", + " workspace_generator: Optional[Callable[[], JitArguments]] = None,\n", + " workspace_count: int = 1,\n", + " use_cuda_graphs: bool = False,\n", + ") -> float:\n", + "```\n", + "\n", + "The benchmark utility exposes several key configuration parameters to control profiling behavior:\n", + "\n", + "- callable: The function to be benchmarked\n", + "- warmup_iterations: Controls the number of initial warmup iterations before measurement begins (default: 10)\n", + "- iterations: Specifies how many iterations to profile for performance measurement (default: 100)\n", + "- stream: Designates which CUDA stream to execute the kernel on (default: default stream) \n", + "- use_cuda_graphs: Whether enables CUDA graph for the callable function to minimize kernel launch overhead (default: False)\n", + "- workspace_generator: Provides a function that generates fresh kernel arguments each iteration to avoid caching effects\n", + "- workspace_count: Determines how many different workspaces to cycle through during profiling (default: 1)\n", + "\n", + "When benchmarking, there are several key parameters that can be configured:\n", + "\n", + "1. Core parameters:\n", + " - The function to profile (callable)\n", + " - Number of warmup iterations before measurement\n", + " - Number of profiling iterations for measurement\n", + "\n", + "2. Stream configuration:\n", + " - For kernels running in non-default streams, the stream must be specified\n", + " - The stream parameter must match the stream passed to the kernel, mismatched streams will result in an error\n", + "\n", + "3. Cache effects mitigation:\n", + " - To prevent L2 cache effects from skewing results, multiple workspaces can be cycled through\n", + " - This is controlled via workspace_count and workspace_generator parameters\n", + " - Each workspace provides fresh kernel arguments\n", + "\n", + "4. CUDA Graph support:\n", + " - Enables measuring kernel execution time without host overhead\n", + " - Requires the callable to be decorated with @cute.jit\n", + " - Must use a non-default CUDA stream when using graphs\n", + "\n", + "This function will return the execution time of the callable in microseconds. As GPU frequency can vary dynamically, we could fix the SM and memory frequencies to get more stable and reproducible benchmark results. This can be done by setting the GPU clocks using nvidia-smi before running the benchmark. In the next, let's use the benchmark function to get the execution time of the above elementwise_add kernel." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "def generate_kernel_arguments():\n", + " a = torch.randn(\n", + " M, N, device=torch.device(\"cuda\"), dtype=torch_dtype\n", + " )\n", + " b = torch.randn(\n", + " M, N, device=torch.device(\"cuda\"), dtype=torch_dtype\n", + " )\n", + "\n", + " c = torch.zeros_like(a)\n", + "\n", + " return testing.JitArguments(a, b, c, M, N)\n", + "\n", + "avg_time_us = testing.benchmark(\n", + " elementwise_add_autotune,\n", + " workspace_generator=generate_kernel_arguments,\n", + " workspace_count=10,\n", + " warmup_iterations=10,\n", + " iterations=100,\n", + ")\n", + "\n", + "# Print execution results\n", + "print(\n", + " f\"Kernel execution time for cute.jit kernel with M={M}, N={N}: {avg_time_us / 1e3:.4f} ms\"\n", + ")\n", + "print(\n", + " f\"Achieved memory throughput for M={M}, N={N}: {(3 * a.numel() * dtype.width // 8) / (avg_time_us / 1e6) / 1e9:.2f} GB/s\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After running the code, we will get output similar to the following:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```\n", + "Kernel execution time for cute.jit kernel with M=1024, N=1024: 0.0403 ms\n", + "Achieved memory throughput for M=1024, N=1024: 312.37 GB/s\n", + "```" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/python/CuTeDSL/notebooks/composed_layout.ipynb b/examples/python/CuTeDSL/notebooks/composed_layout.ipynb new file mode 100644 index 00000000..8eb72219 --- /dev/null +++ b/examples/python/CuTeDSL/notebooks/composed_layout.ipynb @@ -0,0 +1,225 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0c7cf795", + "metadata": {}, + "source": [ + "# Composed Layout in CuTe\n", + "\n", + "A **Composed Layout** is a powerful abstraction in CuTe that enables complex data transformations through \n", + "the composition of layouts and transformations. It provides a flexible way to manipulate memory layouts \n", + "and coordinate systems.\n", + "\n", + "## Components\n", + "\n", + "A Composed Layout consists of three key components:\n", + "\n", + "1. **Inner Layout/Transformation** (`inner`):\n", + " - Can be a layout, swizzle, or custom transformation function\n", + " - Applies the final transformation to the coordinates\n", + " - Supports arbitrary coordinate manipulations\n", + "\n", + "2. **Offset** (`offset`):\n", + " - Typically represented as an integer tuple\n", + " - Adds a constant displacement to coordinates\n", + " - Enables fine-grained control over data positioning\n", + "\n", + "3. **Outer Layout** (`outer`):\n", + " - The layout visible to the user\n", + " - Defines the initial coordinate transformation\n", + " - Determines the shape and organization of the data structure\n", + "\n", + "## Mathematical Representation\n", + "\n", + "The mathematical composition of these components is defined as:\n", + "\n", + "$\n", + "R(c) := (inner \\circ offset \\circ outer)(c) := inner(offset + outer(c))\n", + "$\n", + "\n", + "Where:\n", + "- $c$ represents the input coordinates\n", + "- $\\circ$ denotes function composition\n", + "- The transformation is applied from right to left\n", + "\n", + "## Usage in Python\n", + "\n", + "To create a Composed Layout in Python, use the `make_composed_layout` function:\n", + "\n", + "```python\n", + "layout = cute.make_composed_layout(inner, offset, outer)\n", + "```\n", + "\n", + "## Key Benefits\n", + "\n", + "1. **Flexibility**: Supports complex transformations that direct composition cannot handle\n", + "2. **Modularity**: Separates different aspects of the transformation\n", + "3. **Performance**: Enables optimized memory access patterns for GPU computations\n", + "4. **Compatibility**: Works with various types of transformations and layouts" + ] + }, + { + "cell_type": "markdown", + "id": "24448f7d", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "source": [ + "## Custom Transformation Example\n", + "\n", + "This example demonstrates how to create a Composed Layout with a custom transformation function. We'll create a simple transformation that:\n", + "\n", + "1. Takes a 2D coordinate input `(x, y)`\n", + "2. Increments the y-coordinate by 1\n", + "3. Combines this with an offset and identity layout\n", + "\n", + "The example shows how to:\n", + "- Define a custom transformation function\n", + "- Create a composed layout with the transformation\n", + "- Apply the layout to coordinates\n", + "- Print the results for verification" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "184f30e6", + "metadata": {}, + "outputs": [], + "source": [ + "import cutlass\n", + "import cutlass.cute as cute\n", + "from cutlass.cute.runtime import from_dlpack, make_ptr\n", + "\n", + "\n", + "@cute.jit\n", + "def customized_layout():\n", + " def inner(c):\n", + " x, y = c\n", + " return x, y + 1\n", + "\n", + " layout = cute.make_composed_layout(\n", + " inner, (1, 0), cute.make_identity_layout(shape=(8, 4))\n", + " )\n", + " print(layout)\n", + " cute.printf(layout(0))\n", + "\n", + "\n", + "customized_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "c897187f", + "metadata": {}, + "source": [ + "## Gather/Scatter Operations with Composed Layout\n", + "\n", + "Gather and Scatter operations are fundamental data access patterns in parallel computing and GPU programming. In CuTe, we can implement these operations elegantly using Composed Layout.\n", + "\n", + "### Gather Operation\n", + "A gather operation collects elements from a source array using an index array (also called an indirection array). It's defined as:\n", + "```python\n", + "output[i] = source[index[i]]\n", + "```\n", + "\n", + "#### Components in CuTe Implementation:\n", + "1. **Offset Tensor**: Contains the indices for gathering (`offset_tensor`)\n", + "2. **Data Pointer**: Points to the source data array (`data_ptr`)\n", + "3. **Shape**: Defines the shape of logic tensor viewed by user (`shape`)\n", + "\n", + "### How it Works\n", + "1. The inner transformation function reads from the offset tensor:\n", + " ```python\n", + " def inner(c):\n", + " return offset_tensor[c] # Returns the gather index\n", + " ```\n", + "2. The composed layout maps input coordinates through the offset tensor:\n", + " ```python\n", + " gather_layout = cute.make_composed_layout(inner, 0, cute.make_layout(shape))\n", + " ```\n", + "3. This creates an indirect access pattern where:\n", + " - Input coordinate `i` → `offset_tensor[i]` → `data_ptr[offset_tensor[i]]`\n", + "\n", + "4. notably, layout operations like slice, partition can still be applied on `outer` layout\n", + "\n", + "### Use Cases\n", + "- **Sparse Operations**: Accessing non-contiguous memory efficiently\n", + "- **Graph Processing**: Following edge connections in graph algorithms\n", + "- **Feature Embedding**: Looking up embeddings for discrete tokens\n", + "- **Irregular Data Access**: Any pattern requiring indirect memory access\n", + "\n", + "### Example Output Interpretation\n", + "The example code prints pairs of numbers `i -> j` where:\n", + "- `i` is the output index\n", + "- `j` is the gathered source index from `offset_tensor`\n", + "\n", + "This demonstrates how the composed layout transforms coordinates for indirect memory access.\n", + "\n", + "Note: Scatter operations (writing to indirect locations) can be implemented similarly by reversing the data flow direction.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d68f9476", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "\n", + "@cute.jit\n", + "def gather_tensor(\n", + " offset_tensor: cute.Tensor, data_ptr: cute.Pointer, shape: cute.Shape\n", + "):\n", + " def inner(c):\n", + " return offset_tensor[c]\n", + "\n", + " gather_layout = cute.make_composed_layout(inner, 0, cute.make_layout(shape))\n", + " for i in cutlass.range_constexpr(cute.size(shape)):\n", + " cute.printf(\"%d -> %d\", i, gather_layout(i))\n", + "\n", + " # TODO: support in future\n", + " # gather_tensor = cute.make_tensor(data_ptr, gather_layout)\n", + " # cute.printf(gather_tensor[0])\n", + "\n", + "\n", + "shape = (16,)\n", + "offset_tensor = torch.randint(0, 256, shape, dtype=torch.int32)\n", + "data_tensor = torch.arange(0, 256, dtype=torch.int32)\n", + "\n", + "\n", + "gather_tensor(\n", + " from_dlpack(offset_tensor),\n", + " make_ptr(cutlass.Int32, data_tensor.data_ptr(), cute.AddressSpace.generic),\n", + " shape,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv3_12", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/python/CuTeDSL/notebooks/cuda_graphs.ipynb b/examples/python/CuTeDSL/notebooks/cuda_graphs.ipynb index dc7c17cf..42385d20 100644 --- a/examples/python/CuTeDSL/notebooks/cuda_graphs.ipynb +++ b/examples/python/CuTeDSL/notebooks/cuda_graphs.ipynb @@ -26,10 +26,11 @@ "source": [ "# import torch for CUDA graphs\n", "import torch\n", - "import cutlass\n", "import cutlass.cute as cute\n", + "\n", "# import CUstream type from the cuda driver bindings\n", "from cuda.bindings.driver import CUstream\n", + "\n", "# import the current_stream function from torch\n", "from torch.cuda import current_stream" ] @@ -61,13 +62,15 @@ " \"\"\"\n", " cute.printf(\"Hello world\")\n", "\n", + "\n", "@cute.jit\n", - "def hello_world(stream : CUstream):\n", + "def hello_world(stream: CUstream):\n", " \"\"\"\n", " Host function that launches our (1,1,1), (1,1,1) grid in stream\n", " \"\"\"\n", " hello_world_kernel().launch(grid=[1, 1, 1], block=[1, 1, 1], stream=stream)\n", "\n", + "\n", "# Grab a stream from PyTorch, this will also initialize our context\n", "# so we can omit cutlass.cuda.initialize_cuda_context()\n", "stream = current_stream()\n", @@ -585,7 +588,7 @@ "\n", "# Calculate the time spent when launching kernels in a stream\n", "# Results are in ms\n", - "stream_time = start.elapsed_time(end) \n", + "stream_time = start.elapsed_time(end)\n", "\n", "# Warmup our GPU again\n", "g.replay()\n", diff --git a/examples/python/CuTeDSL/notebooks/cute_layout_algebra.ipynb b/examples/python/CuTeDSL/notebooks/cute_layout_algebra.ipynb index 3a3f9ed7..93ac16b6 100644 --- a/examples/python/CuTeDSL/notebooks/cute_layout_algebra.ipynb +++ b/examples/python/CuTeDSL/notebooks/cute_layout_algebra.ipynb @@ -90,7 +90,9 @@ " \"\"\"\n", " Demonstrates coalesce operation flattening and combining modes\n", " \"\"\"\n", - " layout = cute.make_layout((2, (1, 6)), stride=(1, (cutlass.Int32(6), 2))) # Dynamic stride\n", + " layout = cute.make_layout(\n", + " (2, (1, 6)), stride=(1, (cutlass.Int32(6), 2))\n", + " ) # Dynamic stride\n", " result = cute.coalesce(layout)\n", "\n", " print(\">>> Original:\", layout)\n", @@ -98,6 +100,7 @@ " print(\">>> Coalesced:\", result)\n", " cute.printf(\">?? Coalesced: {}\", result)\n", "\n", + "\n", "coalesce_example()" ] }, @@ -275,8 +278,7 @@ " 3. for all i, 0 <= i < size(@a layout), @a result(i) == @a layout(i)\n", " \"\"\"\n", " layout = cute.make_layout(\n", - " ((2, (3, 4)), (3, 2), 1),\n", - " stride=((4, (8, 24)), (2, 6), 12)\n", + " ((2, (3, 4)), (3, 2), 1), stride=((4, (8, 24)), (2, 6), 12)\n", " )\n", " result = cute.coalesce(layout)\n", "\n", @@ -288,21 +290,26 @@ " original_size = cute.size(layout)\n", " coalesced_size = cute.size(result)\n", " print(f\"Original size: {original_size}, Coalesced size: {coalesced_size}\")\n", - " assert coalesced_size == original_size, \\\n", - " f\"Size mismatch: original {original_size}, coalesced {coalesced_size}\"\n", - " \n", + " assert coalesced_size == original_size, (\n", + " f\"Size mismatch: original {original_size}, coalesced {coalesced_size}\"\n", + " )\n", + "\n", " print(\">>> 2. Checking depth of coalesced layout <= 1:\")\n", " depth = cute.depth(result)\n", " print(f\"Depth of coalesced layout: {depth}\")\n", " assert depth <= 1, f\"Depth of coalesced layout should be <= 1, got {depth}\"\n", "\n", - " print(\">>> 3. Checking layout functionality remains the same after the coalesce operation:\")\n", + " print(\n", + " \">>> 3. Checking layout functionality remains the same after the coalesce operation:\"\n", + " )\n", " for i in cutlass.range_constexpr(original_size):\n", " original_value = layout(i)\n", " coalesced_value = result(i)\n", " print(f\"Index {i}: original {original_value}, coalesced {coalesced_value}\")\n", - " assert coalesced_value == original_value, \\\n", + " assert coalesced_value == original_value, (\n", " f\"Value mismatch at index {i}: original {original_value}, coalesced {coalesced_value}\"\n", + " )\n", + "\n", "\n", "coalesce_post_conditions()" ] @@ -338,11 +345,12 @@ "\n", " # Coalesce with mode-wise profile (1,1) = coalesce both modes\n", " result = cute.coalesce(layout, target_profile=(1, 1))\n", - " \n", + "\n", " # Print results\n", " print(\">>> Original: \", layout)\n", " print(\">>> Coalesced Result: \", result)\n", "\n", + "\n", "bymode_coalesce_example()" ] }, @@ -387,18 +395,19 @@ " \"\"\"\n", " Demonstrates basic layout composition R = A ◦ B\n", " \"\"\"\n", - " A = cute.make_layout((6, 2), stride=(cutlass.Int32(8), 2)) # Dynamic stride\n", + " A = cute.make_layout((6, 2), stride=(cutlass.Int32(8), 2)) # Dynamic stride\n", " B = cute.make_layout((4, 3), stride=(3, 1))\n", " R = cute.composition(A, B)\n", "\n", " # Print static and dynamic information\n", " print(\">>> Layout A:\", A)\n", " cute.printf(\">?? Layout A: {}\", A)\n", - " print(\">>> Layout B:\", B) \n", + " print(\">>> Layout B:\", B)\n", " cute.printf(\">?? Layout B: {}\", B)\n", " print(\">>> Composition R = A ◦ B:\", R)\n", " cute.printf(\">?? Composition R: {}\", R)\n", "\n", + "\n", "composition_example()" ] }, @@ -438,14 +447,8 @@ " Shows difference between static and dynamic composition results\n", " \"\"\"\n", " # Static version - using compile-time values\n", - " A_static = cute.make_layout(\n", - " (10, 2), \n", - " stride=(16, 4)\n", - " )\n", - " B_static = cute.make_layout(\n", - " (5, 4), \n", - " stride=(1, 5)\n", - " )\n", + " A_static = cute.make_layout((10, 2), stride=(16, 4))\n", + " B_static = cute.make_layout((5, 4), stride=(1, 5))\n", " R_static = cute.composition(A_static, B_static)\n", "\n", " # Static print shows compile-time info\n", @@ -457,20 +460,21 @@ " # Dynamic version - using runtime Int32 values\n", " A_dynamic = cute.make_layout(\n", " (cutlass.Int32(10), cutlass.Int32(2)),\n", - " stride=(cutlass.Int32(16), cutlass.Int32(4))\n", + " stride=(cutlass.Int32(16), cutlass.Int32(4)),\n", " )\n", " B_dynamic = cute.make_layout(\n", " (cutlass.Int32(5), cutlass.Int32(4)),\n", - " stride=(cutlass.Int32(1), cutlass.Int32(5))\n", + " stride=(cutlass.Int32(1), cutlass.Int32(5)),\n", " )\n", " R_dynamic = cute.composition(A_dynamic, B_dynamic)\n", - " \n", + "\n", " # Dynamic printf shows runtime values\n", " cute.printf(\">?? Dynamic composition:\")\n", " cute.printf(\">?? A_dynamic: {}\", A_dynamic)\n", " cute.printf(\">?? B_dynamic: {}\", B_dynamic)\n", " cute.printf(\">?? R_dynamic: {}\", R_dynamic)\n", "\n", + "\n", "composition_static_vs_dynamic_layout()" ] }, @@ -511,12 +515,12 @@ " \"\"\"\n", " # Define the original layout A\n", " A = cute.make_layout(\n", - " (cutlass.Int32(12), (cutlass.Int32(4), cutlass.Int32(8))), \n", - " stride=(cutlass.Int32(59), (cutlass.Int32(13), cutlass.Int32(1)))\n", + " (cutlass.Int32(12), (cutlass.Int32(4), cutlass.Int32(8))),\n", + " stride=(cutlass.Int32(59), (cutlass.Int32(13), cutlass.Int32(1))),\n", " )\n", "\n", " # Define the tiler for by-mode composition\n", - " tiler = (3, 8) # Apply 3:1 to mode-0 and 8:1 to mode-1\n", + " tiler = (3, 8) # Apply 3:1 to mode-0 and 8:1 to mode-1\n", "\n", " # Apply by-mode composition\n", " result = cute.composition(A, tiler)\n", @@ -529,6 +533,7 @@ " print(\">>> By-mode Composition Result:\", result)\n", " cute.printf(\">?? By-mode Composition Result: {}\", result)\n", "\n", + "\n", "bymode_composition_example()" ] }, @@ -571,19 +576,20 @@ " \"\"\"\n", " # Define the original layout\n", " layout = cute.make_layout((4, 2, 3), stride=(2, 1, 8)) # (4,2,3):(2,1,8)\n", - " \n", + "\n", " # Define the tiler\n", " tiler = cute.make_layout(4, stride=2) # Apply to layout 4:2\n", - " \n", + "\n", " # Apply logical divide\n", " result = cute.logical_divide(layout, tiler=tiler)\n", - " \n", + "\n", " # Print results\n", " print(\">>> Layout:\", layout)\n", " print(\">>> Tiler :\", tiler)\n", " print(\">>> Logical Divide Result:\", result)\n", " cute.printf(\">?? Logical Divide Result: {}\", result)\n", "\n", + "\n", "logical_divide_1d_example()" ] }, @@ -620,21 +626,26 @@ " Result Shape : ((TileM,RestM), (TileN,RestN), L, ...)\n", " \"\"\"\n", " # Define the original layout\n", - " layout = cute.make_layout((9, (4, 8)), stride=(59, (13, 1))) # (9,(4,8)):(59,(13,1))\n", - " \n", + " layout = cute.make_layout(\n", + " (9, (4, 8)), stride=(59, (13, 1))\n", + " ) # (9,(4,8)):(59,(13,1))\n", + "\n", " # Define the tiler\n", - " tiler = (cute.make_layout(3, stride=3), # Apply to mode-0 layout 3:3\n", - " cute.make_layout((2, 4), stride=(1, 8))) # Apply to mode-1 layout (2,4):(1,8)\n", - " \n", + " tiler = (\n", + " cute.make_layout(3, stride=3), # Apply to mode-0 layout 3:3\n", + " cute.make_layout((2, 4), stride=(1, 8)),\n", + " ) # Apply to mode-1 layout (2,4):(1,8)\n", + "\n", " # Apply logical divide\n", " result = cute.logical_divide(layout, tiler=tiler)\n", - " \n", + "\n", " # Print results\n", " print(\">>> Layout:\", layout)\n", " print(\">>> Tiler :\", tiler)\n", " print(\">>> Logical Divide Result:\", result)\n", " cute.printf(\">?? Logical Divide Result: {}\", result)\n", "\n", + "\n", "logical_divide_2d_example()" ] }, @@ -673,21 +684,26 @@ " Result Shape : ((TileM,TileN), (RestM,RestN,L,...))\n", " \"\"\"\n", " # Define the original layout\n", - " layout = cute.make_layout((9, (4, 8)), stride=(59, (13, 1))) # (9,(4,8)):(59,(13,1))\n", - " \n", + " layout = cute.make_layout(\n", + " (9, (4, 8)), stride=(59, (13, 1))\n", + " ) # (9,(4,8)):(59,(13,1))\n", + "\n", " # Define the tiler\n", - " tiler = (cute.make_layout(3, stride=3), # Apply to mode-0 layout 3:3\n", - " cute.make_layout((2, 4), stride=(1, 8))) # Apply to mode-1 layout (2,4):(1,8)\n", - " \n", + " tiler = (\n", + " cute.make_layout(3, stride=3), # Apply to mode-0 layout 3:3\n", + " cute.make_layout((2, 4), stride=(1, 8)),\n", + " ) # Apply to mode-1 layout (2,4):(1,8)\n", + "\n", " # Apply zipped divide\n", " result = cute.zipped_divide(layout, tiler=tiler)\n", - " \n", + "\n", " # Print results\n", " print(\">>> Layout:\", layout)\n", " print(\">>> Tiler :\", tiler)\n", " print(\">>> Zipped Divide Result:\", result)\n", " cute.printf(\">?? Zipped Divide Result: {}\", result)\n", "\n", + "\n", "zipped_divide_example()" ] }, @@ -724,21 +740,26 @@ " Result Shape : ((TileM,TileN), RestM, RestN, L, ...)\n", " \"\"\"\n", " # Define the original layout\n", - " layout = cute.make_layout((9, (4, 8)), stride=(59, (13, 1))) # (9,(4,8)):(59,(13,1))\n", - " \n", + " layout = cute.make_layout(\n", + " (9, (4, 8)), stride=(59, (13, 1))\n", + " ) # (9,(4,8)):(59,(13,1))\n", + "\n", " # Define the tiler\n", - " tiler = (cute.make_layout(3, stride=3), # Apply to mode-0 layout 3:3\n", - " cute.make_layout((2, 4), stride=(1, 8))) # Apply to mode-1 layout (2,4):(1,8)\n", - " \n", + " tiler = (\n", + " cute.make_layout(3, stride=3), # Apply to mode-0 layout 3:3\n", + " cute.make_layout((2, 4), stride=(1, 8)),\n", + " ) # Apply to mode-1 layout (2,4):(1,8)\n", + "\n", " # Apply tiled divide\n", " result = cute.tiled_divide(layout, tiler=tiler)\n", - " \n", + "\n", " # Print results\n", " print(\">>> Layout:\", layout)\n", " print(\">>> Tiler :\", tiler)\n", " print(\">>> Tiled Divide Result:\", result)\n", " cute.printf(\">?? Tiled Divide Result: {}\", result)\n", "\n", + "\n", "tiled_divide_example()" ] }, @@ -775,21 +796,26 @@ " Result Shape : (TileM, TileN, RestM, RestN, L, ...)\n", " \"\"\"\n", " # Define the original layout\n", - " layout = cute.make_layout((9, (4, 8)), stride=(59, (13, 1))) # (9,(4,8)):(59,(13,1))\n", - " \n", + " layout = cute.make_layout(\n", + " (9, (4, 8)), stride=(59, (13, 1))\n", + " ) # (9,(4,8)):(59,(13,1))\n", + "\n", " # Define the tiler\n", - " tiler = (cute.make_layout(3, stride=3), # Apply to mode-0 layout 3:3\n", - " cute.make_layout((2, 4), stride=(1, 8))) # Apply to mode-1 layout (2,4):(1,8)\n", - " \n", + " tiler = (\n", + " cute.make_layout(3, stride=3), # Apply to mode-0 layout 3:3\n", + " cute.make_layout((2, 4), stride=(1, 8)),\n", + " ) # Apply to mode-1 layout (2,4):(1,8)\n", + "\n", " # Apply flat divide\n", " result = cute.flat_divide(layout, tiler=tiler)\n", - " \n", + "\n", " # Print results\n", " print(\">>> Layout:\", layout)\n", " print(\">>> Tiler :\", tiler)\n", " print(\">>> Flat Divide Result:\", result)\n", " cute.printf(\">?? Flat Divide Result: {}\", result)\n", "\n", + "\n", "flat_divide_example()" ] }, @@ -834,19 +860,20 @@ " \"\"\"\n", " # Define the original layout\n", " layout = cute.make_layout((2, 2), stride=(4, 1)) # (2,2):(4,1)\n", - " \n", + "\n", " # Define the tiler\n", " tiler = cute.make_layout(6, stride=1) # Apply to layout 6:1\n", - " \n", + "\n", " # Apply logical product\n", " result = cute.logical_product(layout, tiler=tiler)\n", - " \n", + "\n", " # Print results\n", " print(\">>> Layout:\", layout)\n", " print(\">>> Tiler :\", tiler)\n", " print(\">>> Logical Product Result:\", result)\n", " cute.printf(\">?? Logical Product Result: {}\", result)\n", "\n", + "\n", "logical_product_1d_example()" ] }, @@ -886,16 +913,16 @@ " \"\"\"\n", " # Define the original layout\n", " layout = cute.make_layout((2, 5), stride=(5, 1))\n", - " \n", + "\n", " # Define the tiler\n", " tiler = cute.make_layout((3, 4), stride=(1, 3))\n", - " \n", + "\n", " # Apply blocked product\n", " blocked_result = cute.blocked_product(layout, tiler=tiler)\n", "\n", " # Apply raked product\n", " raked_result = cute.raked_product(layout, tiler=tiler)\n", - " \n", + "\n", " # Print results\n", " print(\">>> Layout:\", layout)\n", " print(\">>> Tiler :\", tiler)\n", @@ -904,6 +931,7 @@ " cute.printf(\">?? Blocked Product Result: {}\", blocked_result)\n", " cute.printf(\">?? Raked Product Result: {}\", raked_result)\n", "\n", + "\n", "blocked_raked_product_example()" ] }, @@ -950,16 +978,16 @@ " \"\"\"\n", " # Define the original layout\n", " layout = cute.make_layout((2, 5), stride=(5, 1))\n", - " \n", + "\n", " # Define the tiler\n", " tiler = cute.make_layout((3, 4), stride=(1, 3))\n", "\n", " # Apply zipped product\n", " zipped_result = cute.zipped_product(layout, tiler=tiler)\n", - " \n", + "\n", " # Apply tiled product\n", " tiled_result = cute.tiled_product(layout, tiler=tiler)\n", - " \n", + "\n", " # Apply flat product\n", " flat_result = cute.flat_product(layout, tiler=tiler)\n", "\n", @@ -973,6 +1001,7 @@ " cute.printf(\">?? Tiled Product Result: {}\", tiled_result)\n", " cute.printf(\">?? Flat Product Result: {}\", flat_result)\n", "\n", + "\n", "zipped_tiled_flat_product_example()" ] } diff --git a/examples/python/CuTeDSL/notebooks/data_types.ipynb b/examples/python/CuTeDSL/notebooks/data_types.ipynb index dc305fff..dd05177a 100644 --- a/examples/python/CuTeDSL/notebooks/data_types.ipynb +++ b/examples/python/CuTeDSL/notebooks/data_types.ipynb @@ -6,8 +6,6 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import List\n", - "\n", "import cutlass\n", "import cutlass.cute as cute" ] @@ -83,12 +81,13 @@ "@cute.jit\n", "def bar():\n", " a = cutlass.Float32(3.14)\n", - " print(\"a(static) =\", a) # prints `a(static) = ?`\n", - " cute.printf(\"a(dynamic) = {}\", a) # prints `a(dynamic) = 3.140000`\n", + " print(\"a(static) =\", a) # prints `a(static) = ?`\n", + " cute.printf(\"a(dynamic) = {}\", a) # prints `a(dynamic) = 3.140000`\n", "\n", " b = cutlass.Int32(5)\n", - " print(\"b(static) =\", b) # prints `b(static) = 5`\n", - " cute.printf(\"b(dynamic) = {}\", b) # prints `b(dynamic) = 5`\n", + " print(\"b(static) =\", b) # prints `b(static) = 5`\n", + " cute.printf(\"b(dynamic) = {}\", b) # prints `b(dynamic) = 5`\n", + "\n", "\n", "bar()" ] @@ -154,6 +153,7 @@ " f = e.to(cutlass.Int8)\n", " cute.printf(\"Int32({}) => Int8({}) (truncated due to range limitation)\", e, f)\n", "\n", + "\n", "type_conversion()" ] }, @@ -241,7 +241,8 @@ " not_a = ~a\n", " cute.printf(\"~a = {}\", not_a)\n", "\n", - "operator_demo()\n" + "\n", + "operator_demo()" ] } ], diff --git a/examples/python/CuTeDSL/notebooks/elementwise_add.ipynb b/examples/python/CuTeDSL/notebooks/elementwise_add.ipynb index b4dc8616..5ea1bf1c 100644 --- a/examples/python/CuTeDSL/notebooks/elementwise_add.ipynb +++ b/examples/python/CuTeDSL/notebooks/elementwise_add.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "editable": true, "slideshow": { @@ -14,6 +14,7 @@ "source": [ "import torch\n", "from functools import partial\n", + "from typing import List\n", "\n", "import cutlass\n", "import cutlass.cute as cute\n", @@ -24,72 +25,97 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Tutorial: Elementwise Add Kernel in CuTe DSL\n", + "# Kernel Tutorial: Building an Efficient Elementwise Add Kernel with CuTe DSL\n", "\n", - "This tutorial demonstrates how to implement a simple elementwise\n", - "addition kernel using the CuTe DSL (Domain Specific Language).\n", + "This tutorial demonstrates how to implement and optimize a GPU elementwise addition kernel using the CuTe DSL. \n", "\n", + "## Learning Objectives\n", "\n", + "In this tutorial, you will learn building an efficient elementwise kernel in CuTe DSL step by step:\n", + "- How to implement basic GPU kernels using CuTe DSL in basic CUDA techniques\n", + "- How to benchmark performance of the kernel\n", + "- How to tile and partition tensor and map to basic CuTe Layout\n", + "- What it Thread & Value Layout and mapping from thread & value index to logical coordinate\n", + "- How to implement advanced kernel with TV layout and tune performance to achieve peak performance\n", "\n", - "Elementwise Addition\n", - "---------------------\n", + "## Understanding Elementwise Addition\n", "\n", - "Elementwise addition is a fundamental operation in linear algebra.\n", - "Given two tensors of the same shape, the operation performs element-wise\n", - "addition to produce a result tensor of the same shape.\n", + "Elementwise addition is a fundamental operation in linear algebra and deep learning. Given two tensors of the same shape, the operation performs element-wise addition to produce a result tensor of the same shape.\n", "\n", - "For two 2D tensors :math:`A` and :math:`B` of shape :math:`(M, N)`,\n", - "the elementwise addition operation :math:`C = A + B` is defined as:\n", + "For two 2D tensors $A$ and $B$ of shape $(M, N)$, the elementwise addition operation $C = A + B$ is defined as:\n", "\n", "$\n", " C_{i,j} = A_{i,j} + B_{i,j}\n", "$\n", "\n", "where:\n", - "\n", "- $i \\in [0, M-1]$ represents the row index\n", "- $j \\in [0, N-1]$ represents the column index\n", - "- $A_{i,j}$, $B_{i,j}$, and $C_{i,j}$ are the elements at position $(i,j)$ \n", - " in tensors $A$, $B$, and $C$ respectively\n", + "- $A_{i,j}$, $B_{i,j}$, and $C_{i,j}$ are the elements at position $(i,j)$ in tensors $A$, $B$, and $C$ respectively\n", "\n", - "This operation is performed independently for each element position,\n", - "making it highly parallelizable and well-suited for GPU implementation.\n", + "This operation has several important characteristics:\n", + "1. **Parallelizable**: Each element can be computed independently\n", + "2. **Memory-bound**: Performance limited by memory bandwidth rather than compute\n", + "3. **Coalescing-sensitive**: Efficiency depends on memory access patterns\n", + "4. **Vectorization-friendly**: Multiple elements can be processed together\n", "\n", - "Naive Elementwise Add Kernel\n", - "-----------------------------\n", + "## Naive Elementwise Add Kernel\n", "\n", - "Let's start with a naive implementation that loads each element from\n", - "$A$ and $B$, adds them, and stores the result back to $C$." + "Let's start with a naive implementation to establish a baseline before exploring optimizations." ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ + "# Basic Kernel Implementation\n", + "# ---------------------\n", + "# This is our first implementation of the elementwise add kernel.\n", + "# It follows a simple 1:1 mapping between threads and tensor elements.\n", + "\n", + "\n", "@cute.kernel\n", "def naive_elementwise_add_kernel(\n", - " gA: cute.Tensor,\n", - " gB: cute.Tensor,\n", - " gC: cute.Tensor,\n", + " gA: cute.Tensor, # Input tensor A\n", + " gB: cute.Tensor, # Input tensor B\n", + " gC: cute.Tensor, # Output tensor C = A + B\n", "):\n", - " tidx, _, _ = cute.arch.thread_idx()\n", - " bidx, _, _ = cute.arch.block_idx()\n", - " bdim, _, _ = cute.arch.block_dim()\n", + " # Step 1: Get thread indices\n", + " # ------------------------\n", + " # CUDA threads are organized in a 3D grid of thread blocks\n", + " # Here we only use the x-dimension for simplicity\n", + " tidx, _, _ = cute.arch.thread_idx() # Thread index within block (0 to bdim-1)\n", + " bidx, _, _ = cute.arch.block_idx() # Block index in grid (0 to grid_dim-1)\n", + " bdim, _, _ = cute.arch.block_dim() # Number of threads per block\n", "\n", - " thread_idx = bidx * bdim + tidx\n", + " # Calculate global thread index\n", + " # This gives each thread a unique ID across all blocks\n", + " thread_idx = bidx * bdim + tidx # Global thread ID\n", "\n", - " # Map thread index to logical index of input tensor\n", - " m, n = gA.shape\n", - " ni = thread_idx % n\n", - " mi = thread_idx // n\n", + " # Step 2: Map thread index to tensor coordinates\n", + " # -------------------------------------------\n", + " # Each thread will process one element of the input tensors\n", + " m, n = gA.shape # Get tensor dimensions (M rows × N columns)\n", "\n", - " # Map logical index to physical address via tensor layout\n", - " a_val = gA[mi, ni]\n", - " b_val = gB[mi, ni]\n", + " # Convert linear thread index to 2D coordinates:\n", + " # - ni: column index (0 to n-1)\n", + " # - mi: row index (0 to m-1)\n", + " ni = thread_idx % n # Column index (faster varying dimension)\n", + " mi = thread_idx // n # Row index (slower varying dimension)\n", "\n", - " # Perform element-wise addition\n", + " # Step 3: Load and process data\n", + " # ---------------------------\n", + " # Load values from input tensors\n", + " # The tensor layout automatically handles the conversion from\n", + " # logical indices (mi, ni) to physical memory addresses\n", + " a_val = gA[mi, ni] # Load element from tensor A\n", + " b_val = gB[mi, ni] # Load element from tensor B\n", + "\n", + " # Step 4: Store result\n", + " # ------------------\n", + " # Write the sum back to the output tensor\n", " gC[mi, ni] = a_val + b_val" ] }, @@ -99,133 +125,267 @@ "source": [ "### Structure of the Kernel\n", "\n", - "The naive kernel simply maps each thread to one element with a 1-to-1 mapping.\n", - "In this kernel, we don't use CuTe layout algebra but only use basic\n", - "addressing to index the tensor.\n", + "The naive kernel implementation follows a straightforward but effective structure for parallel processing on the GPU. Here's a detailed breakdown of how it works:\n", "\n", - "We can launch the kernel with the following JIT function:" + "1. **Thread Organization and Indexing**\n", + " - Each CUDA thread is uniquely identified using a combination of:\n", + " * `thread_idx` (tidx): Thread index within a block (0 to bdim-1)\n", + " * `block_idx` (bidx): Block index in the grid\n", + " * `block_dim` (bdim): Number of threads per block\n", + " - Global thread ID is calculated as: `thread_idx = bidx * bdim + tidx`\n", + "\n", + "2. **Coordinate Mapping**\n", + " - The kernel maps each thread's global ID to 2D tensor coordinates:\n", + " * `ni = thread_idx % n` (column index - faster varying)\n", + " * `mi = thread_idx // n` (row index - slower varying)\n", + " - This mapping ensures coalesced memory access by having adjacent threads access adjacent memory locations\n", + "\n", + "3. **Memory Access Pattern**\n", + " - Each thread:\n", + " * Loads one element from tensor A: `a_val = gA[mi, ni]`\n", + " * Loads one element from tensor B: `b_val = gB[mi, ni]`\n", + " * Performs addition: `a_val + b_val`\n", + " * Stores result to tensor C: `gC[mi, ni] = result`\n", + " - Memory Considerations\n", + " * Uses 1:1 thread-to-element mapping\n", + " * Memory accesses are coalesced when threads in a warp access consecutive elements\n", + " * No explicit use of shared memory or register blocking\n", + " * Limited ability to hide memory latency due to single element processing\n", + "\n", + "This naive implementation provides a baseline for understanding more optimized versions that follow, which introduce:\n", + "- Vectorized memory access\n", + "- Thread and value (TV) layouts\n", + "- Advanced tiling strategies\n", + "- Custom binary operations\n", + "\n", + "For more details about coalesced memory access, please read: https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/#coalesced-access-to-global-memory\n", + "\n", + "\n", + "### Kernel Launch Configuration and Testing\n", + "\n", + "This section demonstrates how to:\n", + "1. Configure and launch the kernel with `cute.jit` function\n", + "2. Set up test data with `torch`\n", + "3. Verify correctness\n", + "\n", + "**Launch Configuration**\n", + " - Uses 256 threads per block (common choice for good occupancy)\n", + " - Grid size calculated based on total elements: `(m * n) // threads_per_block`\n", + " - Single dimension block and grid configuration for simplicity\n", + "\n", + "#### Host JIT function to launch kernel" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "@cute.jit\n", + "@cute.jit # Just-in-time compilation decorator\n", "def naive_elementwise_add(\n", - " mA: cute.Tensor,\n", - " mB: cute.Tensor,\n", - " mC: cute.Tensor\n", + " mA: cute.Tensor, # Input tensor A\n", + " mB: cute.Tensor, # Input tensor B\n", + " mC: cute.Tensor, # Output tensor C\n", "):\n", + " # Configure kernel launch parameters\n", + " # --------------------------------\n", + " # Choose number of threads per block\n", + " # 256 is a common choice as it:\n", + " # - Allows good occupancy on most GPUs\n", + " # - Is a multiple of 32 (warp size)\n", + " # - Provides enough threads for latency hiding\n", " num_threads_per_block = 256\n", "\n", - " m, n = mA.shape\n", + " # Get input dimensions\n", + " m, n = mA.shape # Matrix dimensions (M rows × N columns)\n", + "\n", + " # Create kernel instance\n", " kernel = naive_elementwise_add_kernel(mA, mB, mC)\n", - " kernel.launch(grid=((m * n) // num_threads_per_block, 1, 1),\n", - " block=(num_threads_per_block, 1, 1))\n", "\n", - "M, N = 2048, 2048\n", - "\n", - "a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", - "b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", - "c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n", - "\n", - "a_ = from_dlpack(a, assumed_align=16)\n", - "b_ = from_dlpack(b, assumed_align=16)\n", - "c_ = from_dlpack(c, assumed_align=16)\n", - "\n", - "# Compile kernel\n", - "naive_elementwise_add_ = cute.compile(naive_elementwise_add, a_, b_, c_)\n", - "naive_elementwise_add_(a_, b_, c_)\n", - "\n", - "# verify correctness\n", - "torch.testing.assert_close(c, a + b)" + " # Launch kernel with calculated grid dimensions\n", + " # -------------------------------------------\n", + " # Grid size calculation:\n", + " # - Total elements: m * n\n", + " # - Blocks needed: ceil(total_elements / threads_per_block)\n", + " # - Using integer division here assumes m * n is multiple of threads_per_block\n", + " kernel.launch(\n", + " grid=((m * n) // num_threads_per_block, 1, 1), # Number of blocks in x,y,z\n", + " block=(num_threads_per_block, 1, 1), # Threads per block in x,y,z\n", + " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Benchmark performance\n", - "\n", - "Here's a utility function to benchmark our kernel implementations:" + "#### Setup test data with torch" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "def benchmark(callable, *, num_warmups, num_iterations):\n", - " start_event = torch.cuda.Event(enable_timing=True)\n", - " end_event = torch.cuda.Event(enable_timing=True)\n", + "# Test Setup\n", + "# ----------\n", + "# Define test dimensions\n", + "M, N = 16384, 8192 # Using large matrices to measure performance\n", "\n", - " torch.cuda.synchronize()\n", + "# Create test data on GPU\n", + "# ----------------------\n", + "# Using float16 (half precision) for:\n", + "# - Reduced memory bandwidth requirements\n", + "# - Better performance on modern GPUs\n", + "a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16) # Random input A\n", + "b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16) # Random input B\n", + "c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16) # Output buffer\n", "\n", - " for _ in range(num_warmups):\n", - " callable()\n", + "# Calculate total elements for bandwidth calculations\n", + "num_elements = sum([a.numel(), b.numel(), c.numel()])\n", "\n", - " start_event.record(stream=torch.cuda.current_stream())\n", - " for _ in range(num_iterations):\n", - " callable()\n", - " end_event.record(stream=torch.cuda.current_stream())\n", - " torch.cuda.synchronize()\n", - "\n", - " elapsed_time = start_event.elapsed_time(end_event)\n", - " avg_time = elapsed_time / num_iterations\n", - "\n", - " print(f\"Average execution time: {avg_time:.4f} ms\")\n", - " print(f\"Throughput: {(3 * a.numel() * 2) / (avg_time / 1000) / 1e9:.2f} GB/s\")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Average execution time: 0.0385 ms\n", - "Throughput: 653.44 GB/s\n" - ] - } - ], - "source": [ - "benchmark(partial(naive_elementwise_add_, a_, b_, c_), num_warmups=5, num_iterations=100)" + "# Convert PyTorch tensors to CuTe tensors\n", + "# -------------------------------------\n", + "# from_dlpack creates CuTe tensor views of PyTorch tensors\n", + "# assumed_align=16 ensures proper memory alignment for vectorized access\n", + "a_ = from_dlpack(a, assumed_align=16) # CuTe tensor A\n", + "b_ = from_dlpack(b, assumed_align=16) # CuTe tensor B\n", + "c_ = from_dlpack(c, assumed_align=16) # CuTe tensor C" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Performance Analysis\n", + "#### Compile and run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compile the kernel for the specific input types\n", + "naive_elementwise_add_ = cute.compile(naive_elementwise_add, a_, b_, c_)\n", "\n", - "While our naive implementation maps thread indices to contiguous tensor\n", - "dimensions for coalesced memory access, it doesn't have enough\n", - "in-flight load & store operations to hide memory latency.\n", + "# Run the kernel\n", + "naive_elementwise_add_(a_, b_, c_)\n", "\n", - "According to Little's Law:\n", + "# Verify Results\n", + "# -------------\n", + "# Compare our kernel output with PyTorch's native implementation\n", + "torch.testing.assert_close(c, a + b) # Raises error if results don't match" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Performance Analysis and Benchmarking\n", + "\n", + "To understand and improve our kernel's performance, we need to measure its execution time and memory throughput. Let's analyze several key metrics:\n", + "\n", + "* **Execution Time**\n", + " - Measures raw kernel performance in microseconds\n", + " - Lower is better\n", + " - Affected by GPU clock speed, memory bandwidth, and kernel efficiency\n", + "* **Memory Throughput**\n", + " - Measures how fast we can copy data (GB/s)\n", + " - Higher is better\n", + " - Theoretical peak varies by GPU model\n", + " - For elementwise add:\n", + " * Read: 2 elements (A and B)\n", + " * Write: 1 element (C)\n", + " * Total bytes = (2 reads + 1 write) × elements × sizeof(dtype)\n", + "\n", + "Below is our benchmarking utility that measures these metrics:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def benchmark(callable, a_, b_, c_):\n", + " avg_time_us = cute.testing.benchmark(\n", + " callable,\n", + " kernel_arguments=cute.testing.JitArguments(a_, b_, c_),\n", + " warmup_iterations=5,\n", + " iterations=100,\n", + " )\n", + "\n", + " # Calculate metrics\n", + " # ----------------\n", + " dtype = a_.element_type\n", + "\n", + " # Calculate total bytes transferred:\n", + " # - 2 reads (A and B) + 1 write (C)\n", + " # - Each element is dtype.width bits\n", + " bytes_per_element = dtype.width // 8\n", + " total_bytes = num_elements * bytes_per_element\n", + "\n", + " # Calculate achieved bandwidth\n", + " achieved_bandwidth = total_bytes / (avg_time_us * 1000) # GB/s\n", + "\n", + " # Print results\n", + " # ------------\n", + " print(f\"Performance Metrics:\")\n", + " print(f\"-------------------\")\n", + " print(f\"Kernel execution time: {avg_time_us:.4f} us\")\n", + " print(f\"Memory throughput: {achieved_bandwidth:.2f} GB/s\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "benchmark(naive_elementwise_add_, a_, b_, c_)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Theoretical Analysis\n", + "\n", + "This section analyze the performance characteristics and optimization opportunities of our elementwise addition kernel through several theoretical frameworks.\n", + "\n", + "#### Little's Law\n", + "\n", + "Little's Law provides crucial insights into relationship between latency, bandwidth and inflight operations:\n", "\n", "$ L = \\lambda \\times W $\n", "\n", "Where:\n", - "- $L$ is the average number of items in a system\n", - "- $\\lambda$ is the average arrival rate of items (bandwidth)\n", - "- $W$ is the average time an item spends in the system (latency)\n", + "- $L$: Number of in-flight memory operations needed\n", + "- $\\lambda$: Target memory bandwidth (bytes/cycle)\n", + "- $W$: Memory system latency (cycles)\n", "\n", - "For our elementwise addition kernel:\n", + "According to *Little's Law*, naive implementation has\n", + " - 1 element (4 bytes load + 2 bytes store) per thread\n", + " - 256 threads/block × N blocks\n", + " - Limited in-flight operations\n", "\n", - "1. $L$: The number of load & store operations in-flight\n", - "2. $\\lambda$ (Bandwidth): Data transfer rate between memory and compute units\n", - "3. $W$ (Latency): Round-trip delay of memory requests\n", + "In some GPUs, it's insufficient parallelism to saturate memory bandwidth.\n", "\n", - "For memory-bound operations like elementwise addition, performance is\n", - "limited by the number of in-flight load & store operations.\n", + "#### Optimization Strategies\n", "\n", + "Based on this analysis, one commonly used technique is **Vectorization**. Instead of 1 element \n", + "per load per thread, vectorization allows multiple element per load\n", + " - Reduces instruction count\n", + " - Improves memory coalescing\n", + " - Increases operations in flight" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ "## Vectorized Load and Store\n", "\n", "To improve performance according to Little's Law, we need to increase the number\n", @@ -257,7 +417,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -273,8 +433,8 @@ "\n", " thread_idx = bidx * bdim + tidx\n", "\n", - " # Map thread index to logical index of input tensor\n", - " m, n = gA.shape[1] # thread-domain\n", + " # Map thread index to logical index of input tensor in unit of vector\n", + " m, n = gA.shape[1] # thread-domain\n", " ni = thread_idx % n\n", " mi = thread_idx // n\n", "\n", @@ -296,55 +456,58 @@ "with one key difference: the tensor slicing pattern. By using `(None, (mi, ni))` as the slice indices,\n", "we can extract a `(1,4)` sub-tensor from `gA`, `gB` and `gC` like \n", "\n", + "$ gA[(None, (mi, ni))]: $\n", + "\n", + "$\n", + " \\begin{array}{ccccc}\n", + " Layout: & ( & (1,4) & , & (2048,512) & ) & : & ((0,1),(2048,4)) & \\xrightarrow{\\text{slice}} & ((1,4)):((0,1)) \\\\\n", + " & & \\underbrace{\\phantom{(1,4)}} & & \\underbrace{\\phantom{(2048,512)}} & & \\\\\n", + " Coord: & ( & None & , & (mi, ni) & ) & &\n", + " \\end{array}\n", + "$\n", + "\n", + "Then tensor data can be loaded into vector via the `gA[(None, (mi, ni))].load()` method. It is equivalent to\n", + "\n", "```python\n", - "gA[(None, (mi, ni))]\n", - "\n", + "v0 = gA[(0, (mi, ni))] # => mA[(mi, ni * 4 + 0)]\n", + "v1 = gA[(1, (mi, ni))] # => mA[(mi, ni * 4 + 1)]\n", + "v2 = gA[(2, (mi, ni))] # => mA[(mi, ni * 4 + 2)]\n", + "v3 = gA[(3, (mi, ni))] # => mA[(mi, ni * 4 + 3)]\n", "```\n", "\n", - "Then tensor data can be loaded into vector via the `.load()` method.\n", + "### Assumed Alignment\n", "\n", + "In order to guide compile to use vectorized load/store, we must tell compiler to assume alignment of incoming pointer. \n", + "It's on users side to guarantee actual pointer at runtime meet the alignment restriction.\n", "\n", + "```python\n", + "a_ = from_dlpack(a, assumed_align=16)\n", + "b_ = from_dlpack(b, assumed_align=16)\n", + "c_ = from_dlpack(c, assumed_align=16)\n", + "\n", + "# Compile kernel with alignment assumption\n", + "compiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)\n", "```\n", - " slice\n", - " ((1,4),(2048,512)):((0,1),(2048,4)) ==> ((1,4)):((0,1))\n", - " ^ ^ ^\n", - " | | |\n", - " (None, (mi, ni))\n", - "```" + "\n", + "It's worth to note that partitioned or tiled tensor could have different alignment of its base pointer because of offset\n", + "during sub-slice." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[DSL INFO] Tiled Tensors:\n", - "[DSL INFO] gA = tensor> o ((1,4),(2048,512)):((0,1),(2048,4))>\n", - "[DSL INFO] gB = tensor> o ((1,4),(2048,512)):((0,1),(2048,4))>\n", - "[DSL INFO] gC = tensor> o ((1,4),(2048,512)):((0,1),(2048,4))>\n", - "[DSL INFO] sliced gA = tensor> o ((1,4)):((0,1))>\n", - "[DSL INFO] sliced gB = tensor> o ((1,4)):((0,1))>\n" - ] - } - ], + "outputs": [], "source": [ "@cute.jit\n", - "def vectorized_elementwise_add(\n", - " mA: cute.Tensor,\n", - " mB: cute.Tensor,\n", - " mC: cute.Tensor\n", - "):\n", + "def vectorized_elementwise_add(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):\n", " threads_per_block = 256\n", "\n", " gA = cute.zipped_divide(mA, (1, 4))\n", " gB = cute.zipped_divide(mB, (1, 4))\n", " gC = cute.zipped_divide(mC, (1, 4))\n", "\n", - " print(f\"[DSL INFO] Tiled Tensors:\")\n", + " print(\"[DSL INFO] Tiled Tensors:\")\n", " print(f\"[DSL INFO] gA = {gA}\")\n", " print(f\"[DSL INFO] gB = {gB}\")\n", " print(f\"[DSL INFO] gC = {gC}\")\n", @@ -354,6 +517,7 @@ " block=(threads_per_block, 1, 1),\n", " )\n", "\n", + "\n", "a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", "b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", "c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n", @@ -371,20 +535,11 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Average execution time: 0.0202 ms\n", - "Throughput: 1244.98 GB/s\n" - ] - } - ], + "outputs": [], "source": [ - "benchmark(partial(compiled_func, a_, b_, c_), num_warmups=5, num_iterations=100)" + "benchmark(compiled_func, a_, b_, c_)" ] }, { @@ -394,65 +549,107 @@ "## TV Layout\n", "\n", "Both the naive and vectorized kernels follow a common pattern to map thread indices\n", - "to physical addresses:\n", + "to physical addresses in two steps:\n", "\n", - "Step 1: Map thread index to logical M/N coordinates\n", + "Step 1: Map thread index to logical coordinates in `(M, N)`\n", + "\n", + "* `mi = thread_idx // n`\n", + "* `ni = thread_idx % n`\n", + "\n", + "In native version, each thread process 1 element, in this case, `mi` and `ni` is logical\n", + "coordinate into data tensor `mA`, `mB` and `mC`.\n", + "\n", + "Int vectorized version, each thread process multiple values of input and output tensor.\n", + "logical coordinate should be computed with both thread and value index.\n", + "\n", + "* `thread_idx // n`\n", + "* `(thread_idx % n) * 4 + value_idx`\n", + "\n", + "\n", + "Step 2: Map logical coordinates in `(M, N)` to physical addresses using the tensor layout\n", + "\n", + "* Vectorized Load\n", "\n", "```python\n", - " mi = thread_idx // n\n", - " ni = thread_idx % n\n", + " frgA = gA[(None, (mi, ni))].load()\n", "```\n", "\n", - "Step 2: Map logical M/N coordinates to physical addresses using the tensor layout\n", + "* Elementwise Load (less efficient)\n", "\n", "```python\n", - " a[(None, (mi, ni))].load()\n", + " frgA0 = mA[(mi, ni * 4 + 0)]\n", + " frgA1 = mA[(mi, ni * 4 + 1)]\n", + " frgA2 = mA[(mi, ni * 4 + 2)]\n", + " frgA3 = mA[(mi, ni * 4 + 3)]\n", + "\n", + " # Or use divided layout\n", + "\n", + " frgA0 = gA[(0, (mi, ni))]\n", + " frgA1 = gA[(1, (mi, ni))]\n", + " frgA2 = gA[(2, (mi, ni))]\n", + " frgA3 = gA[(3, (mi, ni))]\n", "```\n", "\n", - "CuTe uses TV layout to represent this mapping from thread index and value index\n", + "CuTe introduces TV layout to represent this mapping from thread index and value index\n", "(i.e., the 4 elements loaded per thread) to the logical coordinate space of a tensor.\n", "By configuring different TV layouts, we can experiment with different memory access\n", "patterns with minimal code changes.\n", "\n", - "The following example demonstrates two levels of tiling: at the thread-block level\n", - "and at the thread level.\n", + "**Definition:** *TV Layout* is rank-2 layout which maps `(thread_index, value_index)` \n", + "to logical coordinate of tensor. \n", + "\n", + "We always have *TV Layout* with canonical form as `(thread_domain, value_domain):(..., ...)`.\n", + "\n", + "With *TV Layout*, each thread can find logical coordinates or indices of data partitioned\n", + "to current thread.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Elementwise with TV Layout\n", + "\n", + "In this example, we rewrite elementwise kernel with two levels of tiling: \n", + "* the thread-block level \n", + "* the thread level with TV Layout and tiling\n", "\n", "For thread-block level tiling, each input & output tensor is first divided\n", - "into a group of ``(TileM, TileN)`` sub-tensors at the host side.\n", + "into a group of ``(TileM, TileN)`` sub-tensors at the host side. Please be noticed that\n", + "in this case, we still use `zipped_divide` but for tiling at thread-block level.\n", "\n", - "Inside the GPU kernel, we provide the thread-block index to the 2nd mode of the tiled tensor\n", - "(``gA[((None, None), bidx)]``), which returns a thread-block local view of\n", - "a single ``(TileM, TileN)`` sub-tensor.\n", + "Inside the GPU kernel, we slice tiled tensor with the thread-block index at the 2nd mode \n", + "as ``gA[((None, None), bidx)]``, which returns a thread-block local view of\n", + "a single ``(TileM, TileN)`` sub-tensor. This sub-tensor maps logical coordinates\n", + "inside ``(TileM, TileN)`` to physical address of elements.\n", "\n", - "For thread level tiling, we compose the sub-tensor (which maps from logical coordinates\n", - "to physical addresses) with the TV layout (which maps from thread & value indices to\n", - "logical coordinates). This gives us a tiled sub-tensor that maps from thread & value\n", - "indices directly to physical addresses.\n", + "At thread level tiling, we compose the above sub-tensor (logical coordinates to physical addresses) \n", + "with the TV layout (thread & value indices to logical coordinates). This gives us a tiled sub-tensor \n", + "that maps from thread & value indices directly to physical addresses.\n", "\n", - "We then provide the thread index to the tiled sub-tensor (``tidfrgA[(tidx, None)]``)\n", - "to get a thread-local view of the data each thread accesses. Note that the thread index\n", - "is now in the 1st mode, as the tiled sub-tensor puts the thread mode before the value mode." + "We then slice it with the thread index as ``tidfrgA[(tidx, None)]`` to get \n", + "a thread-local view of the data each thread accesses. Note that the thread index\n", + "is now in the 1st mode, as TV layout is normally have form ``(thread_domain, value_domain):(...)``.\n", + "\n", + "### Kernel Code" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@cute.kernel\n", "def elementwise_add_kernel(\n", - " gA: cute.Tensor,\n", - " gB: cute.Tensor,\n", - " gC: cute.Tensor,\n", - " tv_layout: cute.Layout\n", + " gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor, tv_layout: cute.Layout\n", "):\n", " tidx, _, _ = cute.arch.thread_idx()\n", " bidx, _, _ = cute.arch.block_idx()\n", "\n", - " #--------------------------------\n", + " # --------------------------------\n", " # slice for thread-block level view\n", - " #--------------------------------\n", + " # --------------------------------\n", " blk_coord = ((None, None), bidx)\n", "\n", " # logical coord -> address\n", @@ -460,9 +657,9 @@ " blkB = gB[blk_coord] # (TileM, TileN) -> physical address\n", " blkC = gC[blk_coord] # (TileM, TileN) -> physical address\n", "\n", - " #--------------------------------\n", + " # --------------------------------\n", " # compose for thread-index & value-index to physical mapping\n", - " #--------------------------------\n", + " # --------------------------------\n", " # blockA: (TileM, TileN) -> physical address\n", " # tv_layout: (tid, vid) -> (TileM, TileN)\n", " # tidfrgA = blkA o tv_layout\n", @@ -471,14 +668,15 @@ " tidfrgB = cute.composition(blkB, tv_layout)\n", " tidfrgC = cute.composition(blkC, tv_layout)\n", "\n", - " print(f\"Composed with TV layout:\")\n", + " print(\"Composed with TV layout:\")\n", " print(f\" tidfrgA: {tidfrgA.type}\")\n", "\n", - " #--------------------------------\n", + " # --------------------------------\n", " # slice for thread-level view\n", - " #--------------------------------\n", + " # --------------------------------\n", " # `None` represent slice of the entire per-thread data\n", " thr_coord = (tidx, None)\n", + " # thr_coord = (tidx, cute.repeat_like(None, gA.shape[1]))\n", "\n", " # slice for threads: vid -> address\n", " thrA = tidfrgA[thr_coord] # (V) -> physical address\n", @@ -492,71 +690,29 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If we take a closer look at the layout of zipped divided input tensor `gA`:\n", - "\n", - "```\n", - "Tiled to Thread Block:\n", - "\n", - " ((16,256),(128,8)) : ((2048,1),(32768,256))\n", - " ~~~~~~~~ ~~~~~~ ~~~~~~~~\n", - " | | |\n", - " | | |\n", - " | `------------------------> Number of Thread Blocks\n", - " | |\n", - " | |\n", - " `--------------------'\n", - " |\n", - " V\n", - " Thread Block\n", - " Tile\n", - "\n", - "Sliced to Thread-Block local sub-tensor (a (16, 256) tile): gA[((None, None), bidx)]\n", - "\n", - " (16,256) : (2048,1)\n", - " ~~~~~~ ~~~~~~\n", - " | | Tiled/Composed with TV Layout\n", - " | | \n", - " | | o ((32,4),(8,4)):((128,4),(16,1))\n", - " V V \n", - "~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~ \n", - "((32,4), (8,4)) : ((4,8192),(1,2048))\n", - " | |\n", - " | `--------> per thread fragment\n", - " |\n", - "Thread Block\n", - " Shape\n", - "\n", - "Sliced to Thread local sub-tensor (a (4,8) tile): tidfrgA[(tidx, None)]\n", - "\n", - "```\n", + "### Host Code\n", "\n", "The host code below shows the construction of the TV layout. By composing\n", - "a thread layout of ``(4,32):(32,1)`` (32 threads read contiguous elements on the row dimension,\n", - "then 4 warps read different rows) with a value layout of ``(4,8):(8,1)`` (each thread reads\n", - "8 contiguous elements on the row dimension across 4 contiguous rows),\n", - "we obtain the TV layout shown in the figure above." + "a thread layout of ``(4,64):(64,1)`` (64 threads read contiguous elements on the row dimension,\n", + "then 64-thread-groups(2 warps) read different rows) with a value layout of ``(16,8):(8,1)`` (each thread reads\n", + "8 contiguous 16b elements on the row dimension across 4 contiguous rows).\n", + "\n", + "In order to generalize, we started with byte-layout to describe layout for elements in bytes. This is\n", + "to ensure use of 128bit vectorized load store. Then we leverage ``recast_layout`` to convert into\n", + "element-layout.\n", + "\n", + "```python\n", + " # src type bits: 8\n", + " # dst type bits: bits of element type\n", + " val_layout = cute.recast_layout(dtype.width, 8, bit_val_layout)\n", + "```" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tiler: (16, 256)\n", - "TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n", - "Tiled Input Tensors:\n", - " gA: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", - " gB: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", - " gC: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", - "Composed with TV layout:\n", - " tidfrgA: !cute.memref, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n" - ] - } - ], + "outputs": [], "source": [ "@cute.jit\n", "def elementwise_add(\n", @@ -565,34 +721,249 @@ " mC: cute.Tensor,\n", "):\n", " # mA layout: (M, N):(N, 1)\n", - " # TV layout map thread & value index to (16, 256) logical tile\n", + " # TV layout map thread & value index to (64, 512) logical tile\n", " # - contiguous thread index maps to mode-1 because input layout is contiguous on\n", " # mode-1 for coalesced load-store\n", - " # - each thread load 8 contiguous element each row and load 4 rows\n", - " thr_layout = cute.make_layout((4, 32), stride=(32, 1))\n", - " val_layout = cute.make_layout((4, 8), stride=(8, 1))\n", + " # - each thread load contiguous 16 bytes each row and load 16 rows\n", + " coalesced_ldst_bytes = 16\n", + "\n", + " # Compile time validation: expect same element type for all input tensors\n", + " assert all(t.element_type == mA.element_type for t in [mA, mB, mC])\n", + " dtype = mA.element_type\n", + "\n", + " thr_layout = cute.make_ordered_layout((4, 64), order=(1, 0))\n", + " val_layout = cute.make_ordered_layout((16, coalesced_ldst_bytes), order=(1, 0))\n", + " val_layout = cute.recast_layout(dtype.width, 8, val_layout)\n", " tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n", - " print(f\"Tiler: {tiler_mn}\")\n", - " print(f\"TV Layout: {tv_layout}\")\n", + "\n", + " print(f\"[DSL INFO] Tiler: {tiler_mn}\")\n", + " print(f\"[DSL INFO] TV Layout: {tv_layout}\")\n", "\n", " gA = cute.zipped_divide(mA, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n", " gB = cute.zipped_divide(mB, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n", " gC = cute.zipped_divide(mC, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n", "\n", - " print(f\"Tiled Input Tensors:\")\n", - " print(f\" gA: {gA.type}\")\n", - " print(f\" gB: {gB.type}\")\n", - " print(f\" gC: {gC.type}\")\n", + " print(\"Tiled Input Tensors:\")\n", + " print(\"[DSL INFO] Tiled Tensors:\")\n", + " print(f\"[DSL INFO] gA = {gA.type}\")\n", + " print(f\"[DSL INFO] gB = {gB.type}\")\n", + " print(f\"[DSL INFO] gC = {gC.type}\")\n", "\n", " # Launch the kernel asynchronously\n", " # Async token(s) can also be specified as dependencies\n", - " elementwise_add_kernel(\n", - " gA, gB, gC, tv_layout\n", - " ).launch(\n", + " elementwise_add_kernel(gA, gB, gC, tv_layout).launch(\n", " grid=[cute.size(gC, mode=[1]), 1, 1],\n", " block=[cute.size(tv_layout, mode=[0]), 1, 1],\n", " )\n", "\n", + "\n", + "a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", + "b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", + "c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n", + "\n", + "a_ = from_dlpack(a, assumed_align=16)\n", + "b_ = from_dlpack(b, assumed_align=16)\n", + "c_ = from_dlpack(c, assumed_align=16)\n", + "\n", + "elementwise_add_ = cute.compile(elementwise_add, a_, b_, c_)\n", + "elementwise_add_(a_, b_, c_)\n", + "\n", + "# verify correctness\n", + "torch.testing.assert_close(c, a + b)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Explanation of Layouts\n", + "\n", + "Let's take a closer look using zipped divided input tensor `gA` as an example.\n", + "We also choose a smaller M/N, `(256,512)`, to make it easier to explain and visualize.\n", + "\n", + "```\n", + "Tiled to Thread Block:\n", + "\n", + " ((16,256),(16,2)) : ((512,1),(8192,256))\n", + " ~~~~~~~~ ~~~~~~ ~~~~~\n", + " | | |\n", + " | | |\n", + " | `-----------------------> Number of Thread Blocks\n", + " | |\n", + " | |\n", + " `-------------------'\n", + " |\n", + " V\n", + " Thread Block\n", + " Tile\n", + "\n", + "Sliced to Thread-Block local sub-tensor (a (16, 256) tile): gA[((None, None), bidx)]\n", + "\n", + " (16,256) : (512,1)\n", + " ~~~~~~ ~~~~~~\n", + " | | Tiled/Composed with TV Layout\n", + " | |\n", + " | | o ((32,4),(8,4)):((128,4),(16,1))\n", + " V V\n", + "~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~\n", + "((32,4),(8,4)) : ((8,2048),(1,512))\n", + " | |\n", + " | `--------> per thread fragment\n", + " |\n", + "Thread Block\n", + " Shape\n", + "\n", + "Sliced to Thread local sub-tensor (a (4,8) tile): tidfrgA[(tidx, None)]\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualization of TV Layout\n", + "\n", + "To visualize TV Layout, we can first install *`cute-viz`*\n", + "\n", + "```\n", + "pip install -U git+https://github.com/NTT123/cute-viz.git\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " from cute_viz import display_tv_layout\n", + "\n", + " @cute.jit\n", + " def visualize():\n", + " # Create and render a layout to file\n", + " # layout = cute.make_layout( ((16,16),(256,2)), stride=((512,8192),(1,256)))\n", + " # display_layout(layout)\n", + "\n", + " tv_layout = cute.make_layout(((32, 4), (8, 4)), stride=((128, 4), (16, 1)))\n", + " display_tv_layout(tv_layout, (16, 256))\n", + "\n", + " thr_block_layout = cute.make_layout((16, 256), stride=(512, 1))\n", + " print(cute.composition(thr_block_layout, tv_layout))\n", + "\n", + " visualize()\n", + "except ImportError:\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "***Why modes of thread domain of TV Layout looks swapped especially when tensor is row major?***\n", + "\n", + "We may notice that *TV Layout* in above example is `((32,4),(8,4)):((128,4),(16,1))`. \n", + "However, on visualization, thread indices are arrange as shape `(4,32)` rather than \n", + "`(32,4)` of *TV Layout*.\n", + "\n", + "This is a commonly asked question by developers from both internal teams and community.\n", + "\n", + "It's important to keep in mind that *TV Layout* maps `(thread_index, value_index)` to \n", + "`(row_index, column_index)` of logical domain `(TileM, TileN)`. However, visualization \n", + "shows **inverse** mapping of logical domain `(TileM, TileN)` to `(thread_domain, value_domain)`,\n", + "because this is more intuitive for human developer.\n", + "\n", + "That's why the shape of domain of *TV Layout* doesn't necessarily match logical view." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "benchmark(elementwise_add_, a_, b_, c_)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Remap/Transpose thread block index\n", + "\n", + "As tensors are row major in this example, we may want thread blocks to load contiguous memory as much as possible.\n", + "\n", + "We can apply a simple thread-block remapping to transpose the mapping of thread block indices in row first order. \n", + "`cute.composition(gA, (None, remap_block))` only apply transpose of 2nd mode of tiled layout but keep \n", + "the 1st mode un-touched.\n", + "\n", + "```python\n", + " remap_block = cute.make_ordered_layout(\n", + " cute.select(gA.shape[1], mode=[1, 0]), order=(1, 0)\n", + " )\n", + " gA = cute.composition(gA, (None, remap_block))\n", + " gB = cute.composition(gB, (None, remap_block))\n", + " gC = cute.composition(gC, (None, remap_block))\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@cute.jit\n", + "def elementwise_add(\n", + " mA: cute.Tensor,\n", + " mB: cute.Tensor,\n", + " mC: cute.Tensor,\n", + "):\n", + " # mA layout: (M, N):(N, 1)\n", + " # TV layout map thread & value index to (64, 512) logical tile\n", + " # - contiguous thread index maps to mode-1 because input layout is contiguous on\n", + " # mode-1 for coalesced load-store\n", + " # - each thread load contiguous 16 bytes each row and load 16 rows\n", + " coalesced_ldst_bytes = 16\n", + "\n", + " # Compile time validation: expect same element type for all input tensors\n", + " assert all(t.element_type == mA.element_type for t in [mA, mB, mC])\n", + " dtype = mA.element_type\n", + "\n", + " thr_layout = cute.make_ordered_layout((4, 64), order=(1, 0))\n", + " val_layout = cute.make_ordered_layout((16, coalesced_ldst_bytes), order=(1, 0))\n", + " val_layout = cute.recast_layout(dtype.width, 8, val_layout)\n", + " tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n", + "\n", + " print(f\"[DSL INFO] Tiler: {tiler_mn}\")\n", + " print(f\"[DSL INFO] TV Layout: {tv_layout}\")\n", + "\n", + " gA = cute.zipped_divide(mA, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n", + " gB = cute.zipped_divide(mB, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n", + " gC = cute.zipped_divide(mC, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n", + "\n", + " # (RestM, RestN) -> (RestN, RestM)\n", + " remap_block = cute.make_ordered_layout(\n", + " cute.select(gA.shape[1], mode=[1, 0]), order=(1, 0)\n", + " )\n", + " gA = cute.composition(gA, (None, remap_block))\n", + " gB = cute.composition(gB, (None, remap_block))\n", + " gC = cute.composition(gC, (None, remap_block))\n", + "\n", + " print(\"Tiled Input Tensors:\")\n", + " print(\"[DSL INFO] Tiled Tensors:\")\n", + " print(f\"[DSL INFO] gA = {gA.type}\")\n", + " print(f\"[DSL INFO] gB = {gB.type}\")\n", + " print(f\"[DSL INFO] gC = {gC.type}\")\n", + "\n", + " # Launch the kernel asynchronously\n", + " # Async token(s) can also be specified as dependencies\n", + " elementwise_add_kernel(gA, gB, gC, tv_layout).launch(\n", + " grid=[cute.size(gC, mode=[1]), 1, 1],\n", + " block=[cute.size(tv_layout, mode=[0]), 1, 1],\n", + " )\n", + "\n", + "\n", "a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", "b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", "c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n", @@ -610,27 +981,18 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Average execution time: 0.0222 ms\n", - "Throughput: 1133.58 GB/s\n" - ] - } - ], + "outputs": [], "source": [ - "benchmark(partial(elementwise_add_, a_, b_, c_), num_warmups=5, num_iterations=200)" + "benchmark(compiled_func, a_, b_, c_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Using Lambda Function\n", + "## Using Lambda Function\n", "\n", "CuTe DSL is built on top of Python. It can leverage Python to implement meta-programming to generate flexible kernels.\n", "E.g. we can write kernel template that take custom binary operations to generate kernels for arbitrary binary operations.\n", @@ -640,9 +1002,8 @@ "@cute.jit\n", "def elementwise_apply(\n", " op: cutlass.Constexpr,\n", - " mA: cute.Tensor,\n", - " mB: cute.Tensor,\n", - " mC: cute.Tensor\n", + " inputs,\n", + " result: cute.Tensor\n", "):\n", " ...\n", "\n", @@ -651,110 +1012,142 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tiler: (16, 256)\n", - "TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n", - "Tiled Input Tensors:\n", - " gA: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", - " gB: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", - " gC: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", - "Composed with TV layout:\n", - " tidfrgA: !cute.memref, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n" - ] - } - ], + "outputs": [], "source": [ "@cute.kernel\n", "def elementwise_apply_kernel(\n", - " op: cutlass.Constexpr, # lambda function must be const expr to generate code at compile time\n", - " gA: cute.Tensor,\n", - " gB: cute.Tensor,\n", - " gC: cute.Tensor,\n", - " tv_layout: cute.Layout\n", + " op: cutlass.Constexpr,\n", + " mInputs: List[cute.Tensor],\n", + " mC: cute.Tensor,\n", + " cC: cute.Tensor, # coordinate tensor\n", + " shape: cute.Shape,\n", + " tv_layout: cute.Layout, # (tid, vid) -> logic coord\n", "):\n", " tidx, _, _ = cute.arch.thread_idx()\n", " bidx, _, _ = cute.arch.block_idx()\n", "\n", - " blk_coord = ((None, None), bidx)\n", + " ###############################################################################\n", + " # Slice to local tile of thread block\n", + " ###############################################################################\n", + " blk_crd = ((None, None), bidx)\n", "\n", - " # logical coord -> address\n", - " blkA = gA[blk_coord] # (TileM, TileN) -> physical address\n", - " blkB = gB[blk_coord] # (TileM, TileN) -> physical address\n", - " blkC = gC[blk_coord] # (TileM, TileN) -> physical address\n", + " # Leverage the meta-programming capability of the DSL to slice the tensors for each input\n", + " # All for loops below on input tensors would be fully unrolled automatically at compile time\n", + " # logical coord -> memory address\n", + " gInputs = [t[blk_crd] for t in mInputs] # (TileM, TileN)\n", + " gC = mC[blk_crd] # (TileM, TileN)\n", + " gCrd = cC[blk_crd] # (TileM, TileN)\n", "\n", - " tidfrgA = cute.composition(blkA, tv_layout)\n", - " tidfrgB = cute.composition(blkB, tv_layout)\n", - " tidfrgC = cute.composition(blkC, tv_layout)\n", + " print(\"[DSL INFO] Sliced Tensors per thread block:\")\n", + " for i in cutlass.range_constexpr(len(gInputs)):\n", + " print(f\"[DSL INFO] ctaInputs{i} = {gInputs[i].type}\")\n", + " print(f\"[DSL INFO] gC = {gC.type}\")\n", + " print(f\"[DSL INFO] gCrd = {gCrd.type}\")\n", "\n", - " print(f\"Composed with TV layout:\")\n", - " print(f\" tidfrgA: {tidfrgA.type}\")\n", + " ###############################################################################\n", + " # Compose with thread block TV layout to map thread & value indices to memory address\n", + " ###############################################################################\n", + " # (tid, vid) -> memory address\n", + " tidfrgInputs = [cute.composition(t, tv_layout) for t in gInputs]\n", + " tidfrgC = cute.composition(gC, tv_layout)\n", + " tidfrgCrd = cute.composition(gCrd, tv_layout)\n", "\n", - " thr_coord = (tidx, None)\n", + " # repeat None like vid to remove hierarchy of layout\n", + " thr_crd = (tidx, cute.repeat_like(None, tidfrgInputs[0][1]))\n", "\n", - " # slice for threads: vid -> address\n", - " thrA = tidfrgA[thr_coord] # (V) -> physical address\n", - " thrB = tidfrgB[thr_coord] # (V) -> physical address\n", - " thrC = tidfrgC[thr_coord] # (V) -> physical address\n", + " ###############################################################################\n", + " # Slice to local tile of thread\n", + " ###############################################################################\n", + " # vid -> address\n", + " thrInputs = [t[thr_crd] for t in tidfrgInputs] # (V)\n", + " thrC = tidfrgC[thr_crd] # (V)\n", + " thrCrd = tidfrgCrd[thr_crd]\n", "\n", - " #--------------------------------\n", - " # apply custom operation\n", - " #--------------------------------\n", - " thrC[None] = op(thrA.load(), thrB.load())\n", + " print(\"[DSL INFO] Sliced Tensors per thread:\")\n", + " for i in cutlass.range_constexpr(len(thrInputs)):\n", + " print(f\"[DSL INFO] thrInputs{i} = {thrInputs[i].type}\")\n", + " print(f\"[DSL INFO] thrC = {thrC.type}\")\n", + " print(f\"[DSL INFO] thrCrd = {thrCrd.type}\")\n", + "\n", + " ###############################################################################\n", + " # Compute predicate for out of boundary checks\n", + " ###############################################################################\n", + " frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean)\n", + " print(f\"[DSL INFO] frgPred = {frgPred.type}\")\n", + "\n", + " for i in cutlass.range_constexpr(cute.size(frgPred)):\n", + " frgPred[i] = cute.elem_less(thrCrd[i], shape)\n", + "\n", + " # if tidx == 0 and bidx == 0:\n", + " # cute.print_tensor(frgPred)\n", + "\n", + " ##########################################################\n", + " # Load data and compute result\n", + " ##########################################################\n", + "\n", + " # Load data before use. The compiler will optimize the copy and load\n", + " # operations to convert some memory ld/st into register uses.\n", + " result = op(*[thrInput.load() for thrInput in thrInputs])\n", + " thrC.store(result)\n", "\n", "\n", "@cute.jit\n", - "def elementwise_op(\n", - " op: cutlass.Constexpr,\n", - " mA: cute.Tensor,\n", - " mB: cute.Tensor,\n", - " mC: cute.Tensor,\n", - "):\n", - " # mA layout: (M, N):(N, 1)\n", - " # TV layout map thread & value index to (16, 256) logical tile\n", - " # - contiguous thread index maps to mode-1 because input layout is contiguous on\n", - " # mode-1 for coalesced load-store\n", - " # - each thread load 8 contiguous element each row and load 4 rows\n", - " thr_layout = cute.make_layout((4, 32), stride=(32, 1))\n", - " val_layout = cute.make_layout((4, 8), stride=(8, 1))\n", + "def elementwise_apply(op: cutlass.Constexpr, inputs, result: cute.Tensor):\n", + " # Use 128bit(16B) load as canonicalized form of val_layout then recast to target element-type\n", + " coalesced_ldst_bytes = 16\n", + "\n", + " # Compile time validation: expect same element type for all input tensors\n", + " assert all(t.element_type == inputs[0].element_type for t in inputs)\n", + " dtype = inputs[0].element_type\n", + "\n", + " thr_layout = cute.make_ordered_layout((4, 64), order=(1, 0))\n", + " val_layout = cute.make_ordered_layout((16, coalesced_ldst_bytes), order=(1, 0))\n", + " val_layout = cute.recast_layout(dtype.width, 8, val_layout)\n", " tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n", - " print(f\"Tiler: {tiler_mn}\")\n", - " print(f\"TV Layout: {tv_layout}\")\n", "\n", - " gA = cute.zipped_divide(mA, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n", - " gB = cute.zipped_divide(mB, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n", - " gC = cute.zipped_divide(mC, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n", + " mInputs = [cute.zipped_divide(input, tiler_mn) for input in inputs]\n", + " mC = cute.zipped_divide(result, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n", "\n", - " print(f\"Tiled Input Tensors:\")\n", - " print(f\" gA: {gA.type}\")\n", - " print(f\" gB: {gB.type}\")\n", - " print(f\" gC: {gC.type}\")\n", + " # (RestM, RestN) -> (RestN, RestM)\n", + " remap_block = cute.make_ordered_layout(\n", + " cute.select(mInputs[0].shape[1], mode=[1, 0]), order=(1, 0)\n", + " )\n", + " for i, t in enumerate(mInputs):\n", + " mInputs[i] = cute.composition(t, (None, remap_block))\n", + "\n", + " mC = cute.composition(mC, (None, remap_block))\n", + "\n", + " idC = cute.make_identity_tensor(result.shape)\n", + " cC = cute.zipped_divide(idC, tiler=tiler_mn)\n", "\n", " # Launch the kernel asynchronously\n", - " # Async token(s) can also be specified as dependencies\n", - " elementwise_apply_kernel(\n", - " op, gA, gB, gC, tv_layout\n", - " ).launch(\n", - " grid=[cute.size(gC, mode=[1]), 1, 1],\n", + " # Group input tensors into a list as a single argument\n", + " elementwise_apply_kernel(op, mInputs, mC, cC, result.shape, tv_layout).launch(\n", + " grid=[cute.size(mC, mode=[1]), 1, 1],\n", " block=[cute.size(tv_layout, mode=[0]), 1, 1],\n", " )\n", "\n", + "\n", "a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", "b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", "c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n", "\n", "a_ = from_dlpack(a, assumed_align=16)\n", "b_ = from_dlpack(b, assumed_align=16)\n", - "c_ = from_dlpack(c, assumed_align=16)\n", - "\n", + "c_ = from_dlpack(c, assumed_align=16)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "from operator import mul\n", "\n", - "elementwise_op(mul, a_, b_, c_)\n", + "elementwise_apply(mul, [a_, b_], c_)\n", "\n", "# verify correctness\n", "torch.testing.assert_close(c, mul(a, b))" @@ -764,30 +1157,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "### Use customized function\n", + "\n", "Custom operators can be more complex. For example, here's a function that performs\n", "multiplication followed by ReLU:" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tiler: (16, 256)\n", - "TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n", - "Tiled Input Tensors:\n", - " gA: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", - " gB: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", - " gC: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", - "Composed with TV layout:\n", - " tidfrgA: !cute.memref, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n" - ] - } - ], + "outputs": [], "source": [ "def mul_relu(a, b):\n", " tmp = a * b\n", @@ -800,7 +1180,7 @@ " return torch.relu(tmp)\n", "\n", "\n", - "elementwise_op(mul_relu, a_, b_, c_)\n", + "elementwise_apply(mul_relu, [a_, b_], c_)\n", "\n", "# verify correctness\n", "torch.testing.assert_close(c, mul_relu_ref(a, b))" @@ -809,7 +1189,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv3_12", "language": "python", "name": "python3" }, @@ -823,7 +1203,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.12.10" }, "widgets": { "application/vnd.jupyter.widget-state+json": { diff --git a/examples/python/CuTeDSL/notebooks/hello_world.ipynb b/examples/python/CuTeDSL/notebooks/hello_world.ipynb index e722d828..218378ad 100644 --- a/examples/python/CuTeDSL/notebooks/hello_world.ipynb +++ b/examples/python/CuTeDSL/notebooks/hello_world.ipynb @@ -27,8 +27,8 @@ "metadata": {}, "outputs": [], "source": [ - "import cutlass \n", - "import cutlass.cute as cute " + "import cutlass\n", + "import cutlass.cute as cute" ] }, { @@ -80,14 +80,13 @@ "source": [ "@cute.jit\n", "def hello_world():\n", - "\n", " # Print hello world from host code\n", " cute.printf(\"hello world\")\n", "\n", " # Launch kernel\n", " kernel().launch(\n", - " grid=(1, 1, 1), # Single thread block\n", - " block=(32, 1, 1) # One warp (32 threads) per thread block\n", + " grid=(1, 1, 1), # Single thread block\n", + " block=(32, 1, 1), # One warp (32 threads) per thread block\n", " )" ] }, @@ -107,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -115,17 +114,19 @@ "output_type": "stream", "text": [ "Running hello_world()...\n", - "hello world\n", "Compiling...\n", + "hello world\n", "Hello world\n", + "Compiling with PTX/CUBIN dumped...\n", "Running compiled version...\n", - "hello world\n" + "hello world\n", + "Hello world\n" ] } ], "source": [ "# Initialize CUDA context for launching a kernel with error checking\n", - "# We make context initialization explicit to allow users to control the context creation \n", + "# We make context initialization explicit to allow users to control the context creation\n", "# and avoid potential issues with multiple contexts\n", "cutlass.cuda.initialize_cuda_context()\n", "\n", @@ -137,6 +138,14 @@ "print(\"Compiling...\")\n", "hello_world_compiled = cute.compile(hello_world)\n", "\n", + "# Dump PTX/CUBIN files while compiling\n", + "from cutlass.cute import KeepPTX, KeepCUBIN\n", + "\n", + "print(\"Compiling with PTX/CUBIN dumped...\")\n", + "# Alternatively, compile with string based options like\n", + "# cute.compile(hello_world, options=\"--keep-ptx --keep-cubin\") would also work.\n", + "hello_world_compiled_ptx_on = cute.compile[KeepPTX, KeepCUBIN](hello_world)\n", + "\n", "# Run the pre-compiled version\n", "print(\"Running compiled version...\")\n", "hello_world_compiled()" diff --git a/examples/python/CuTeDSL/notebooks/print.ipynb b/examples/python/CuTeDSL/notebooks/print.ipynb index ce1cf9ec..0e70d2c4 100644 --- a/examples/python/CuTeDSL/notebooks/print.ipynb +++ b/examples/python/CuTeDSL/notebooks/print.ipynb @@ -83,8 +83,8 @@ " print(\">>>\", type(b)) # => \n", "\n", " layout = cute.make_layout((a, b))\n", - " print(\">>>\", layout) # => (?,2):(1,?)\n", - " cute.printf(\">?? {}\", layout) # => (8,2):(1,8)" + " print(\">>>\", layout) # => (?,2):(1,?)\n", + " cute.printf(\">?? {}\", layout) # => (8,2):(1,8)" ] }, { @@ -221,6 +221,7 @@ " layout = cute.make_layout((a, b))\n", " print(f\"layout: {layout}\")\n", "\n", + "\n", "print(\"Direct run output:\")\n", "format_string_example(cutlass.Int32(8), 2)" ] @@ -246,23 +247,26 @@ "source": [ "from cutlass.cute.runtime import from_dlpack\n", "\n", + "\n", "@cute.jit\n", - "def print_tensor_basic(x : cute.Tensor):\n", + "def print_tensor_basic(x: cute.Tensor):\n", " # Print the tensor\n", " print(\"Basic output:\")\n", " cute.print_tensor(x)\n", - " \n", + "\n", + "\n", "@cute.jit\n", - "def print_tensor_verbose(x : cute.Tensor):\n", + "def print_tensor_verbose(x: cute.Tensor):\n", " # Print the tensor with verbose mode\n", " print(\"Verbose output:\")\n", " cute.print_tensor(x, verbose=True)\n", "\n", + "\n", "@cute.jit\n", - "def print_tensor_slice(x : cute.Tensor, coord : tuple):\n", + "def print_tensor_slice(x: cute.Tensor, coord: tuple):\n", " # slice a 2D tensor from the 3D tensor\n", " sliced_data = cute.slice_(x, coord)\n", - " y = cute.make_fragment(sliced_data.layout, sliced_data.element_type)\n", + " y = cute.make_rmem_tensor(sliced_data.layout, sliced_data.element_type)\n", " # Convert to TensorSSA format by loading the sliced data into the fragment\n", " y.store(sliced_data.load())\n", " print(\"Slice output:\")\n", @@ -302,12 +306,13 @@ "source": [ "def tensor_print_example1():\n", " shape = (4, 3, 2)\n", - " \n", + "\n", " # Creates [0,...,23] and reshape to (4, 3, 2)\n", - " data = np.arange(24, dtype=np.float32).reshape(*shape) \n", - " \n", + " data = np.arange(24, dtype=np.float32).reshape(*shape)\n", + "\n", " print_tensor_basic(from_dlpack(data))\n", "\n", + "\n", "tensor_print_example1()" ] }, @@ -348,12 +353,13 @@ "source": [ "def tensor_print_example2():\n", " shape = (4, 3)\n", - " \n", + "\n", " # Creates [0,...,11] and reshape to (4, 3)\n", - " data = np.arange(12, dtype=np.float32).reshape(*shape) \n", - " \n", + " data = np.arange(12, dtype=np.float32).reshape(*shape)\n", + "\n", " print_tensor_verbose(from_dlpack(data))\n", "\n", + "\n", "tensor_print_example2()" ] }, @@ -390,13 +396,14 @@ "source": [ "def tensor_print_example3():\n", " shape = (4, 3)\n", - " \n", + "\n", " # Creates [0,...,11] and reshape to (4, 3)\n", - " data = np.arange(12, dtype=np.float32).reshape(*shape) \n", - " \n", + " data = np.arange(12, dtype=np.float32).reshape(*shape)\n", + "\n", " print_tensor_slice(from_dlpack(data), (None, 0))\n", " print_tensor_slice(from_dlpack(data), (1, None))\n", "\n", + "\n", "tensor_print_example3()" ] }, @@ -418,9 +425,10 @@ " print(src)\n", " cute.print_tensor(src)\n", "\n", + "\n", "@cute.jit\n", "def print_tensor_host(src: cute.Tensor):\n", - " print_tensor_gpu(src).launch(grid=(1,1,1), block=(1,1,1))" + " print_tensor_gpu(src).launch(grid=(1, 1, 1), block=(1, 1, 1))" ] }, { @@ -449,11 +457,14 @@ ], "source": [ "import torch\n", + "\n", + "\n", "def tensor_print_example4():\n", " a = torch.randn(4, 3, device=\"cuda\")\n", " cutlass.cuda.initialize_cuda_context()\n", " print_tensor_host(from_dlpack(a))\n", "\n", + "\n", "tensor_print_example4()" ] }, diff --git a/examples/python/CuTeDSL/notebooks/tensor.ipynb b/examples/python/CuTeDSL/notebooks/tensor.ipynb index ad5c7fc1..6106aae3 100644 --- a/examples/python/CuTeDSL/notebooks/tensor.ipynb +++ b/examples/python/CuTeDSL/notebooks/tensor.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -43,7 +43,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -69,24 +69,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(raw_ptr(0x000000000736b0c0: f32, generic, align<4>) o (8,5):(5,1), data=\n", - " [[ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ],\n", - " [ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ],\n", - " [ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ],\n", - " ...\n", - " [ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ],\n", - " [ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ],\n", - " [ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ]])\n" - ] - } - ], + "outputs": [], "source": [ "import torch\n", "\n", @@ -115,12 +100,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from cutlass.cute.runtime import from_dlpack\n", "\n", + "\n", "@cute.jit\n", "def print_tensor_dlpack(src: cute.Tensor):\n", " print(src)\n", @@ -129,25 +115,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor o (8,5):(5,1)>\n", - "tensor(raw_ptr(0x0000000007559340: f32, generic, align<4>) o (8,5):(5,1), data=\n", - " [[-1.151769, 1.019397, -0.371175, -0.717776, 0.502176, ],\n", - " [ 0.114282, 0.900084, 0.320770, 1.564574, -0.632329, ],\n", - " [-0.570140, 0.178112, -0.423079, 1.936198, 0.003355, ],\n", - " ...\n", - " [-2.425393, -0.275528, 1.267157, -0.811101, -0.985456, ],\n", - " [ 0.777889, -2.114074, 0.357184, -0.321312, -0.938138, ],\n", - " [ 1.959564, 1.797602, 0.116901, 0.306198, -1.837295, ]])\n" - ] - } - ], + "outputs": [], "source": [ "a = torch.randn(8, 5, dtype=torch_dtype(cutlass.Float32))\n", "\n", @@ -156,25 +126,9 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor o (8,8):(8,1)>\n", - "tensor(raw_ptr(0x0000000007979da0: f32, generic, align<4>) o (8,8):(8,1), data=\n", - " [[ 0.122739, -0.605744, -1.442022, ..., -0.356501, -0.993329, -0.091110, ],\n", - " [ 0.278448, 0.318482, -0.276867, ..., 1.542181, -1.701539, -0.309454, ],\n", - " [ 0.563565, -0.753936, 0.131214, ..., 0.437912, -0.482277, -0.051540, ],\n", - " ...\n", - " [-1.974096, -0.177881, 0.426807, ..., -1.579115, -0.304974, 0.451164, ],\n", - " [ 0.149851, -0.704689, -0.295063, ..., -0.653001, 0.008871, 0.903916, ],\n", - " [ 1.188619, 1.519662, 1.270734, ..., 0.404082, 0.173200, 0.093476, ]])\n" - ] - } - ], + "outputs": [], "source": [ "import numpy as np\n", "\n", @@ -211,39 +165,23 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "a[2] = 10.000000 (equivalent to a[(2,0)])\n", - "a[9] = 6.000000 (equivalent to a[(1,1)])\n", - "a[2,0] = 10.000000\n", - "a[2,4] = 14.000000\n", - "a[(2,4)] = 14.000000\n", - "a[2,3] = 100.000000\n", - "a[(2,4)] = 101.000000\n", - "tensor([[ 0., 1., 2., 3., 4.],\n", - " [ 5., 6., 7., 8., 9.],\n", - " [ 10., 11., 12., 100., 101.],\n", - " [ 15., 16., 17., 18., 19.],\n", - " [ 20., 21., 22., 23., 24.],\n", - " [ 25., 26., 27., 28., 29.],\n", - " [ 30., 31., 32., 33., 34.],\n", - " [ 35., 36., 37., 38., 39.]])\n" - ] - } - ], + "outputs": [], "source": [ "@cute.jit\n", "def tensor_access_item(a: cute.Tensor):\n", " # access data using linear index\n", - " cute.printf(\"a[2] = {} (equivalent to a[{}])\", a[2],\n", - " cute.make_identity_tensor(a.layout.shape)[2])\n", - " cute.printf(\"a[9] = {} (equivalent to a[{}])\", a[9],\n", - " cute.make_identity_tensor(a.layout.shape)[9])\n", + " cute.printf(\n", + " \"a[2] = {} (equivalent to a[{}])\",\n", + " a[2],\n", + " cute.make_identity_tensor(a.layout.shape)[2],\n", + " )\n", + " cute.printf(\n", + " \"a[9] = {} (equivalent to a[{}])\",\n", + " a[9],\n", + " cute.make_identity_tensor(a.layout.shape)[9],\n", + " )\n", "\n", " # access data using n-d coordinates, following two are equivalent\n", " cute.printf(\"a[2,0] = {}\", a[2, 0])\n", @@ -251,14 +189,14 @@ " cute.printf(\"a[(2,4)] = {}\", a[2, 4])\n", "\n", " # assign value to tensor@(2,4)\n", - " a[2,3] = 100.0\n", - " a[2,4] = 101.0\n", - " cute.printf(\"a[2,3] = {}\", a[2,3])\n", - " cute.printf(\"a[(2,4)] = {}\", a[(2,4)])\n", + " a[2, 3] = 100.0\n", + " a[2, 4] = 101.0\n", + " cute.printf(\"a[2,3] = {}\", a[2, 3])\n", + " cute.printf(\"a[(2,4)] = {}\", a[(2, 4)])\n", "\n", "\n", "# Create a tensor with sequential data using torch\n", - "data = torch.arange(0, 8*5, dtype=torch.float32).reshape(8, 5)\n", + "data = torch.arange(0, 8 * 5, dtype=torch.float32).reshape(8, 5)\n", "tensor_access_item(from_dlpack(data))\n", "\n", "print(data)" @@ -287,14 +225,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Coordinate Tensor\n", + "## Coordinate Tensors\n", "\n", - "A coordinate tensor is a special type of tensor that maps coordinates to coordinates rather than to values. \n", - "The key distinction is that while regular tensors map coordinates to some value type (like numbers), \n", - "coordinate tensors map coordinates to other coordinates.\n", + "### Definition and Properties\n", "\n", - "For example, given a shape (4,4), a coordinate tensor using row-major layout would appear as:\n", + "A coordinate tensor $T: Z^n → Z^m$ is a mathematical structure that establishes a mapping between coordinate spaces. Unlike standard tensors that map coordinates to scalar values, coordinate tensors map coordinates to other coordinates, forming a fundamental building block for tensor operations and transformations.\n", "\n", + "### Examples\n", + "\n", + "Consider a `(4,4)` coordinate tensor:\n", + "\n", + "**Row-Major Layout (C-style):**\n", "\\begin{bmatrix} \n", "(0,0) & (0,1) & (0,2) & (0,3) \\\\\n", "(1,0) & (1,1) & (1,2) & (1,3) \\\\\n", @@ -302,8 +243,7 @@ "(3,0) & (3,1) & (3,2) & (3,3)\n", "\\end{bmatrix}\n", "\n", - "The same shape with a column-major layout would appear as:\n", - "\n", + "**Column-Major Layout (Fortran-style):**\n", "\\begin{bmatrix}\n", "(0,0) & (1,0) & (2,0) & (3,0) \\\\\n", "(0,1) & (1,1) & (2,1) & (3,1) \\\\\n", @@ -311,40 +251,50 @@ "(0,3) & (1,3) & (2,3) & (3,3)\n", "\\end{bmatrix}\n", "\n", - "The key points about coordinate tensors are:\n", - "- Each element in the tensor is itself a coordinate tuple (i,j) rather than a scalar value\n", - "- The coordinates map to themselves - so position (1,2) contains the coordinate (1,2)\n", - "- The layout (row-major vs column-major) determines how these coordinate tuples are arranged in memory\n", + "### Identity Tensor\n", "\n", - "For example, coordinate tensors can be created using the `make_identity_tensor` utility:\n", + "An identity tensor $I$ is a special case of a coordinate tensor that implements the identity mapping function:\n", "\n", + "**Definition:**\n", + "For a given shape $S = (s_1, s_2, ..., s_n)$, the identity tensor $I$ satisfies: $I(c) = c, \\forall c \\in \\prod_{i=1}^n [0, s_i)$\n", + "\n", + "**Properties:**\n", + "1. **Bijective Mapping**: The identity tensor establishes a one-to-one correspondence between coordinates.\n", + "2. **Layout Invariance**: The logical structure remains constant regardless of the underlying memory layout.\n", + "3. **Coordinate Preservation**: For any coordinate c, I(c) = c.\n", + "\n", + "\n", + "CuTe establishes an isomorphism between 1-D indices and N-D coordinates through lexicographical ordering. For a coordinate c = (c₁, c₂, ..., cₙ) in an identity tensor with shape S = (s₁, s₂, ..., sₙ):\n", + "\n", + "**Linear Index Formula:**\n", + "$\\text{idx} = c_1 + \\sum_{i=2}^{n} \\left(c_i \\prod_{j=1}^{i-1} s_j\\right)$\n", + "\n", + "**Example:**\n", "```python\n", + "# Create an identity tensor from a given shape\n", "coord_tensor = make_identity_tensor(layout.shape())\n", + "\n", + "# Access coordinate using linear index\n", + "coord = coord_tensor[linear_idx] # Returns the N-D coordinate\n", "```\n", "\n", - "This creates a tensor that maps each coordinate to itself, providing a reference point for understanding how other layouts transform these coordinates." + "This bidirectional mapping enables efficient conversion from linear indices to N-dimensional coordinates, facilitating tensor operations and memory access patterns." ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor<(0,0) o (8,4):(1@0,1@1)>\n" - ] - } - ], + "outputs": [], "source": [ "@cute.jit\n", "def print_tensor_coord(a: cute.Tensor):\n", " coord_tensor = cute.make_identity_tensor(a.layout.shape)\n", " print(coord_tensor)\n", + " cute.print_tensor(coord_tensor)\n", "\n", - "a = torch.randn(8,4, dtype=torch_dtype(cutlass.Float32))\n", + "\n", + "a = torch.randn(8, 4, dtype=torch_dtype(cutlass.Float32))\n", "print_tensor_coord(from_dlpack(a))" ] } diff --git a/examples/python/CuTeDSL/notebooks/tensorssa.ipynb b/examples/python/CuTeDSL/notebooks/tensorssa.ipynb index 40a56d97..f60e0365 100644 --- a/examples/python/CuTeDSL/notebooks/tensorssa.ipynb +++ b/examples/python/CuTeDSL/notebooks/tensorssa.ipynb @@ -10,8 +10,7 @@ "import cutlass.cute as cute\n", "from cutlass.cute.runtime import from_dlpack\n", "\n", - "import numpy as np\n", - "import torch" + "import numpy as np" ] }, { @@ -55,12 +54,13 @@ " :param b: The source tensor to be loaded.\n", " \"\"\"\n", " a_vec = a.load()\n", - " print(f\"a_vec: {a_vec}\") # prints `a_vec: vector<12xf32> o (3, 4)`\n", + " print(f\"a_vec: {a_vec}\") # prints `a_vec: vector<12xf32> o (3, 4)`\n", " b_vec = b.load()\n", - " print(f\"b_vec: {b_vec}\") # prints `b_vec: vector<12xf32> o (3, 4)`\n", + " print(f\"b_vec: {b_vec}\") # prints `b_vec: vector<12xf32> o (3, 4)`\n", " res.store(a_vec + b_vec)\n", " cute.print_tensor(res)\n", "\n", + "\n", "a = np.ones(12).reshape((3, 4)).astype(np.float32)\n", "b = np.ones(12).reshape((3, 4)).astype(np.float32)\n", "c = np.zeros(12).reshape((3, 4)).astype(np.float32)\n", @@ -101,6 +101,7 @@ " dst[0] = dst_vec\n", " cute.print_tensor(dst)\n", "\n", + "\n", "def slice_1():\n", " src_shape = (4, 2, 3)\n", " dst_shape = (4, 3)\n", @@ -124,6 +125,7 @@ " dst = np.random.randn(*dst_shape).astype(np.float32)\n", " apply_slice(from_dlpack(a), from_dlpack(dst), indices)\n", "\n", + "\n", "slice_1()" ] }, @@ -141,6 +143,7 @@ " dst = np.random.randn(*dst_shape).astype(np.float32)\n", " apply_slice(from_dlpack(a), from_dlpack(dst), indices)\n", "\n", + "\n", "slice_2()" ] }, @@ -169,22 +172,22 @@ " b_vec = b.load()\n", "\n", " add_res = a_vec + b_vec\n", - " cute.print_tensor(add_res) # prints [3.000000, 3.000000, 3.000000]\n", + " cute.print_tensor(add_res) # prints [3.000000, 3.000000, 3.000000]\n", "\n", " sub_res = a_vec - b_vec\n", - " cute.print_tensor(sub_res) # prints [-1.000000, -1.000000, -1.000000]\n", + " cute.print_tensor(sub_res) # prints [-1.000000, -1.000000, -1.000000]\n", "\n", " mul_res = a_vec * b_vec\n", - " cute.print_tensor(mul_res) # prints [2.000000, 2.000000, 2.000000]\n", + " cute.print_tensor(mul_res) # prints [2.000000, 2.000000, 2.000000]\n", "\n", " div_res = a_vec / b_vec\n", - " cute.print_tensor(div_res) # prints [0.500000, 0.500000, 0.500000]\n", + " cute.print_tensor(div_res) # prints [0.500000, 0.500000, 0.500000]\n", "\n", " floor_div_res = a_vec // b_vec\n", - " cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n", + " cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n", "\n", " mod_res = a_vec % b_vec\n", - " cute.print_tensor(mod_res) # prints [1.000000, 1.000000, 1.000000]\n", + " cute.print_tensor(mod_res) # prints [1.000000, 1.000000, 1.000000]\n", "\n", "\n", "a = np.empty((3,), dtype=np.float32)\n", @@ -206,22 +209,23 @@ " a_vec = a.load()\n", "\n", " add_res = a_vec + c\n", - " cute.print_tensor(add_res) # prints [3.000000, 3.000000, 3.000000]\n", + " cute.print_tensor(add_res) # prints [3.000000, 3.000000, 3.000000]\n", "\n", " sub_res = a_vec - c\n", - " cute.print_tensor(sub_res) # prints [-1.000000, -1.000000, -1.000000]\n", + " cute.print_tensor(sub_res) # prints [-1.000000, -1.000000, -1.000000]\n", "\n", " mul_res = a_vec * c\n", - " cute.print_tensor(mul_res) # prints [2.000000, 2.000000, 2.000000]\n", + " cute.print_tensor(mul_res) # prints [2.000000, 2.000000, 2.000000]\n", "\n", " div_res = a_vec / c\n", - " cute.print_tensor(div_res) # prints [0.500000, 0.500000, 0.500000]\n", + " cute.print_tensor(div_res) # prints [0.500000, 0.500000, 0.500000]\n", "\n", " floor_div_res = a_vec // c\n", - " cute.print_tensor(floor_div_res) # prints [0.000000, 0.000000, 0.000000]\n", + " cute.print_tensor(floor_div_res) # prints [0.000000, 0.000000, 0.000000]\n", "\n", " mod_res = a_vec % c\n", - " cute.print_tensor(mod_res) # prints [1.000000, 1.000000, 1.000000]\n", + " cute.print_tensor(mod_res) # prints [1.000000, 1.000000, 1.000000]\n", + "\n", "\n", "a = np.empty((3,), dtype=np.float32)\n", "a.fill(1.0)\n", @@ -251,11 +255,12 @@ " eq_res = a_ == b_ # [False, False, False]\n", " \"\"\"\n", "\n", + "\n", "a = np.array([1, 2, 3], dtype=np.float32)\n", "b = np.array([2, 1, 4], dtype=np.float32)\n", "res = np.empty((3,), dtype=np.bool_)\n", "binary_op_3(from_dlpack(res), from_dlpack(a), from_dlpack(b))\n", - "print(res) # prints [False, True, False]\n" + "print(res) # prints [False, True, False]" ] }, { @@ -278,11 +283,12 @@ " # and_res = a_vec & b_vec\n", " # res.store(and_res) # prints [0, 2, 0]\n", "\n", + "\n", "a = np.array([1, 2, 3], dtype=np.int32)\n", "b = np.array([2, 2, 4], dtype=np.int32)\n", "res = np.empty((3,), dtype=np.int32)\n", "binary_op_4(from_dlpack(res), from_dlpack(a), from_dlpack(b))\n", - "print(res) # prints [3, 0, 7]" + "print(res) # prints [3, 0, 7]" ] }, { @@ -303,14 +309,15 @@ " a_vec = a.load()\n", "\n", " sqrt_res = cute.math.sqrt(a_vec)\n", - " cute.print_tensor(sqrt_res) # prints [2.000000, 2.000000, 2.000000]\n", + " cute.print_tensor(sqrt_res) # prints [2.000000, 2.000000, 2.000000]\n", "\n", " sin_res = cute.math.sin(a_vec)\n", " res.store(sin_res)\n", - " cute.print_tensor(sin_res) # prints [-0.756802, -0.756802, -0.756802]\n", + " cute.print_tensor(sin_res) # prints [-0.756802, -0.756802, -0.756802]\n", "\n", " exp2_res = cute.math.exp2(a_vec)\n", - " cute.print_tensor(exp2_res) # prints [16.000000, 16.000000, 16.000000]\n", + " cute.print_tensor(exp2_res) # prints [16.000000, 16.000000, 16.000000]\n", + "\n", "\n", "a = np.array([4.0, 4.0, 4.0], dtype=np.float32)\n", "res = np.empty((3,), dtype=np.float32)\n", @@ -344,26 +351,14 @@ " :param src: The source tensor to be reduced.\n", " \"\"\"\n", " a_vec = a.load()\n", - " red_res = a_vec.reduce(\n", - " cute.ReductionOp.ADD,\n", - " 0.0,\n", - " reduction_profile=0\n", - " )\n", - " cute.printf(red_res) # prints 21.000000\n", + " red_res = a_vec.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile=0)\n", + " cute.printf(red_res) # prints 21.000000\n", "\n", - " red_res = a_vec.reduce(\n", - " cute.ReductionOp.ADD,\n", - " 0.0,\n", - " reduction_profile=(None, 1)\n", - " )\n", - " cute.print_tensor(red_res) # prints [6.000000, 15.000000]\n", + " red_res = a_vec.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile=(None, 1))\n", + " cute.print_tensor(red_res) # prints [6.000000, 15.000000]\n", "\n", - " red_res = a_vec.reduce(\n", - " cute.ReductionOp.ADD,\n", - " 1.0,\n", - " reduction_profile=(1, None)\n", - " )\n", - " cute.print_tensor(red_res) # prints [6.000000, 8.000000, 10.000000]\n", + " red_res = a_vec.reduce(cute.ReductionOp.ADD, 1.0, reduction_profile=(1, None))\n", + " cute.print_tensor(red_res) # prints [6.000000, 8.000000, 10.000000]\n", "\n", "\n", "a = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)\n", @@ -399,7 +394,7 @@ "\n", "@cute.jit\n", "def broadcast_examples():\n", - " a = cute.make_fragment((1,3), dtype=cutlass.Float32)\n", + " a = cute.make_rmem_tensor((1, 3), dtype=cutlass.Float32)\n", " a[0] = 0.0\n", " a[1] = 1.0\n", " a[2] = 2.0\n", @@ -411,7 +406,7 @@ " # [ 0.000000, 1.000000, 2.000000, ],\n", " # [ 0.000000, 1.000000, 2.000000, ]])\n", "\n", - " c = cute.make_fragment((4,1), dtype=cutlass.Float32)\n", + " c = cute.make_rmem_tensor((4, 1), dtype=cutlass.Float32)\n", " c[0] = 0.0\n", " c[1] = 1.0\n", " c[2] = 2.0\n", @@ -494,7 +489,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.10" + "version": "3.12.11" } }, "nbformat": 4, diff --git a/examples/python/CuTeDSL/utils/__init__.py b/examples/python/CuTeDSL/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/python/CuTeDSL/utils/fmha_helpers.py b/examples/python/CuTeDSL/utils/fmha_helpers.py new file mode 100644 index 00000000..6d49a9f1 --- /dev/null +++ b/examples/python/CuTeDSL/utils/fmha_helpers.py @@ -0,0 +1,975 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from typing import Tuple, Optional +import cutlass +from cutlass.cute.typing import Boolean + +from cutlass.cutlass_dsl import ( + Int32, + Float32, + min, + extract_mlir_values, + new_from_mlir_values, +) +from cutlass.utils.hardware_info import HardwareInfo +from cutlass.utils import WorkTileInfo +import cutlass.cute as cute + +############################################################################## +# Fmha static tile scheduler +############################################################################## + + +class FmhaStaticTileSchedulerParams: + """A class to represent parameters for the FMHA (Fused Multi-Head Attention) static tile scheduler. + + This class holds the configuration parameters needed to initialize and configure + the tile scheduler for FMHA operations. + + :ivar is_persistent: Whether to use persistent kernel mode. + :type is_persistent: bool + :ivar problem_shape_mbh: Problem shape in (M, B, H) format. + :type problem_shape_mbh: cute.Shape + """ + + def __init__( + self, + is_persistent: bool, + problem_shape_mbh: cute.Shape, + *, + loc=None, + ip=None, + ): + """ + Initializes the FmhaStaticTileSchedulerParams with the given parameters. + + :param is_persistent: Whether to use persistent kernel mode. + :type is_persistent: bool + :param problem_shape_mbh: Problem shape in (M, B, H) format. + :type problem_shape_mbh: cute.Shape + """ + self.is_persistent = is_persistent + self.problem_shape_mbh = problem_shape_mbh + self._loc = loc + self._ip = ip + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.problem_shape_mbh]: + obj_values = extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.problem_shape_mbh], self._values_pos): + obj_list.append(new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return FmhaStaticTileSchedulerParams( + self.is_persistent, *(tuple(obj_list)), loc=self._loc + ) + + +class FmhaStaticTileScheduler: + """A static tile scheduler for FMHA (Fused Multi-Head Attention) operations. + + This class manages the scheduling of work tiles for FMHA kernels, supporting + both persistent and non-persistent kernel modes. It tracks the current work + position and advances through the problem space efficiently. + + :ivar _params: Scheduler parameters. + :type _params: FmhaStaticTileSchedulerParams + :ivar _blk_coord: Block coordinates. + :type _blk_coord: cute.Coord + :ivar _grid_shape: Grid shape for the kernel. + :type _grid_shape: cute.Shape + :ivar _is_persistent: Whether to use persistent kernel mode. + :type _is_persistent: bool + :ivar _current_work_linear_idx: Current linear work index. + :type _current_work_linear_idx: Int32 + :ivar _problem_shape_mbh: Problem shape in (M, B, H) format. + :type _problem_shape_mbh: cute.Layout + :ivar _num_blocks: Number of blocks in the problem. + :type _num_blocks: Int32 + :ivar _is_first_block: Whether this is the first block. + :type _is_first_block: bool + :ivar num_persistent_sm: Number of persistent SMs. + :type num_persistent_sm: Int32 + """ + + def __init__( + self, + params: FmhaStaticTileSchedulerParams, + current_work_linear_idx: Int32, + blk_coord: cute.Coord, + grid_shape: cute.Shape, + *, + loc=None, + ip=None, + ): + """ + Initializes the FmhaStaticTileScheduler with the given parameters. + + :param params: Scheduler parameters. + :type params: FmhaStaticTileSchedulerParams + :param current_work_linear_idx: Current linear work index. + :type current_work_linear_idx: Int32 + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param grid_shape: Grid shape for the kernel. + :type grid_shape: cute.Shape + """ + self._params = params + self._blk_coord = blk_coord + self._grid_shape = grid_shape + self._is_persistent = params.is_persistent + self._current_work_linear_idx = current_work_linear_idx + self._problem_shape_mbh = cute.make_layout( + params.problem_shape_mbh, loc=loc, ip=ip + ) + self._num_blocks = cute.size(self._problem_shape_mbh, loc=loc, ip=ip) + self._is_first_block = True + self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip) + self._loc = loc + self._ip = ip + + # called by host + @staticmethod + def get_grid_shape( + params: FmhaStaticTileSchedulerParams, + *, + loc=None, + ip=None, + ) -> cute.Shape: + """ + Determine the grid shape for the FMHA kernel. + + For persistent kernels, the grid shape is limited by the number of SMs + (Streaming Multiprocessors) available on the device. For non-persistent + kernels, the grid shape matches the problem shape. + + :param params: Scheduler parameters. + :type params: FmhaStaticTileSchedulerParams + + :return: Grid shape as (M, B, H) tuple. + :rtype: cute.Shape + """ + if params.is_persistent: + hardware_info = HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + return ( + min(sm_count, cute.size(params.problem_shape_mbh, loc=loc, ip=ip)), + 1, + 1, + ) + else: + return params.problem_shape_mbh + + @staticmethod + def check_valid_work_for_seqlen_q( + q_tiler: int, + current_idx: Int32, + seqlen_q: Int32, + ) -> Boolean: + """ + Check if the current work index is valid for the given query sequence length. + + This method verifies that the current work tile index multiplied by the + query tiler size is within the bounds of the query sequence length. + + :param q_tiler: Query tiler size. + :type q_tiler: int + :param current_idx: Current work index. + :type current_idx: Int32 + :param seqlen_q: Query sequence length. + :type seqlen_q: Int32 + + :return: True if the work is valid, False otherwise. + :rtype: Boolean + """ + return current_idx * q_tiler < seqlen_q + + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + """ + Get information about the current work tile. + + Determines if the current work is valid and computes the tile coordinates + based on whether the kernel is persistent or non-persistent. + + :return: WorkTileInfo containing tile coordinates and validity flag. + :rtype: WorkTileInfo + """ + is_valid = ( + self._current_work_linear_idx < self._num_blocks + if self._is_persistent + else self._is_first_block + ) + + blk_coord = (0, 0, 0) + if self._is_persistent: + blk_coord = self._problem_shape_mbh.get_hier_coord( + self._current_work_linear_idx, loc=loc, ip=ip + ) + else: + blk_coord = self._blk_coord + + # cur_tile_coord is (mid, 0, (bid, hid)) + cur_tile_coord = ( + blk_coord[0], + 0, + (blk_coord[1], blk_coord[2]), + ) + + return WorkTileInfo(cur_tile_coord, is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None): + """ + Get the initial work tile information. + + :return: Initial WorkTileInfo. + :rtype: WorkTileInfo + """ + return self.get_current_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None): + """ + Advance to the next work tile. + + For persistent kernels, advances by the number of persistent SMs. + For non-persistent kernels, marks that the first block has been processed. + + :param advance_count: Number of steps to advance (default: 1). + :type advance_count: int + """ + if self._is_persistent: + self._current_work_linear_idx += advance_count * self.num_persistent_sm + self._is_first_block = False + + def __extract_mlir_values__(self): + values = extract_mlir_values(self._params) + values.extend(extract_mlir_values(self._current_work_linear_idx)) + values.extend(extract_mlir_values(self._blk_coord)) + values.extend(extract_mlir_values(self._grid_shape)) + return values + + def __new_from_mlir_values__(self, values): + assert len(values) == 10 + new_params = new_from_mlir_values(self._params, values[0:3]) + new_current_work_linear_idx = new_from_mlir_values( + self._current_work_linear_idx, [values[3]] + ) + new_blk_coord = new_from_mlir_values(self._blk_coord, values[4:7]) + new_grid_shape = new_from_mlir_values(self._grid_shape, values[7:]) + return FmhaStaticTileScheduler( + new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape + ) + + +def create_fmha_static_tile_scheduler( + params: FmhaStaticTileSchedulerParams, + blk_coord: cute.Coord, + grid_shape: cute.Shape, +) -> FmhaStaticTileScheduler: + """ + Create a new FMHA static tile scheduler. + + :param params: Scheduler parameters. + :type params: FmhaStaticTileSchedulerParams + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param grid_shape: Grid shape. + :type grid_shape: cute.Shape + + :return: New FmhaStaticTileScheduler instance. + :rtype: FmhaStaticTileScheduler + """ + return FmhaStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape) + + +def create_fmha_static_tile_scheduler_params( + is_persistent: bool, + problem_shape_mbh: cute.Shape, +) -> FmhaStaticTileSchedulerParams: + """ + Create FMHA static tile scheduler parameters. + + :param is_persistent: Whether to use persistent kernel mode. + :type is_persistent: bool + :param problem_shape_mbh: Problem shape in (M, B, H) format. + :type problem_shape_mbh: cute.Shape + + :return: New FmhaStaticTileSchedulerParams instance. + :rtype: FmhaStaticTileSchedulerParams + """ + return FmhaStaticTileSchedulerParams(is_persistent, problem_shape_mbh) + + +def compute_grid( + o_shape: cute.Shape, + cta_tiler: Tuple[int, int, int], + is_persistent: bool, +) -> Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]: + """ + Compute grid parameters for FMHA operation. + + This function calculates the appropriate grid shape and scheduler parameters + based on the output tensor shape, CTA (Cooperative Thread Array) tiler, + and whether to use persistent kernel mode. + + The output tensor o has shape (s, d, ((h_r, h_k), b)) where: + - s: sequence length + - d: head dimension + - h_r: number of heads for query + - h_k: number of heads for key + - b: batch size + + :param o_shape: Output tensor shape for grid computation. + :type o_shape: cute.Shape + :param cta_tiler: CTA tiler dimensions (M, N, K). + :type cta_tiler: Tuple[int, int, int] + :param is_persistent: Whether to use persistent kernel mode. + :type is_persistent: bool + + :return: Tuple of (scheduler_params, grid_shape). + :rtype: Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]] + """ + tile_sched_params = create_fmha_static_tile_scheduler_params( + is_persistent, + ( + cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), + cute.size(o_shape[2][0]), + cute.size(o_shape[2][1]), + ), + ) + grid = FmhaStaticTileScheduler.get_grid_shape(tile_sched_params) + + return tile_sched_params, grid + + +############################################################################## +# Fused Mask +############################################################################## + + +class MaskEnum(enum.Enum): + """Enumeration of mask types for FMHA operations. + + - RESIDUAL_MASK: Residual mask for handling variable sequence lengths + - WINDOW_MASK: Window mask for attention which also includes causal and no mask + - WINDOW_MASK_INFERENCE: Same as the window mask, but has the limitation that the end of q is aligned with the end of k + - WINDOW_MASK_BWD: Window mask for backward pass + - WINDOW_MASK_BWD_INFERENCE: Same as the window mask for backward pass, but has the limitation that the end of q is aligned with the end of k + """ + + RESIDUAL_MASK = enum.auto() + RESIDUAL_MASK_BWD = enum.auto() + WINDOW_MASK = enum.auto() + WINDOW_MASK_INFERENCE = enum.auto() + WINDOW_MASK_BWD = enum.auto() + WINDOW_MASK_BWD_INFERENCE = enum.auto() + + +class FusedMask: + """A fused mask implementation for FMHA operations. + + This class handles different types of attention masks including no mask, + residual mask for variable sequence lengths, and causal mask for + autoregressive attention patterns. + + The class provides methods to: + - Calculate trip counts for different mask types + - Apply masks to attention scores + - Handle masked and unmasked trip calculations + """ + + def get_trip_count( + mask_type: MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Int32: + """ + Calculate the number of trips needed for the current block. + + The trip count depends on the mask type and the block coordinates. + For causal masks, it considers the autoregressive constraint. + + :param mask_type: Type of mask to use + :type mask_type: utils.MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + + :return: Number of trips needed. + :rtype: Int32 + """ + result = 0 + offset = 0 + if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE): + offset = seqlen_k - seqlen_q + if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE): + offset = seqlen_q - seqlen_k + if cutlass.const_expr(mask_type == MaskEnum.RESIDUAL_MASK): + result = cute.ceil_div(seqlen_k, tile_shape[1]) + if cutlass.const_expr(mask_type is MaskEnum.RESIDUAL_MASK_BWD): + result = cute.ceil_div(seqlen_q, tile_shape[0]) + if cutlass.const_expr( + mask_type == MaskEnum.WINDOW_MASK + or mask_type == MaskEnum.WINDOW_MASK_INFERENCE + ): + if cutlass.const_expr(window_size_right is None): + result = cute.ceil_div(seqlen_k, tile_shape[1]) + else: + max_idx_q = (blk_coord[0] + 1) * tile_shape[0] + idx_k = max_idx_q + offset + window_size_right + tmp_blocks_k = cute.ceil_div(idx_k, tile_shape[1]) + max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1]) + result = min(max_blocks_k, tmp_blocks_k) + if cutlass.const_expr( + mask_type == MaskEnum.WINDOW_MASK_BWD + or mask_type == MaskEnum.WINDOW_MASK_BWD_INFERENCE + ): + if cutlass.const_expr(window_size_left is None): + result = cute.ceil_div(seqlen_q, tile_shape[0]) + else: + max_idx_k = (blk_coord[1] + 1) * tile_shape[1] + idx_k = max_idx_k + offset + window_size_left + tmp_blocks_q = cute.ceil_div(idx_k, tile_shape[0]) + max_blocks_q = cute.ceil_div(seqlen_q, tile_shape[0]) + result = min(max_blocks_q, tmp_blocks_q) + start_block = FusedMask.get_trip_start( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + result = result - start_block + return result + + @cute.jit + def get_trip_start( + mask_type: MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Int32: + """ + Get the start of the trip for the current block. + + :param mask_type: Type of mask to use + :type mask_type: utils.MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + """ + result = 0 + offset = 0 + if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE): + offset = seqlen_k - seqlen_q + if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE): + offset = seqlen_q - seqlen_k + if cutlass.const_expr( + mask_type is MaskEnum.WINDOW_MASK + or mask_type is MaskEnum.WINDOW_MASK_INFERENCE + ): + if cutlass.const_expr(window_size_left is not None): + min_idx_q = blk_coord[0] * tile_shape[0] + idx_k = min_idx_q + offset - window_size_left + tmp_blocks_k = idx_k // tile_shape[1] + result = max(tmp_blocks_k, result) + if cutlass.const_expr( + mask_type is MaskEnum.WINDOW_MASK_BWD + or mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE + ): + if cutlass.const_expr(window_size_right is not None): + min_idx_k = blk_coord[1] * tile_shape[1] + idx_q = min_idx_k + offset - window_size_right + tmp_blocks_q = idx_q // tile_shape[0] + result = max(tmp_blocks_q, result) + return result + + @cute.jit + def get_leading_mask_id( + mask_type: MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Tuple[Int32, Int32]: + """ + Get the begin and end tile idx for the leading mask. + + :param mask_type: Type of mask to use + :type mask_type: utils.MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + + :return: Tuple of (begin, end) tile idx for the leading mask. + :rtype: Tuple[Int32, Int32] + """ + offset = 0 + if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE): + offset = seqlen_k - seqlen_q + if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE): + offset = seqlen_q - seqlen_k + leading_mask_begin = FusedMask.get_trip_start( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + trip_count = FusedMask.get_trip_count( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + + leading_mask_end = leading_mask_begin + if cutlass.const_expr( + mask_type is MaskEnum.WINDOW_MASK + or mask_type is MaskEnum.WINDOW_MASK_INFERENCE + ): + if cutlass.const_expr(window_size_left is not None): + min_idx_q = ( + (blk_coord[0] + 1) * tile_shape[0] + offset - window_size_left + ) + leading_mask_end = min( + cute.ceil_div(min_idx_q, tile_shape[1]) - 1, + trip_count + leading_mask_begin - 1, + ) + else: + leading_mask_end = leading_mask_begin - 1 + elif cutlass.const_expr( + mask_type is MaskEnum.WINDOW_MASK_BWD + or mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE + ): + if cutlass.const_expr(window_size_right is not None): + min_idx_k = ( + (blk_coord[1] + 1) * tile_shape[1] + offset - window_size_right + ) + leading_mask_end = cute.ceil_div(min_idx_k, tile_shape[0]) - 1 + else: + leading_mask_end = leading_mask_begin - 1 + return leading_mask_begin, leading_mask_end + + @cute.jit + def get_trailing_mask_id( + mask_type: MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Tuple[Optional[Int32], Optional[Int32]]: + """ + Get the begin and end tile idx for the trailing mask. + + :param mask_type: Type of mask to use + :type mask_type: utils.MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + + :return: Tuple of (begin, end) tile idx for the trailing mask. + :rtype: Tuple[Int32, Int32] + """ + offset = 0 + if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE): + offset = seqlen_k - seqlen_q + if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE): + offset = seqlen_q - seqlen_k + trip_start = FusedMask.get_trip_start( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + trip_count = FusedMask.get_trip_count( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + + trailing_mask_begin, trailing_mask_end = None, None + if cutlass.const_expr( + mask_type is MaskEnum.WINDOW_MASK + or mask_type is MaskEnum.WINDOW_MASK_INFERENCE + ): + if cutlass.const_expr(window_size_right is not None): + min_idx_q = blk_coord[0] * tile_shape[0] + offset + window_size_right + trailing_mask_begin = min( + min_idx_q // tile_shape[1], trip_count + trip_start - 1 + ) + trailing_mask_end = trip_count + trip_start - 1 + else: + # last tile, we always apply mask on it regardless whether it's a residual tile + trailing_mask_begin = trip_count + trip_start - 1 + trailing_mask_end = trip_count + trip_start - 1 + else: + if cutlass.const_expr(window_size_left is not None): + min_idx_k = blk_coord[1] * tile_shape[1] + offset + window_size_left + 1 + max_idx_k = ( + (blk_coord[1] + 1) * tile_shape[1] + offset + window_size_left + ) + trailing_mask_begin = min( + cute.ceil_div(min_idx_k, tile_shape[0]) - 1, + trip_count + trip_start - 1, + ) + trailing_mask_end = min( + cute.ceil_div(max_idx_k, tile_shape[0]) - 1, + trip_count + trip_start - 1, + ) + else: + # last tile, we always apply mask on it regardless whether it's a residual tile + trailing_mask_begin = trip_count + trip_start - 1 + trailing_mask_end = trip_count + trip_start - 1 + + return trailing_mask_begin, trailing_mask_end + + @cute.jit + def get_masked_leading_count( + mask_type: MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Int32: + """ + Calculate the number of masked trips for the leading mask. + + This is used for blocks that need special handling due to masking. + + :param mask_type: Type of mask to use + :type mask_type: utils.MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + + :return: Number of masked trips. + :rtype: Int32 + """ + result = 0 + if cutlass.const_expr( + mask_type is not MaskEnum.RESIDUAL_MASK + and mask_type is not MaskEnum.RESIDUAL_MASK_BWD + ): + if cutlass.const_expr( + window_size_left is not None or window_size_right is not None + ): + leading_mask_begin, leading_mask_end = FusedMask.get_leading_mask_id( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + result = max(leading_mask_end - leading_mask_begin + 1, 0) + + return result + + @cute.jit + def get_masked_trailing_count( + mask_type: MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + rem_count: Optional[Int32] = 0, + ) -> Int32: + """ + Calculate the number of masked trips for the trailing mask. + + This is used for blocks that need special handling due to masking. + + :param mask_type: Type of mask to use + :type mask_type: utils.MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + :param rem_count: Remaining count from previous calculations. + :type rem_count: Int32 + + :return: Number of masked trips. + :rtype: Int32 + """ + result = 0 + + if cutlass.const_expr( + mask_type is not MaskEnum.RESIDUAL_MASK + and mask_type is not MaskEnum.RESIDUAL_MASK_BWD + ): + if cutlass.const_expr( + window_size_left is not None or window_size_right is not None + ): + trailing_mask_begin, trailing_mask_end = FusedMask.get_trailing_mask_id( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + leading_mask_begin, leading_mask_end = FusedMask.get_leading_mask_id( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + if cutlass.const_expr( + trailing_mask_begin is not None and trailing_mask_end is not None + ): + if trailing_mask_begin <= leading_mask_end: + result = max(trailing_mask_end - leading_mask_end, 0) + else: + result = max(trailing_mask_end - trailing_mask_begin + 1, 0) + else: + if seqlen_k % tile_shape[1] != 0: + result = 1 + else: + result = 0 + + return result + rem_count + + @cute.jit + def get_unmasked_trip_count( + mask_type: MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Int32: + """ + Calculate the number of unmasked trips for the current block. + + This represents the number of trips that don't require special + masking treatment. + + :param mask_type: Type of mask to use + :type mask_type: utils.MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + + :return: Number of unmasked trips. + :rtype: Int32 + """ + result = ( + FusedMask.get_trip_count( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + - FusedMask.get_masked_leading_count( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + - FusedMask.get_masked_trailing_count( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + 0, + ) + ) + return result + + @cute.jit + def apply_mask( + mask_type: MaskEnum, + acc_qk: cute.Tensor, + index_qk: cute.Tensor, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, + index_transform: cutlass.Constexpr = lambda index_q, index_k: ( + index_q, + index_k, + ), + ): + """ + Apply the appropriate mask to the attention scores. + + This method modifies the attention scores (acc_qk) based on the mask type + and the positions in the index tensor. + + :param mask_type: Type of mask to use + :type mask_type: utils.MaskEnum + :param acc_qk: Accumulated QK attention scores tensor. + :type acc_qk: cute.Tensor + :param index_qk: Index tensor containing position information. + :type index_qk: cute.Tensor + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Optional[int] + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[int] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[int] + """ + + tidx, tidy, tidx = cute.arch.thread_idx() + offset = 0 + offset = ( + seqlen_k - seqlen_q + if cutlass.const_expr( + mask_type is MaskEnum.WINDOW_MASK_INFERENCE + or mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE + ) + else 0 + ) + for i in cutlass.range_constexpr(cute.size(acc_qk)): + index_q, index_k = index_transform(*index_qk[i]) + if cutlass.const_expr( + window_size_left is not None or window_size_right is not None + ): + if cutlass.const_expr(window_size_left is None): + if index_q + offset + window_size_right < index_k: + acc_qk[i] = -Float32.inf + if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask + acc_qk[i] = -Float32.inf + elif cutlass.const_expr(window_size_right is None): + if index_q + offset - window_size_left > index_k: + acc_qk[i] = -Float32.inf + if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask + acc_qk[i] = -Float32.inf + else: + max_K_index = min(index_q + offset + window_size_right, seqlen_k) + min_K_index = max(0, index_q + offset - window_size_left) + if index_k > max_K_index or index_k < min_K_index: + acc_qk[i] = -Float32.inf + if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask + acc_qk[i] = -Float32.inf + + if cutlass.const_expr( + mask_type == MaskEnum.RESIDUAL_MASK + or mask_type == MaskEnum.RESIDUAL_MASK_BWD + ): + if index_k >= seqlen_k or index_q >= seqlen_q: + acc_qk[i] = -Float32.inf diff --git a/examples/python/CuTeDSL/utils/sparse_utils.py b/examples/python/CuTeDSL/utils/sparse_utils.py new file mode 100644 index 00000000..24b3f791 --- /dev/null +++ b/examples/python/CuTeDSL/utils/sparse_utils.py @@ -0,0 +1,457 @@ +import numpy as np +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack +import torch + + +@cute.jit +def print_tensor_dlpack(src: cute.Tensor): + print(src) + cute.print_tensor(src) + + +# Sparse emulation +class SparseEmulation: + def __init__(self, M: int, N: int, K: int, L: int): + self.M = M + self.N = N + self.K = K + self.L = L + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, d: cute.Tensor, e: cute.Tensor): + """Sparse emulation""" + num_threads = 128 + grid = (cute.ceil_div(self.M, num_threads), 1, 1) + block = (num_threads, 1, 1) + self.kernel(a, b, d, e).launch(grid=grid, block=block) + return + + @cute.kernel + def kernel(self, a: cute.Tensor, b: cute.Tensor, d: cute.Tensor, e: cute.Tensor): + """CUDA kernel to emulate sparse tensor core""" + tidx, tidy, tidz = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + + row_idx = tidx + bidx * self.M + meta_idx = self.K // 4 // 8 + if row_idx < self.M: + # each thread process 1 row + for col in range(self.N): + # each meta_idx stands for 32 elements + for e_idx in range(meta_idx): + meta_val = e[(row_idx, e_idx)] + for k in range(8): + # each k stands for 4 elements + meta_row = (meta_val >> (k * 4)) & 0xF + idx0 = meta_row & 0x3 + idx1 = (meta_row >> 2) & 0x3 + # calculate the idx in b tensor which has value in A tensor + km = e_idx * 16 + k * 2 + km_1 = km + 1 + kn = e_idx * 32 + k * 4 + idx0 + kn_1 = e_idx * 32 + k * 4 + idx1 + d[row_idx, col] += a[row_idx, km] * b[col, kn] + d[row_idx, col] += a[row_idx, km_1] * b[col, kn_1] + return + + +# Compressor +# compress a sparse tensor to a dense tensor && generate metadata +class Compressor: + def __init__(self, M: int, K: int, L: int): + self.M = M + self.K = K + self.L = L + self.pos_map = { + 0x4: [0, 1], + 0x8: [0, 2], + 0xC: [0, 3], + 0x9: [1, 2], + 0xD: [1, 3], + 0xE: [2, 3], + } + + @cute.jit + def _init__(self, a: cute.Tensor): + self.__init__(a.shape[0], a.shape[1], a.shape[2]) + + def compress(self, a, a_compressed, meta, run_on_cpu: bool): + if run_on_cpu: + if a.device.type != "cpu": + raise ValueError("a must be on cpu") + return self.__compress_on_cpu(a, a_compressed, meta) + else: + if a.device.type != "cuda": + raise ValueError("a must be on cuda") + return self.__compress_on_cuda(a, a_compressed, meta) + + def __compress_on_cpu(self, a, a_compressed, meta): + """ + compress the tensor on cpu + # Convert to 4-bit metadata value + # The metadata value represents which 2 elements are non-zero + # 0x4: [1,1,0,0] - first two elements are non-zero + # 0x8: [1,0,1,0] - first and third elements are non-zero + # 0xC: [1,0,0,1] - first and fourth elements are non-zero + # 0x9: [0,1,1,0] - second and third elements are non-zero + # 0xD: [0,1,0,1] - second and fourth elements are non-zero + # 0xE: [0,0,1,1] - third and fourth elements are non-zero + # special case: + # [0,0,0,0] == [0,0,1,1] + # [1,0,0,0] == [1,0,0,1] + # [0,1,0,0] == [0,1,0,1] + # [0,0,1,0] == [0,0,1,1] + # [0,0,0,1] == [0,0,1,1] + """ + M, K = a.shape + assert a_compressed.shape == ( + M, + K // 2, + ), f"Expected a_compressed shape {(M, K // 2)}, got {a_compressed.shape}" + assert meta.shape == ( + M, + K // 4 // 8, + ), f"Expected meta shape {(M, K // 4 // 8)}, got {meta.shape}" + for m in range(M): + k_meta = 0 + for k in range(0, K, 4): + chunk = a[m, k : k + 4] + + non_zero_indices = torch.nonzero(chunk).squeeze() + meta_val = 0xE + if torch.equal(non_zero_indices, torch.tensor([0, 1])): + meta_val = 0x4 + elif torch.equal(non_zero_indices, torch.tensor([0, 2])): + meta_val = 0x8 + elif torch.equal(non_zero_indices, torch.tensor([0, 3])) or torch.equal( + non_zero_indices, torch.tensor(0) + ): + meta_val = 0xC + elif torch.equal(non_zero_indices, torch.tensor([1, 2])): + meta_val = 0x9 + elif torch.equal(non_zero_indices, torch.tensor([1, 3])) or torch.equal( + non_zero_indices, torch.tensor(1) + ): + meta_val = 0xD + elif torch.equal(non_zero_indices, torch.tensor([2, 3])) or torch.equal( + non_zero_indices, torch.tensor(2) + ): + meta_val = 0xE + elif torch.equal(non_zero_indices, torch.tensor([])) or torch.equal( + non_zero_indices, torch.tensor(3) + ): + meta_val = 0xE + else: + raise ValueError(f"Invalid non-zero pattern: {non_zero_indices}") + meta_idx = k // 4 // 8 + meta_bit_pos = (k // 4) % 8 + if k_meta == meta_idx: + k_meta = meta_idx + 1 + meta[m, meta_idx] = 0 + meta[m, meta_idx] |= meta_val << (meta_bit_pos * 4) + compressed_idx = k // 2 + index = self.pos_map[meta_val] + a_compressed[m, compressed_idx] = chunk[index[0]] + a_compressed[m, compressed_idx + 1] = chunk[index[1]] + + def __compress_on_cuda(self, a, a_compressed, meta): + """ + compress the tensor on cuda + """ + a_tensor = from_dlpack(a) + a_compressed_tensor = from_dlpack(a_compressed) + meta_tensor = from_dlpack(meta) + self.compress_on_cuda_impl(a_tensor, a_compressed_tensor, meta_tensor) + return + + @cute.jit + def compress_on_cuda_impl( + self, a: cute.Tensor, a_compressed: cute.Tensor, meta: cute.Tensor + ): + """Compress the input tensor using the metadata""" + num_threads = 128 + grid = (cute.ceil_div(self.M, num_threads), 1, 1) + block = (num_threads, 1, 1) + self.compressor_impl(a, a_compressed, meta).launch(grid=grid, block=block) + + @cute.kernel + def compressor_impl( + self, a: cute.Tensor, a_compressed: cute.Tensor, meta: cute.Tensor + ): + """CUDA kernel to compress the tensor""" + tidx, tidy, tidz = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + m = a.shape[0] + k = a.shape[1] + + # each thread process 1 row + row_idx = tidx + bidx * self.M + meta_idx = self.K // 4 // 8 + if row_idx < self.M: + # each meta_idx stands for 32 elements + for i in range(meta_idx): + meta[row_idx, i] = 0 + # each k stands for 4 elements + for j in range(8): + val = a[row_idx, i * 32 + j * 4] + val_1 = a[row_idx, i * 32 + j * 4 + 1] + val_2 = a[row_idx, i * 32 + j * 4 + 2] + val_3 = a[row_idx, i * 32 + j * 4 + 3] + value_idx = 0 + value_idx_1 = 0 + value_idx_2 = 0 + value_idx_3 = 0 + pos0 = 0 + pos1 = 0 + if val != 0: + value_idx = 1 + pos0 = 0 + if val_1 != 0: + value_idx_1 = 1 + if val_2 != 0: + value_idx_2 = 1 + if val_3 != 0: + value_idx_3 = 1 + pos = [value_idx, value_idx_1, value_idx_2, value_idx_3] + tmp = 0 + if pos == [0, 0, 0, 0]: + tmp = 0xE + pos0 = 2 + pos1 = 3 + elif pos == [1, 0, 0, 0]: + tmp = 0xC + pos0 = 0 + pos1 = 3 + elif pos == [0, 1, 0, 0]: + tmp = 0xD + pos0 = 1 + pos1 = 3 + elif pos == [0, 0, 1, 0]: + tmp = 0xE + pos0 = 2 + pos1 = 3 + elif pos == [0, 0, 0, 1]: + tmp = 0xE + pos0 = 2 + pos1 = 3 + elif pos == [1, 1, 0, 0]: + tmp = 0x4 + pos0 = 0 + pos1 = 1 + elif pos == [1, 0, 1, 0]: + tmp = 0x8 + pos0 = 0 + pos1 = 2 + elif pos == [1, 0, 0, 1]: + tmp = 0xC + pos0 = 0 + pos1 = 3 + elif pos == [0, 1, 1, 0]: + tmp = 0x9 + pos0 = 1 + pos1 = 2 + elif pos == [0, 1, 0, 1]: + tmp = 0xD + pos0 = 1 + pos1 = 3 + elif pos == [0, 0, 1, 1]: + tmp = 0xE + pos0 = 2 + pos1 = 3 + # cute.printf(row_idx, cutlass.Float32(val), cutlass.Float32(val_1), cutlass.Float32(val_2), cutlass.Float32(val_3), tmp) + meta[row_idx, i] |= tmp << (j * 4) + + a_compressed[row_idx, i * 16 + j * 2] = a[ + row_idx, i * 32 + j * 4 + pos0 + ] + a_compressed[row_idx, i * 16 + j * 2 + 1] = a[ + row_idx, i * 32 + j * 4 + pos1 + ] + + return + + +# SparseUtils is used to generate sparse tensor +# format torch.Tensor +class SparseUtils: + #!brief: SparseUtils is used to generate sparse tensor + #!param: M: int, K: int, L: int, dtype: cutlass.DataType + def __init__(self, M: int, K: int, L: int, dtype): + self.M = M + self.K = K + self.L = L + self.dtype = dtype + self.meta_data = self._generate_meta_data_4_2() + self._use_specific_meta_data = False + + #!brief: cast cutlass.DataType to torch.Tensor + def _get_type(self): + if self.dtype == cutlass.Float16: + return torch.float16 + elif self.dtype == cutlass.Float32: + return torch.float32 + elif self.dtype == cutlass.Int8: + return torch.int8 + else: + raise ValueError(f"Unsupported dtype: {self.dtype}") + + def _generate_meta_data_4_2(self): + # metadata for 4:2 sparse will in range( 4,8,9,c,d,e) + # represents + # 0: [1,1,0,0] no zero pos 00,01 -> 0100 = 4 + # 1: [1,0,1,0] no zero pos 00,10 -> 1000 = 8 + # 2: [1,0,0,1] no zero pos 00,11 -> 1100 = c + # 3: [0,1,1,0] no zero pos 01,10 -> 1001 = 9 + # 4: [0,1,0,1] no zero pos 01,11 -> 1101 = d + # 5: [0,0,1,1] no zero pos 10,11 -> 1011 = e + meta_value = [0x4, 0x8, 0x9, 0xC, 0xD, 0xE] + # 4:2 sparse, so each chunk is 4 elements, map to 4 bits + K_NumChunk = self.K // 4 + meta_data = np.random.choice( + meta_value, size=(self.M, K_NumChunk), replace=True + ) + meta_data = torch.from_numpy( + np.array(meta_data).astype(np.uint8).reshape(self.M, K_NumChunk) + ) + return meta_data + + #!brief: pack meta data + def _pack_meta_data(self): + tmp = [] + K_NumChunk = self.K // 4 + for i in range(self.M): + for j in range(K_NumChunk // 8): + v = 0 + for k in range(8): + vv = int(self.meta_data[i, j * 8 + k] & 0xF) + tt = vv << (k * 4) + v = v | tt + tmp.append(v) + # debug print + # print([hex(vt) for vt in tmp]) + result = torch.from_numpy( + np.array(tmp).astype(np.uint32).reshape(self.M, K_NumChunk // 8) + ) + return result + + #!brief: use specific meta data + def use_specific_meta_data(self, meta_data: torch.Tensor = None): + if meta_data is not None: + self.meta_data = meta_data + self._use_specific_meta_data = True + + #!brief: generate sparse tensor with tensor + #!param: a: torch.Tensor + #!param: run_on_cpu: bool + #!return: torch.Tensor + def generate_sparse_4_2_tensor_with_tensor(self, a, run_on_cpu): + if run_on_cpu: + if a.device.type != "cpu": + raise ValueError("a must be on cpu") + return self.__generate_sparse_tensor_cpu(a) + else: + if a.device.type != "cuda": + raise ValueError("a must be on cuda") + a_tensor = from_dlpack(a) + packed_meta_data = self._pack_meta_data() + meta_tensor = from_dlpack(packed_meta_data.cuda()) + self.__generate_sparse_tensor_cuda(a_tensor, meta_tensor) + return a + + #!brief: generate sparse tensor + #!param: run_on_cpu: bool + #!return: torch.Tensor + def generate_4_2_sparse_tensor(self, run_on_cpu): + dtype = self._get_type() + a = torch.empty(self.M, self.K).random_(-5, 5).to(dtype) + if run_on_cpu: + return self.generate_sparse_4_2_tensor_with_tensor(a, run_on_cpu) + else: + return self.generate_sparse_4_2_tensor_with_tensor(a.cuda(), run_on_cpu) + + #!brief: generate sparse tensor on cpu + #!param: a: torch.Tensor + #!return: torch.Tensor + def __generate_sparse_tensor_cpu(self, a): + if not self._use_specific_meta_data: + for m in range(self.M): + for k in range(0, self.K, 4): + # random choose 2 zero positions + zero_indices = torch.randperm(4)[:2] + a[m, k + zero_indices[0]] = 0 + a[m, k + zero_indices[1]] = 0 + return a + else: + # use specific meta data + tensor_mask = [] + for i in range(self.M): + for j in range(self.K // 4): + meta_val = self.meta_data[i, j] + tmp = [] + if meta_val == 0x4: + tmp = [1, 1, 0, 0] + elif meta_val == 0x8: + tmp = [1, 0, 1, 0] + elif meta_val == 0xC: + tmp = [1, 0, 0, 1] + elif meta_val == 0x9: + tmp = [0, 1, 1, 0] + elif meta_val == 0xD: + tmp = [0, 1, 0, 1] + elif meta_val == 0xE: + tmp = [0, 0, 1, 1] + tensor_mask.extend(tmp) + a = torch.reshape(a, (-1,)) + mask = torch.tensor(tensor_mask) + a = a * mask + a = torch.reshape(a, (self.M, self.K)) + return a + + @cute.jit + def __generate_sparse_tensor_cuda(self, a: cute.Tensor, meta: cute.Tensor): + """Generate a sparse tensor from a dense tensor using metadata""" + assert a.shape[0] == self.M and a.shape[1] == self.K + assert meta.shape[0] == self.M and meta.shape[1] == self.K // 4 // 8 + num_threads = 128 + grid = (cute.ceil_div(self.M, num_threads), 1, 1) + block = (num_threads, 1, 1) + self.kernel(a, meta).launch(grid=grid, block=block) + + @cute.kernel + def kernel(self, a: cute.Tensor, meta: cute.Tensor): + """Apply sparsity mask to input tensor using metadata""" + tidx, tidy, tidz = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + + # each thread process 1 ro + row_idx = tidx + bidx * self.M + meta_idx = self.K // 4 // 8 + # each thread process 1 row + if row_idx < self.M: + # iterate over each chunk(32 elements) + for i in range(meta_idx): + meta_val = meta[(row_idx, i)] + # iterate over each sparse pattern(4 elements) + for j in range(8): + meta_row = (meta_val >> (j * 4)) & 0xF + idx0 = meta_row & 0x3 + idx1 = (meta_row >> 2) & 0x3 + r_id0 = 0 + r_id1 = 0 + # r_id is the idx that value is 0 + if idx0 >= 2 and idx1 >= 2: + r_id0 = 0 + r_id1 = 1 + elif idx0 <= 1 and idx1 <= 1: + r_id0 = 2 + r_id1 = 3 + else: + r_id0 = idx0 ^ 0b1 + r_id1 = idx1 ^ 0b1 + row_id0 = r_id0 + i * 32 + j * 4 + row_id1 = r_id1 + i * 32 + j * 4 + a[row_idx, row_id0] = self.dtype(0.0) + a[row_idx, row_id1] = self.dtype(0.0) + return diff --git a/examples/python/CuTeDSL/utils/test_sparse_utils.py b/examples/python/CuTeDSL/utils/test_sparse_utils.py new file mode 100644 index 00000000..3264f191 --- /dev/null +++ b/examples/python/CuTeDSL/utils/test_sparse_utils.py @@ -0,0 +1,104 @@ +import sparse_utils as su +import cutlass +import torch +from cutlass.cute.runtime import from_dlpack +import numpy as np +import pytest + + +@pytest.mark.L0 +def test_sparse_cpu(): + M = 128 + N = 32 + K = 32 + L = 1 + debug = False + # generate sparse tensor + a = torch.empty(M, K).random_(-5, 5).to(torch.float16) + sparse_utils = su.SparseUtils(M, K, L, cutlass.Float16) + if debug: + sparse_utils.use_specific_meta_data() + a_gen_from_cpu = sparse_utils.generate_sparse_4_2_tensor_with_tensor(a, True) + # print(a_gen_from_cpu) + # generate compressed tensor and meta data + a_compressed_cpu = torch.empty(M, K // 2).to(torch.float16) + meta_data_cpu = torch.empty(M, K // 4 // 8).to(torch.uint32) + compressor = su.Compressor(M, K, L) + compressor.compress(a_gen_from_cpu, a_compressed_cpu, meta_data_cpu, True) + # # test with gemm + b = torch.empty(N, K).random_(-5, 5).to(torch.float16).cuda() + d = torch.empty(M, N).zero_().to(torch.float16).cuda() + b_tensor = from_dlpack(b) + d_tensor = from_dlpack(d) + a_compressed_cpu_tensor = from_dlpack(a_compressed_cpu.cuda()) + meta_data_cpu_tensor = from_dlpack(meta_data_cpu.cuda()) + sparse_emulation = su.SparseEmulation(M, N, K, 1) + sparse_emulation(a_compressed_cpu_tensor, b_tensor, d_tensor, meta_data_cpu_tensor) + + ref = torch.einsum("mk,nk->mn", a_gen_from_cpu.cpu(), b.cpu()) + if debug: + a_ori = a_gen_from_cpu.cpu().numpy() + np.savetxt("a.txt", a_ori, fmt="%f") + a_compressed_cpu_ori = a_compressed_cpu.cpu().numpy() + np.savetxt("a_compressed_cpu.txt", a_compressed_cpu_ori, fmt="%f") + meta_data_cpu_ori = meta_data_cpu.cpu().numpy() + np.savetxt("meta_data_cpu.txt", meta_data_cpu_ori, fmt="%f") + d_ori = d.cpu().numpy() + np.savetxt("d.txt", d_ori, fmt="%f") + ref_ori = ref.cpu().numpy() + np.savetxt("ref.txt", ref_ori, fmt="%f") + torch.testing.assert_close(d.cpu(), ref) + print("cpu d == ref") + + +@pytest.mark.L0 +def test_sparse_cuda(): + M = 128 + N = 32 + K = 32 + L = 1 + debug = False + sparse_utils = su.SparseUtils(M, K, L, cutlass.Float16) + if debug: + sparse_utils.use_specific_meta_data() + # generate sparse tensor + a = torch.empty(M, K).random_(-5, 5).to(torch.float16).cuda() + a_gen_from_cuda = sparse_utils.generate_4_2_sparse_tensor(False) + # print(a_gen_from_cuda) + # generate compressed tensor and meta data + a_compressed_cuda = torch.empty(M, K // 2).to(torch.float16).cuda() + meta_data_cuda = torch.empty(M, K // 4 // 8).to(torch.uint32).cuda() + compressor = su.Compressor(M, K, L) + compressor.compress(a_gen_from_cuda, a_compressed_cuda, meta_data_cuda, False) + # test with gemm + b = torch.empty(N, K).random_(-5, 5).to(torch.float16).cuda() + d = torch.empty(M, N).zero_().to(torch.float16).cuda() + b_tensor = from_dlpack(b) + d_tensor = from_dlpack(d) + a_compressed_cuda_tensor = from_dlpack(a_compressed_cuda) + meta_data_cuda_tensor = from_dlpack(meta_data_cuda) + sparse_emulation = su.SparseEmulation(M, N, K, 1) + sparse_emulation( + a_compressed_cuda_tensor, b_tensor, d_tensor, meta_data_cuda_tensor + ) + + ref = torch.einsum("mk,nk->mn", a_gen_from_cuda.cpu(), b.cpu()) + if debug: + a_ori = a_gen_from_cuda.cpu().numpy() + np.savetxt("a.txt", a_ori, fmt="%f") + a_compressed_cuda_ori = a_compressed_cuda.cpu().numpy() + np.savetxt("a_compressed_cuda.txt", a_compressed_cuda_ori, fmt="%f") + meta_data_cuda_ori = meta_data_cuda.cpu().numpy() + np.savetxt("meta_data_cuda.txt", meta_data_cuda_ori, fmt="%f") + d_ori = d.cpu().numpy() + np.savetxt("d.txt", d_ori, fmt="%f") + ref_ori = ref.cpu().numpy() + np.savetxt("ref.txt", ref_ori, fmt="%f") + torch.testing.assert_close(d.cpu(), ref) + print("cuda d == ref") + + +if __name__ == "__main__": + cutlass.cuda.initialize_cuda_context() + test_sparse_cpu() + test_sparse_cuda() diff --git a/include/cute/arch/config.hpp b/include/cute/arch/config.hpp index 7ce05f30..2f9fd363 100644 --- a/include/cute/arch/config.hpp +++ b/include/cute/arch/config.hpp @@ -156,7 +156,8 @@ # define CUTE_ARCH_TMA_SM120_ENABLED #endif -#if (defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM121_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) # if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) # define CUTE_ARCH_F8F6F4_MMA_ENABLED # define CUTE_ARCH_MXF8F6F4_MMA_ENABLED @@ -165,15 +166,6 @@ # endif #endif -#if (defined(CUTLASS_ARCH_MMA_SM121_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) -# if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) -# define CUTE_ARCH_F8F6F4_MMA_ENABLED -# define CUTE_ARCH_MXF8F6F4_MMA_ENABLED -# define CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED -# define CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED -# endif -#endif - #if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) # define CUTE_ARCH_LDSM_SM100A_ENABLED # define CUTE_ARCH_STSM_SM100A_ENABLED diff --git a/include/cute/numeric/integral_constant.hpp b/include/cute/numeric/integral_constant.hpp index 1d33361f..1ab989a6 100644 --- a/include/cute/numeric/integral_constant.hpp +++ b/include/cute/numeric/integral_constant.hpp @@ -516,7 +516,7 @@ constexpr uint64_t parse_int_digits(uint64_t result, int digit, Ts... digits) // var has type cute::constant. // template -constexpr cute::constant operator "" _c() +constexpr cute::constant operator""_c() { static_assert((('0' <= digits && digits <= '9') && ...), "Expected 0 <= digit <= 9 for each digit of the integer."); diff --git a/include/cutlass/bfloat16.h b/include/cutlass/bfloat16.h index 0c1f4d36..1609d73b 100644 --- a/include/cutlass/bfloat16.h +++ b/include/cutlass/bfloat16.h @@ -363,6 +363,7 @@ struct numeric_limits { /// Returns smallest finite value CUTLASS_HOST_DEVICE static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x3c00); } + /// Returns smallest finite value CUTLASS_HOST_DEVICE static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); } @@ -431,6 +432,7 @@ struct numeric_limits { /// Returns smallest finite value CUTLASS_HOST_DEVICE static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x3c00); } + /// Returns smallest finite value CUTLASS_HOST_DEVICE static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); } @@ -665,12 +667,12 @@ bfloat16_t operator--(bfloat16_t & lhs, int) { // CUTLASS_HOST_DEVICE -cutlass::bfloat16_t operator "" _bf16(long double x) { +cutlass::bfloat16_t operator""_bf16(long double x) { return cutlass::bfloat16_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::bfloat16_t operator "" _bf16(unsigned long long int x) { +cutlass::bfloat16_t operator""_bf16(unsigned long long int x) { return cutlass::bfloat16_t(int(x)); } diff --git a/include/cutlass/coord.h b/include/cutlass/coord.h index 16cfa1b3..28eeebcd 100644 --- a/include/cutlass/coord.h +++ b/include/cutlass/coord.h @@ -102,6 +102,7 @@ public: template CUTLASS_HOST_DEVICE Coord(Coord other) { + static_assert(kRank == R); for (int i = 0; i < kRank; ++i) { idx[i] = other[i]; } diff --git a/include/cutlass/detail/collective/mixed_input_utils.hpp b/include/cutlass/detail/collective/mixed_input_utils.hpp index 41fc4a9e..1ffe44e9 100644 --- a/include/cutlass/detail/collective/mixed_input_utils.hpp +++ b/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -625,14 +625,14 @@ public: return 0; } else if constexpr (ModeHasScales) { - constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(filter_zeros(SmemLayoutScale{})) * size<1>(filter_zeros(SmemLayoutScale{})) * static_cast(cute::sizeof_bits_v)); + constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutScale{})) * static_cast(cute::sizeof_bits_v)); static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return scale_tx_bytes; } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { // Scale and zero share smem layout - constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(filter_zeros(SmemLayoutScale{})) * size<1>(filter_zeros(SmemLayoutScale{})) * static_cast(cute::sizeof_bits_v)); + constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutScale{})) * static_cast(cute::sizeof_bits_v)); static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA return scale_tx_bytes + zero_tx_bytes; } @@ -1189,8 +1189,7 @@ public: auto smem_thr_copy_S = smem_tiled_copy_S.get_slice(threadIdx.x % 128); Tensor sS = make_tensor(make_smem_ptr(shared_storage.input.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) - Tensor tCsS = cta_mma.partition_A(sS); - Tensor tSsS = smem_thr_copy_S.partition_S(tCsS); + Tensor tSsS = smem_thr_copy_S.partition_S(sS); Tensor tSrS = make_tensor(tSsS(_,_,_,_,0).shape()); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { @@ -1198,8 +1197,7 @@ public: } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { Tensor sZ = make_tensor(make_smem_ptr(shared_storage.input.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) - Tensor tCsZ = cta_mma.partition_A(sZ); - Tensor tZsZ = smem_thr_copy_S.partition_S(tCsZ); + Tensor tZsZ = smem_thr_copy_S.partition_S(sZ); Tensor tZrZ = make_tensor(tZsZ(_,_,_,_,0).shape()); return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS, tZrZ, tZsZ); } diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index fb09f8b1..f93eff7b 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -40,6 +40,7 @@ #include "cute/tensor.hpp" #include "cute/numeric/numeric_types.hpp" #include "cute/util/type_traits.hpp" +#include "cute/arch/copy_sm90_desc.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -315,13 +316,37 @@ public: return false; } + template CUTLASS_DEVICE auto load_init( - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormaps, - [[maybe_unused]] int32_t sm_count, - [[maybe_unused]] int32_t sm_idx) { - return cute::make_tuple(nullptr); + typename EpilogueOp::Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { + return cute::make_tuple( + tensormaps_init(params, shared_tensormaps, sm_count, sm_idx, 0) + ); + } + + template + CUTLASS_DEVICE auto + tensormaps_init( + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorMapStorage& shared_tensormaps, + [[maybe_unused]] int32_t sm_count, + [[maybe_unused]] int32_t sm_idx, + [[maybe_unused]] int32_t warp_group_idx = 0) { + // In the async tensormap update kernels, we will use operator[] to index the return value to locate the correct tensormap. + // In other kernels, we will use return value as tensormap pointer directly. + struct { + CUTLASS_DEVICE operator cute::TmaDescriptor *() const { + return reinterpret_cast(0); + } + CUTLASS_DEVICE auto operator [] (int) const { + return reinterpret_cast(0); + } + } ret; + return ret; } template< @@ -377,14 +402,17 @@ public: return load_pipe_producer_state; } + template CUTLASS_DEVICE auto store_init( - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormaps, - [[maybe_unused]] int32_t sm_count, - [[maybe_unused]] int32_t sm_idx, - [[maybe_unused]] int32_t warp_group_idx) { - return cute::make_tuple(nullptr); + typename EpilogueOp::Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx, + int32_t warp_group_idx = 0) { + return cute::make_tuple( + tensormaps_init(params, shared_tensormaps, sm_count, sm_idx, 0) + ); } template< @@ -495,6 +523,7 @@ public: // Dummy methods to perform different parts of TMA/Tensormap modifications template CUTLASS_DEVICE void @@ -504,15 +533,17 @@ public: [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] ProblemShapeMNKL problem_shape, [[maybe_unused]] int32_t next_batch, - [[maybe_unused]] int32_t warp_group_idx) { } + [[maybe_unused]] int32_t warp_group_idx = 0 + ) { } - template + template CUTLASS_DEVICE void tensormaps_cp_fence_release( [[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] cute::TmaDescriptor const* tensormap, - [[maybe_unused]] int32_t warp_group_idx) { } + [[maybe_unused]] int32_t warp_group_idx = 0 + ) { } template CUTLASS_DEVICE @@ -537,6 +568,10 @@ public: static constexpr int NumAccumulatorMtxs = Sm100EpilogueOpNumAccumulatorMtxs::value; + // Epilog assumes a max scheduler pipe count to calculate the number of asynchronous tma update buffer they need. + // In these epilogues, we don't need to update tensormaps at all. Setting this to INT_MAX. + constexpr static uint32_t NumMaxSchedulerPipelineStageCount = INT_MAX; + template CUTLASS_HOST_DEVICE static constexpr int @@ -564,13 +599,37 @@ public: // ctor inheritance using EpilogueOp::EpilogueOp; + template + CUTLASS_DEVICE auto + tensormaps_init( + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorMapStorage& shared_tensormaps, + [[maybe_unused]] int32_t sm_count, + [[maybe_unused]] int32_t sm_idx, + [[maybe_unused]] int32_t warp_group_idx = 0) const { + // In the async tensormap update kernels, we will use operator[] to index the return value to locate the correct tensormap. + // In other kernels, we will use return value as tensormap pointer directly. + struct { + CUTLASS_DEVICE operator cute::TmaDescriptor *() const { + return reinterpret_cast(0); + } + CUTLASS_DEVICE auto operator [] (int) const { + return reinterpret_cast(0); + } + } ret; + return ret; + } + + template CUTLASS_DEVICE auto load_init( - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormap, - [[maybe_unused]] int32_t const sm_count, - [[maybe_unused]] int32_t const sm_idx) const { - return cute::make_tuple(nullptr); + typename EpilogueOp::Params const& params, + TensorMapStorage& shared_tensormap, + int32_t const sm_count, + int32_t const sm_idx) const { + return cute::make_tuple( + tensormaps_init(params, shared_tensormap, sm_count, sm_idx, 0) + ); } template< @@ -633,13 +692,16 @@ public: { } + template CUTLASS_DEVICE auto store_init( - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormap, - [[maybe_unused]] int32_t const sm_count, - [[maybe_unused]] int32_t const sm_idx) const { - return cute::make_tuple(nullptr); + typename EpilogueOp::Params const& params, + TensorMapStorage& shared_tensormap, + int32_t const sm_count, + int32_t const sm_idx) const { + return cute::make_tuple( + tensormaps_init(params, shared_tensormap, sm_count, sm_idx, 0) + ); } template< @@ -703,18 +765,18 @@ public: > CUTLASS_DEVICE auto store( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - MmaTileMNK mma_tile_mnk, - TiledMma tiled_mma, - cute::Tensor& tTR_rAcc, - TensorStorage& shared_tensors, - TiledCopyT2R tiled_t2r) + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor& tTR_rAcc, + TensorStorage& shared_tensors, + TiledCopyT2R tiled_t2r) { (*this)( problem_shape_mnkl, @@ -740,19 +802,19 @@ public: > CUTLASS_DEVICE auto store( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - CtaTileMNK cta_tile_mnk, - CtaCoordMNKL cta_coord_mnkl, - MmaTileMNK mma_tile_mnk, - TiledMma tiled_mma, - cute::Tensor& tTR_rAcc, - TensorStorage& shared_tensors, - TensorMap tensormap, - TiledCopyT2R tiled_t2r) { + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor& tTR_rAcc, + TensorStorage& shared_tensors, + TensorMap tensormap, + TiledCopyT2R tiled_t2r) { (*this)( problem_shape_mnkl, cta_tile_mnk, @@ -765,6 +827,7 @@ public: template< bool ReuseTmem = false, + bool WaitForInflightTmaRequests = true, class AccumulatorPipeline, class AccumulatorPipelineState, class ProblemShapeMNKL, @@ -825,8 +888,7 @@ public: } // Dummy methods to perform different parts of TMA/Tensormap modifications - - template + template CUTLASS_DEVICE void tensormaps_perform_update( @@ -834,14 +896,16 @@ public: [[maybe_unused]] typename EpilogueOp::Params const& params, [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] ProblemShape problem_shape, - [[maybe_unused]] int32_t next_batch) { } + [[maybe_unused]] int32_t next_batch + ) { } - template + template CUTLASS_DEVICE void tensormaps_cp_fence_release( [[maybe_unused]] TensorMapStorage& shared_tensormap, - [[maybe_unused]] cute::TmaDescriptor const* tensormap) { } + [[maybe_unused]] cute::TmaDescriptor const* tensormap + ) { } template CUTLASS_DEVICE diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp index d3b2d088..afdc4528 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp @@ -398,10 +398,33 @@ public: if (epilogue_op.is_source_needed()) { ptr_C_l = params.ptr_C[l_coord]; } + auto [stride_c, stride_d] = [&, l = l_coord]() { + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (epilogue_op.is_source_needed()) { + return make_tuple( + detail::get_epilogue_stride(params.dC[l]), + detail::get_epilogue_stride(params.dD[l]) + ); + } + else { + return make_tuple( + InternalStrideC{}, + detail::get_epilogue_stride(params.dD[l]) + ); + } + } + else { + return make_tuple( + detail::get_epilogue_stride(params.dC), + detail::get_epilogue_stride(params.dD) + ); + } + }(); // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L) - Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L) + Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, stride_c); // (M,N,L) + Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, stride_d); // (M,N,L) Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) @@ -572,12 +595,7 @@ public: can_implement( [[maybe_unused]] ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { - - bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); - if (!fusion_implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); - } - return fusion_implementable; + return true; } diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp index 1f0a915d..e9f06f24 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp @@ -128,6 +128,9 @@ public: static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + // Epilog assumes a max scheduler pipe count to calculate the number of asynchronous tma update buffer they need. + constexpr static uint32_t NumMaxSchedulerPipelineStageCount = 8; + private: constexpr static bool is_source_supported = not cute::is_void_v; @@ -177,6 +180,11 @@ private: // TMA store delay only benefits with loop unrolling constexpr static bool DelayTmaStore = DelayTmaStore_ and UnrollEpiLoop; + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. + // This should be larger than the total number of TMA requests inflight (from update to issued to returned). + // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). + constexpr static uint32_t NumTmaDescriptorsPerSm = NumMaxSchedulerPipelineStageCount + std::max(StagesC, (ReuseSmemC ? StagesC : StagesD)) + 2; + struct CollectiveStorageWithC { alignas(SmemAlignmentC) ArrayEngine> smem_C; alignas(SmemAlignmentD) ArrayEngine> smem_D; @@ -346,7 +354,8 @@ public: constexpr uint32_t NumInputTensors = cute::is_void_v ? 1 : 2; constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies - return (NumInputTensors * SizeOfCuTensorMap * sm_count) + (round_nearest(FusionCallbacks::get_workspace_size(problem_shape, args.thread), MinTensorMapWorkspaceAlignment)); + return (NumInputTensors * SizeOfCuTensorMap * sm_count * NumTmaDescriptorsPerSm) + + (round_nearest(FusionCallbacks::get_workspace_size(problem_shape, args.thread), MinTensorMapWorkspaceAlignment)); } template @@ -456,16 +465,22 @@ public: return fusion_callbacks.is_producer_load_needed(); } + template CUTLASS_DEVICE auto load_init( Params const& params, TensorMapStorage& shared_tensormap, int32_t const sm_count, int32_t const sm_idx) const { - // Fetch a copy of tensormaps for the CTA from Params - constexpr bool IsEpiLoad = true; - auto load_tensormap = tensormaps_init(params, shared_tensormap, sm_count, sm_idx); - return cute::make_tuple(load_tensormap); + if constexpr (IsTmaAsyncUpdate) { + // Async update kernels will fetch the tensormap directly from tensormaps_init. + return cute::make_tuple(); + } else { + // Fetch a copy of tensormaps for the CTA from Params + constexpr bool IsEpiLoad = true; + auto load_tensormap = tensormaps_init(params, shared_tensormap, sm_count, sm_idx); + return cute::make_tuple(load_tensormap); + } } template< @@ -581,22 +596,27 @@ public: load_pipeline.producer_tail(load_pipe_producer_state); } + template CUTLASS_DEVICE auto store_init( Params const& params, TensorMapStorage& shared_tensormap, int32_t const sm_count, int32_t const sm_idx) const { - // Fetch a copy of tensormaps for the CTA from Params - constexpr bool IsEpiLoad = false; - cute::TmaDescriptor* store_tensormap = nullptr; - int thread_idx = threadIdx.x % ThreadCount; - int warp_idx = thread_idx / NumThreadsPerWarp; - // Only the first epilogue warp needs to perform TMA related operations - if (warp_idx == 0) { - store_tensormap = tensormaps_init(params, shared_tensormap, sm_count, sm_idx); + if constexpr (IsTmaAsyncUpdate) { + return cute::make_tuple(); + } else { + // Fetch a copy of tensormaps for the CTA from Params + constexpr bool IsEpiLoad = false; + cute::TmaDescriptor* store_tensormap = nullptr; + int thread_idx = threadIdx.x % ThreadCount; + int warp_idx = thread_idx / NumThreadsPerWarp; + // Only the first epilogue warp needs to perform TMA related operations + if (warp_idx == 0) { + store_tensormap = tensormaps_init(params, shared_tensormap, sm_count, sm_idx); + } + return cute::make_tuple(store_tensormap); } - return cute::make_tuple(store_tensormap); } template< @@ -1343,17 +1363,44 @@ public: // Methods to perform different parts of TMA/Tensormap modifications // - template + template CUTLASS_DEVICE auto tensormaps_init(Params const& params, TensorMapStorage& shared_tensormap, int32_t const sm_count, - int32_t const sm_idx) const { + int32_t const sm_idx, + bool const is_leader_warp = true) const { + + // Define a local struct that provides simple array indexing for TMA descriptors + struct TensorMapArray { + cute::TmaDescriptor* tma_desc; + + TensorMapArray() = default; + + CUTLASS_DEVICE + TensorMapArray(cute::TmaDescriptor* desc) : tma_desc(desc) {} + + CUTLASS_DEVICE + cute::TmaDescriptor* + operator[](int32_t idx) const { + return tma_desc + (idx % NumTmaDescriptorsPerSm); + } + }; + cute::TmaDescriptor* tma_desc = nullptr; cute::TmaDescriptor* gmem_tensormap = params.tensormaps; + + if (!is_leader_warp) { + if constexpr (IsTmaAsyncUpdate) { + return TensorMapArray{tma_desc}; + } else { + return tma_desc; + } + } + if constexpr (IsLoad) { - if (is_source_supported) { - tma_desc = &gmem_tensormap[sm_idx]; + if constexpr (is_source_supported) { + tma_desc = &gmem_tensormap[sm_idx * NumTmaDescriptorsPerSm]; if (cute::elect_one_sync()) { // Bringing tensormaps from params to smem for modification later Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{}); @@ -1364,7 +1411,7 @@ public: } } else if constexpr (is_destination_supported) { int const offset_Ddesc = cute::is_void_v ? 0 : sm_count; - tma_desc = &gmem_tensormap[sm_idx + offset_Ddesc]; + tma_desc = &gmem_tensormap[(sm_idx + offset_Ddesc) * NumTmaDescriptorsPerSm]; if (cute::elect_one_sync()) { // Bringing tensormaps from params to smem for modification later Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{}); @@ -1374,7 +1421,11 @@ public: __syncwarp(); } - return tma_desc; + if constexpr (IsTmaAsyncUpdate) { + return TensorMapArray{tma_desc}; + } else { + return tma_desc; + } } // Replace address for the global tensor (to be done by single thread) @@ -1451,15 +1502,16 @@ public: } // The entire warp must call this function collectively (that is, the instructions are aligned) - template + template CUTLASS_DEVICE void tensormaps_perform_update( TensorMapStorage& shared_tensormap, Params const& params, - cute::TmaDescriptor const* tensormap, + cute::TmaDescriptor* tensormap, ProblemShape problem_shape, - int32_t next_batch) { + int32_t next_batch + ) { if (cute::elect_one_sync()) { // Replacing global_address for the next batch tensormaps_replace_global_address(shared_tensormap, params, next_batch); @@ -1474,16 +1526,21 @@ public: // Ensure warp is converged before issuing tensormap fence release __syncwarp(); // Entire warp must do this (ie its aligned) - tensormaps_cp_fence_release(shared_tensormap, tensormap); + tensormaps_cp_fence_release( + shared_tensormap, + tensormap + ); } - template + template CUTLASS_DEVICE void tensormaps_cp_fence_release( TensorMapStorage& shared_tensormap, - cute::TmaDescriptor const* tensormap) { - // Commit and wait for all TMA load/store instructions before updating the tensormap in gmem. + cute::TmaDescriptor* tensormap + ) { + + // Commit and wait for all TMA load/store instructions before updating the tensormap in gmem if we're not using async update. // This operation only happens when the group/batch changes between consecutive tiles. // If there are no uncommitted instructions then tma_desc_commit_group results in an empty bulk async-group. auto tma_desc_wait_all_fn = [] () CUTLASS_LAMBDA_FUNC_INLINE { @@ -1492,14 +1549,19 @@ public: cute::tma_desc_wait_group(); } }; + // Entire warp must do this (ie its aligned) if constexpr (IsLoad) { - if (is_source_supported) { - tma_desc_wait_all_fn(); + if constexpr (is_source_supported) { + if constexpr (WaitForInflightTmaRequests) { + tma_desc_wait_all_fn(); + } tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_C); } } else if constexpr (is_destination_supported) { - tma_desc_wait_all_fn(); + if constexpr (WaitForInflightTmaRequests) { + tma_desc_wait_all_fn(); + } tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_D); } } @@ -1509,7 +1571,7 @@ public: void tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) { if constexpr (IsLoad) { - if (is_source_supported) { + if constexpr (is_source_supported) { cute::tma_descriptor_fence_acquire(tensormap); } } else if constexpr (is_destination_supported) { diff --git a/include/cutlass/exmy_base.h b/include/cutlass/exmy_base.h index be207a49..3588088d 100644 --- a/include/cutlass/exmy_base.h +++ b/include/cutlass/exmy_base.h @@ -79,7 +79,7 @@ enum class FpEncoding E8M23, // float E5M2, // FP8 E4M3, // FP8 - UE4M3, // FP8 + UE4M3, // FP8 UE8M0, // FP8 E3M2, // FP6 E2M3, // FP6 @@ -869,7 +869,7 @@ CUTLASS_CONSTEXPR_IF_CXX17 auto fp_encoding_selector() { else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::UE4M3) { // FP8 return cutlass::detail::FpBitRepresentation{}; } - + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::UE8M0) { // FP8 return cutlass::detail::FpBitRepresentation{}; } diff --git a/include/cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl index 3edd9280..9f94070f 100644 --- a/include/cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl @@ -106,6 +106,7 @@ sm100_compute_stage_count_or_override_fast_fp32(StageCountAutoCarveout struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassTensorOp, float, // ElementA GmemLayoutATag, // LayoutA @@ -131,6 +132,8 @@ struct CollectiveBuilder< StageCountType, BuilderScheduleTag, cute::enable_if_t< + (cute::is_same_v + ) && (not cute::is_tuple::value && not cute::is_tuple::value) && (cute::is_base_of_v) && ((sizeof(float) * AlignmentA) % detail::tma_alignment_bytes == 0) && @@ -226,9 +229,11 @@ struct CollectiveBuilder< TensorMapStorage); // Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations - static constexpr int Sm100ReducedSmemCapacityBytes = detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; static constexpr auto stage_info = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_fast_fp32< - Sm100ReducedSmemCapacityBytes, CtaTileShape_MNK, TiledMma, BuilderScheduleTag, UmmaMajorA>(StageCountType{}); + ReducedSmemCapacityBytes, CtaTileShape_MNK, TiledMma, BuilderScheduleTag, UmmaMajorA>(StageCountType{}); static constexpr int Load2TransformPipelineStageCount = get<0>(stage_info); static constexpr int Transform2MmaPipelineStageCount = get<1>(stage_info); diff --git a/include/cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl index b8824c23..487c259a 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl @@ -44,6 +44,7 @@ namespace detail { // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. template < + int CapacityBytes, class ElementAMma, class ElementB, class ElementEMma, @@ -62,6 +63,7 @@ sm100_compute_stage_count_or_override_blockscaled_sparse(StageCount stag // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. template < + int CapacityBytes, class ElementAMma, class ElementB, class ElementEMma, @@ -110,7 +112,7 @@ sm100_compute_stage_count_or_override_blockscaled_sparse(StageCountAutoCarveout< constexpr auto EpilogueSharedStorage = carveout_bytes; - constexpr auto Stages = (cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout - EpilogueSharedStorage) / + constexpr auto Stages = (CapacityBytes - KernelSmemCarveout - EpilogueSharedStorage) / (MainloopTensorStorage_per_Stage + MainloopPipelineStorage_per_Stage_aligned); return Stages; @@ -121,6 +123,7 @@ sm100_compute_stage_count_or_override_blockscaled_sparse(StageCountAutoCarveout< ///////////////////////////////////////////////////////////////////////////////////////////////// template < + class ArchTag, class ElementPairA, class GmemLayoutATag, int AlignmentA, @@ -134,7 +137,7 @@ template < class BuilderScheduleTag > struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassBlockScaledSparseTensorOp, ElementPairA, GmemLayoutATag, @@ -148,6 +151,8 @@ struct CollectiveBuilder< StageCountType, BuilderScheduleTag, cute::enable_if_t< + (cute::is_same_v + ) && // Blockscaled Sparse Gemm cute::is_base_of_v && @@ -272,7 +277,12 @@ struct CollectiveBuilder< using SmemTileShape = cute::Shape; + // Calculate SMEM capacity based on ArchTag + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes; + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled_sparse< + ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, ElementEMma, diff --git a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl index 68600c67..f5e0ed70 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl @@ -92,6 +92,7 @@ sm100_compute_stage_count_or_override_blockscaled(StageCountAutoCarveout struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassBlockScaledTensorOp, ElementPairA, GmemLayoutATag, @@ -119,6 +120,8 @@ struct CollectiveBuilder< StageCountType, BuilderScheduleTag, cute::enable_if_t< + (cute::is_same_v + ) && // Blockscaled Gemm (not cute::is_same_v) && (cute::is_base_of_v || @@ -250,12 +253,14 @@ struct CollectiveBuilder< 4 // 4 Tensor maps for A, SFA, B and SFB >::KernelSmemCarveout; // Reduce SMEM capacity available for buffers considering barrier allocations. - static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; using SmemTileShape = cute::Shape; static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled< - Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); + ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, B, SFA, and SFB."); using DispatchPolicy = diff --git a/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl index 0566905d..ec339635 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl @@ -237,6 +237,7 @@ sm100_make_trivial_tiled_mma_blockwise() { ///////////////////////////////////////////////////////////////////////////////////////////////// template < + class ArchTag, class ElementA, class GmemLayoutATagPair, int AlignmentA, @@ -250,7 +251,7 @@ template < class BuilderScheduleTag > struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassTensorOp, ElementA, GmemLayoutATagPair, @@ -264,6 +265,8 @@ struct CollectiveBuilder< StageCountType, BuilderScheduleTag, cute::enable_if_t< + (cute::is_same_v + ) && not cute::is_tuple_v && not cute::is_tuple_v && not cute::is_complex_v && not cute::is_complex_v && cute::is_tuple_v && cute::is_tuple_v && @@ -369,7 +372,9 @@ struct CollectiveBuilder< IsArrayOfPointersGemm >::KernelSmemCarveout; // Reduce SMEM capacity available for buffers considering barrier allocations. - static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; using SmemTileShape = cute::Shape; using MainloopABPipelineStorage = typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage; @@ -399,7 +404,7 @@ struct CollectiveBuilder< using ScaleTileShape = cute::Shape; static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockwise< - Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, + ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, ElementAccumulator, ScaleTileShape, SmemTileShape, MainloopABPipelineStorage, MainloopSFPipelineStorage>(StageCountType{}); static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, B, and scales."); diff --git a/include/cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl index 6a30b41b..1f61784d 100644 --- a/include/cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl @@ -39,6 +39,7 @@ namespace cutlass::gemm::collective { ///////////////////////////////////////////////////////////////////////////////////////////////// template < + class ArchTag, class ElementA, class GmemLayoutATag, int AlignmentA, @@ -52,7 +53,7 @@ template < class BuilderScheduleTag > struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassTensorOp, ElementA, GmemLayoutATag, @@ -65,10 +66,13 @@ struct CollectiveBuilder< ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) StageCountType, BuilderScheduleTag, - cute::enable_if_t || - (cute::is_same_v && - (((sizeof(ElementA) * AlignmentA) % cutlass::gemm::collective::detail::tma_alignment_bytes != 0) || - ((sizeof(ElementB) * AlignmentB) % cutlass::gemm::collective::detail::tma_alignment_bytes != 0)))> + cute::enable_if_t< + (cute::is_same_v + ) && + (cute::is_same_v || + (cute::is_same_v && + (((sizeof(ElementA) * AlignmentA) % cutlass::gemm::collective::detail::tma_alignment_bytes != 0) || + ((sizeof(ElementB) * AlignmentB) % cutlass::gemm::collective::detail::tma_alignment_bytes != 0))))> > { static_assert(cute::is_static_v, "TileShape has to be static"); @@ -137,13 +141,15 @@ struct CollectiveBuilder< CLCPipelineStorage + CLCResponseStorage); // Reduce SMEM capacity available for buffers considering barrier allocations. - static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; using SmemTileShape = cute::Shape; using MainloopPipelineStorage = typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage; static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override< - Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, MainloopPipelineStorage>(StageCountType{}); + ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, MainloopPipelineStorage>(StageCountType{}); using CollectiveOp = cutlass::gemm::collective::CollectiveMma< cutlass::gemm::MainloopSm100UmmaCpAsyncWarpSpecialized< diff --git a/include/cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl index eed95105..298fa6e9 100644 --- a/include/cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl @@ -152,6 +152,7 @@ constexpr int get_ScaleGranularityK() { // Mixed Input MMA kernels builder template < + class ArchTag, class ElementAOptionalTuple, class GmemLayoutATagTuple, int AlignmentA, @@ -165,7 +166,7 @@ template < class KernelScheduleType > struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassTensorOp, ElementAOptionalTuple, // ElementA GmemLayoutATagTuple, // LayoutA @@ -179,6 +180,8 @@ struct CollectiveBuilder< StageCountType, KernelScheduleType, cute::enable_if_t< + (cute::is_same_v + ) && (cute::is_base_of_v) && ((sizeof(float) * AlignmentA) % detail::tma_alignment_bytes == 0) && ((sizeof(float) * AlignmentB) % detail::tma_alignment_bytes == 0)>> @@ -304,12 +307,12 @@ struct CollectiveBuilder< TensorMapStorage); // Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations - static constexpr int Sm100ReducedSmemCapacityBytes = detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; static constexpr int ScaleGranularityK = get_ScaleGranularityK(); - static constexpr auto stage_info = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_mixed_input< - Sm100ReducedSmemCapacityBytes, TmaElementA, ElementAMma, ElementScale, ElementZero, ElementB, CtaTileShape_MNK, TiledMma, KernelScheduleType, UmmaMajorA, ScaleGranularityK>(StageCountType{}); + ReducedSmemCapacityBytes, TmaElementA, ElementAMma, ElementScale, ElementZero, ElementB, CtaTileShape_MNK, TiledMma, KernelScheduleType, UmmaMajorA, ScaleGranularityK>(StageCountType{}); static constexpr int Load2TransformPipelineStageCount = get<0>(stage_info); static constexpr int Transform2MmaPipelineStageCount = get<1>(stage_info); diff --git a/include/cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl b/include/cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl index fc4aa4a2..664e04e7 100644 --- a/include/cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl +++ b/include/cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl @@ -46,11 +46,13 @@ struct Sm100DenseGemmTmaUmmaCarveout { // AccumulatorPipeline = PipelineUmmaAsync static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); // CLCPipeline = PipelineCLCFetchAsync - static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // For pointer-array and grouped GEMM, we have two CLC responses, one for TMA updater, one for the TMA/MMA/Epilogue warps. + static constexpr int NumCLCResponses = (IsArrayOfPointersGemm ? 2 : 1); + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage) * NumCLCResponses; // LoadOrderBarrier = OrderedSequenceBarrier<1,2> static constexpr auto LoadOrderBarrierStorage = sizeof(typename cutlass::OrderedSequenceBarrier<1,2>::SharedStorage); // CLC (scheduler) response - static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize * NumCLCResponses; // CLC Throttle pipeline storage static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync::SharedStorage); // Tmem dealloc @@ -59,8 +61,14 @@ struct Sm100DenseGemmTmaUmmaCarveout { static constexpr auto TmemBasePtrsStorage = SchedulerPipelineStageCount * sizeof(uint32_t); // Tensormap Storage static constexpr auto TensorMapStorage = - IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * NumTensorMaps /* for A and B */ : + IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * NumTensorMaps * 5 /* We have five tensormaps smem */ : 0; + + // TensorMapReady pipeline storage (specific to grouped/array kernels) + static constexpr auto TensorMapReadyPipelineStorage = + IsArrayOfPointersGemm ? sizeof(typename cutlass::PipelineAsync::SharedStorage) : + 0; + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage static constexpr auto KernelSmemCarveout = static_cast( AccumulatorPipelineStorage + CLCPipelineStorage + @@ -69,7 +77,8 @@ struct Sm100DenseGemmTmaUmmaCarveout { CLCThrottlePipelineStorage + CLCResponseStorage + TmemBasePtrsStorage + - TensorMapStorage + TensorMapStorage + + TensorMapReadyPipelineStorage ); }; diff --git a/include/cutlass/gemm/collective/builders/sm100_simt_builder.inl b/include/cutlass/gemm/collective/builders/sm100_simt_builder.inl index 15ad6bc2..31bce408 100644 --- a/include/cutlass/gemm/collective/builders/sm100_simt_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_simt_builder.inl @@ -96,6 +96,7 @@ sm100_make_simt_f32_tiled_mma() { } // namespace detail template < + class ArchTag, class GmemLayoutATag, int AlignmentA, class GmemLayoutBTag, @@ -105,7 +106,7 @@ template < int stages, class BuilderScheduleTag> struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassSimt, float, GmemLayoutATag, @@ -119,6 +120,8 @@ struct CollectiveBuilder< StageCount, BuilderScheduleTag, cute::enable_if_t< + (cute::is_same_v + ) && (cute::is_same_v || cute::is_same_v || cute::is_same_v) && diff --git a/include/cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl index c7d380a9..ba54d2ee 100644 --- a/include/cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl @@ -44,6 +44,7 @@ namespace detail { // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. template < + int CapacityBytes, class ElementAMma, class ElementB, class ElementEMma, @@ -60,6 +61,7 @@ sm100_compute_stage_count_or_override_sparse(StageCount stage_count) { // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. template < + int CapacityBytes, class ElementAMma, class ElementB, class ElementEMma, @@ -104,7 +106,7 @@ sm100_compute_stage_count_or_override_sparse(StageCountAutoCarveout struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassSparseTensorOp, ElementA, GmemLayoutATag, @@ -296,6 +299,8 @@ struct CollectiveBuilder< StageCountType, BuilderScheduleTag, cute::enable_if_t< + (cute::is_same_v + ) && (not cute::is_tuple_v && not cute::is_tuple_v && not cute::is_complex_v && not cute::is_complex_v && not cute::is_sparse_v) && @@ -375,7 +380,12 @@ struct CollectiveBuilder< using SmemTileShape = cute::Shape; + // Calculate SMEM capacity based on ArchTag + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes; + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_sparse< + ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, ElementEMma, diff --git a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl index dfd4fece..0822673e 100644 --- a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl @@ -153,6 +153,7 @@ check_input_datatypes() { ///////////////////////////////////////////////////////////////////////////////////////////////// template < + class ArchTag, class ElementA, class GmemLayoutATag, int AlignmentA, @@ -166,7 +167,7 @@ template < class BuilderScheduleTag > struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassTensorOp, ElementA, GmemLayoutATag, @@ -180,6 +181,8 @@ struct CollectiveBuilder< StageCountType, BuilderScheduleTag, cute::enable_if_t< + (cute::is_same_v + ) && not cute::is_tuple_v && not cute::is_tuple_v && not cute::is_complex_v && not cute::is_complex_v && // Dense Gemm / PtrArrayDenseGemm @@ -265,11 +268,17 @@ struct CollectiveBuilder< // Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding. using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; using InternalStrideA = cute::remove_pointer_t; + using InternalStrideB = cute::remove_pointer_t; static constexpr bool IsArrayOfPointersGemm = (cute::is_base_of_v); + // Grouped GEMM(where Stride type is Stride*) uses specific static tile scheduler. - static constexpr bool IsGroupGemm = !cute::is_same_v; - static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return(8, 2); + // Perform checks for both StrideA and StrideB to filter out Ragged Continguous Group Gemm + static constexpr bool IsGroupGemm = !(cute::is_same_v) && !(cute::is_same_v); + static constexpr bool IsRCGroupGemm = (cute::is_same_v) && !(cute::is_same_v); + + static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return(8, 2); static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout< ClusterShape_MNK, @@ -279,23 +288,34 @@ struct CollectiveBuilder< IsArrayOfPointersGemm >::KernelSmemCarveout; // Reduce SMEM capacity available for buffers considering barrier allocations. - static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + using SmemTileShape = cute::Shape; using MainloopPipelineStorage = typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage; static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override< - Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, MainloopPipelineStorage>(StageCountType{}); + ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, MainloopPipelineStorage>(StageCountType{}); static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, and B."); using DispatchPolicy = cute::conditional_t, + cute::conditional_t, + cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK + > + >, cutlass::gemm::MainloopSm100TmaUmmaWarpSpecialized< PipelineStages, SchedulerPipelineStageCount, diff --git a/include/cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl index 83f106c2..d69af786 100644 --- a/include/cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl @@ -313,6 +313,7 @@ auto sSFB = [&]() { ///////////////////////////////////////////////////////////////////////////////////////////////// template < + class ArchTag, class ElementPairA, class GmemLayoutATag, int AlignmentA, @@ -326,7 +327,7 @@ template < class BuilderScheduleTag > struct CollectiveBuilder< - arch::Sm103, + ArchTag, arch::OpClassBlockScaledTensorOp, ElementPairA, GmemLayoutATag, @@ -340,6 +341,8 @@ struct CollectiveBuilder< StageCountType, BuilderScheduleTag, cute::enable_if_t< + (cute::is_same_v + ) && // Not paired input, Not Complex input (cute::is_tuple_v && cute::is_tuple_v && not cute::is_complex_v && not cute::is_complex_v) && @@ -495,11 +498,12 @@ struct CollectiveBuilder< TensorMapStorage + TmaPrefetchStorage); // Reduce SMEM capacity available for buffers considering barrier allocations. - static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + static constexpr int ReducedSmemCapacityBytes = + cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; using SmemTileShape = cute::Shape, Int, _128>; // SmemAllocTypes are uint8_t. We always allocate 128bytes static constexpr auto PipelineStages = cutlass::gemm::collective::detail::sm103_compute_stage_count_or_override_blockscaled< - Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); + ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); using DispatchPolicy = typename cute::conditional_t(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; + // Reserve 128B for 8 stages of tile scheduling static constexpr size_t SchedulerPipelineStorage = cute::is_pointer_v> ? sizeof(cutlass::PipelineDetail::PipelineAsyncSharedStorage<8>) : 0; diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index 9e3ae800..15a89fa9 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -55,6 +55,7 @@ #if !defined(__CUDACC_RTC__) #include "cutlass/gemm/collective/sm100_mma_warpspecialized.hpp" #include "cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_rcggemm.hpp" #include "cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp" #include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp" #include "cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp" diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp index 2665ef1c..cf9a1349 100644 --- a/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp @@ -143,6 +143,11 @@ struct CollectiveMma< using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. + // This should be larger than the total number of TMA requests inflight (from update to issued to returned). + // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). + constexpr static uint32_t NumTmaDescriptorsPerSm = SchedulerPipelineStageCount + Stages + 2; + using ElementPairA = ElementPairA_; using ElementPairB = ElementPairB_; using ElementAMma = typename TiledMma::ValTypeA; @@ -571,13 +576,18 @@ struct CollectiveMma< }; } + struct TensorMaps : cute::aligned_struct<256, _0> { + cute::TmaDescriptor tma_desc_a; + cute::TmaDescriptor tma_desc_b; + cute::TmaDescriptor tma_desc_sfa; + cute::TmaDescriptor tma_desc_sfb; + }; + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { - constexpr uint32_t NumInputTensors = 4; - constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); - // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies - return (NumInputTensors * SizeOfCuTensorMap * sm_count); + // Allocate gmem space for input tensormaps per each SM. + return (sm_count * sizeof(TensorMaps) * NumTmaDescriptorsPerSm); } template @@ -674,7 +684,7 @@ struct CollectiveMma< /// mcast_mask_b - tma multicast mask for B /// mcast_mask_sfa - tma multicast mask for SFA /// mcast_mask_sfb - tma multicast mask for SFB - template + template CUTLASS_DEVICE auto load_init( ProblemShape_MNKL const& problem_shape_MNKL, @@ -682,6 +692,7 @@ struct CollectiveMma< TensorStorage& shared_tensors, TensorMapStorage& shared_tensormaps, int32_t const sm_count, int32_t const sm_idx, + [[maybe_unused]] int32_t num_groups, int32_t init_group) const { using X = Underscore; @@ -788,15 +799,19 @@ struct CollectiveMma< uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); - // Fetch a copy of tensormaps for the CTA from Params - auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); + auto ret = cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values + mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb); // multicast masks - return cute::make_tuple( - gA_mkl, gB_nkl, // for scheduler - tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values - tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values - mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb, // multicast masks - input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + if constexpr (IsTensorMapUpdateAsync) { + return ret; + } else { + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); + return cute::tuple_cat(ret, cute::make_tuple(input_tensormaps)); + } } /// Set up the data needed by this collective for mma compute. @@ -895,7 +910,8 @@ struct CollectiveMma< cute::tuple> const& load_inputs, TileCoordMNKL const& cta_coord_mnkl, KTileIterator k_tile_iter, int k_tile_count, - bool did_batch_change) { + bool did_batch_change, + [[maybe_unused]] int curr_batch) { auto [unused_gA, unused_gB, tAgA_mkl, tBgB_nkl, tAsA, tBsB, @@ -1116,19 +1132,15 @@ struct CollectiveMma< // Methods to perform different parts of TMA/Tensormap modifications // + template CUTLASS_DEVICE auto tensormaps_init( Params const& mainloop_params, TensorMapStorage& shared_tensormaps, int32_t const sm_count, int32_t const sm_idx) const { - cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; - cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; - cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; - - cute::TmaDescriptor* tma_desc_sfa = &gmem_tensormap[sm_idx + 2 * sm_count]; - cute::TmaDescriptor* tma_desc_sfb = &gmem_tensormap[sm_idx + 3 * sm_count]; + TensorMaps* gmem_tensormap = &(reinterpret_cast(mainloop_params.tensormaps)[sm_idx * NumTmaDescriptorsPerSm]); if (cute::elect_one_sync()) { // Bringing tensormaps from params to smem for modification later @@ -1148,9 +1160,30 @@ struct CollectiveMma< copy(recast(pSFA_tensormap), recast(sSFA_tensormap)); copy(recast(pSFB_tensormap), recast(sSFB_tensormap)); } + __syncwarp(); - return cute::make_tuple(tma_desc_a, tma_desc_b, tma_desc_sfa, tma_desc_sfb); + struct TensorMapArray { + + TensorMaps *tensor_maps; + + TensorMapArray() = default; + + CUTLASS_DEVICE + TensorMapArray(void* tensormaps) : tensor_maps(reinterpret_cast(tensormaps)) {} + + CUTLASS_DEVICE + cute::tuple + operator[](int32_t idx) { + idx = idx % NumTmaDescriptorsPerSm; + return cute::make_tuple(&tensor_maps[idx].tma_desc_a, &tensor_maps[idx].tma_desc_b, &tensor_maps[idx].tma_desc_sfa, &tensor_maps[idx].tma_desc_sfb); + } + }; + if constexpr (IsTensorMapUpdateAsync) { + return TensorMapArray(gmem_tensormap); + } else { + return cute::make_tuple(&gmem_tensormap->tma_desc_a, &gmem_tensormap->tma_desc_b, &gmem_tensormap->tma_desc_sfa, &gmem_tensormap->tma_desc_sfb); + } } // Replace address for the global tensor (to be done by single thread) @@ -1244,7 +1277,7 @@ struct CollectiveMma< } // The entire warp must call this function collectively (that is, the instructions are aligned) - template + template CUTLASS_DEVICE void tensormaps_perform_update( @@ -1252,10 +1285,9 @@ struct CollectiveMma< Params const& mainloop_params, cute::tuple const& input_tensormaps, ProblemShape problem_shape, - int32_t next_batch) { + int32_t next_batch + ) { if (cute::elect_one_sync()) { - // Replacing global_address for the next batch - tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); if constexpr (IsGroupedGemmKernel) { auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); @@ -1263,23 +1295,34 @@ struct CollectiveMma< tensormaps_replace_global_tensor_properties(shared_tensormaps, mainloop_params, next_batch, problem_shape_MNKL); } + + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); } // Ensure warp is converged before issuing tensormap fence release __syncwarp(); // Entire warp must do this (ie its aligned) - tensormaps_cp_fence_release(shared_tensormaps, input_tensormaps); + tensormaps_cp_fence_release( + shared_tensormaps, + input_tensormaps + ); } - template + template CUTLASS_DEVICE void tensormaps_cp_fence_release ( TensorMapStorage& shared_tensormaps, - cute::tuple const& input_tensormaps) { - if (cute::elect_one_sync()) { - cute::tma_desc_commit_group(); - cute::tma_desc_wait_group(); + cute::tuple const& input_tensormaps + ) { + + if constexpr (WaitForInflightTmaRequests) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } } + // Entire warp must do this (i.e. it's aligned) tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); diff --git a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp index d832a1fc..e78ac3c4 100644 --- a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp @@ -123,6 +123,11 @@ struct CollectiveMma< using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. + // This should be larger than the total number of TMA requests inflight (from update to issued to returned). + // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). + constexpr static uint32_t NumTmaDescriptorsPerSm = SchedulerPipelineStageCount + Stages + 2; + using ElementA = ElementA_; using ElementAMma = typename TiledMma::ValTypeA; using StrideA = StrideA_; @@ -417,7 +422,7 @@ struct CollectiveMma< constexpr uint32_t NumInputTensors = 2; constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies - return (NumInputTensors * SizeOfCuTensorMap * sm_count); + return (NumInputTensors * SizeOfCuTensorMap * sm_count * NumTmaDescriptorsPerSm); } template @@ -497,7 +502,7 @@ struct CollectiveMma< /// tBsB - partitioned smem tensor for B /// mcast_mask_a - tma multicast mask for A /// mcast_mask_b - tma multicast mask for B - template + template CUTLASS_DEVICE auto load_init( ProblemShape_MNKL const& problem_shape_MNKL, @@ -505,6 +510,7 @@ struct CollectiveMma< TensorStorage& shared_tensors, TensorMapStorage& shared_tensormaps, int32_t const sm_count, int32_t const sm_idx, + [[maybe_unused]] int32_t num_groups, [[maybe_unused]] int32_t init_group) const { using X = Underscore; @@ -550,14 +556,20 @@ struct CollectiveMma< uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); - // Fetch a copy of tensormaps for the CTA from Params - auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); - - return cute::make_tuple( + auto ret = cute::make_tuple( gA_mkl, gB_nkl, // for scheduler tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values - mcast_mask_a, mcast_mask_b, // multicast masks - input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + mcast_mask_a, mcast_mask_b // multicast masks + ); + + if constexpr (IsTensorMapUpdateAsync) { + return ret; + } + else { + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); + return cute::tuple_cat(ret, cute::make_tuple(input_tensormaps)); + } } /// Set up the data needed by this collective for mma compute. @@ -612,7 +624,8 @@ struct CollectiveMma< cute::tuple> const& load_inputs, TileCoordMNKL const& cta_coord_mnkl, KTileIterator k_tile_iter, int k_tile_count, - bool did_batch_change) { + bool did_batch_change, + [[maybe_unused]] int curr_batch) { auto [unused_gA, unused_gB, tAgA_mkl, tBgB_nkl, tAsA, tBsB, @@ -739,6 +752,7 @@ struct CollectiveMma< // Methods to perform different parts of TMA/Tensormap modifications // + template CUTLASS_DEVICE auto tensormaps_init( Params const& mainloop_params, @@ -747,8 +761,8 @@ struct CollectiveMma< int32_t const sm_idx) const { cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; - cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; - cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx * NumTmaDescriptorsPerSm]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[(sm_idx + sm_count) * NumTmaDescriptorsPerSm]; if (cute::elect_one_sync()) { // Bringing tensormaps from params to smem for modification later @@ -762,7 +776,29 @@ struct CollectiveMma< } __syncwarp(); - return cute::make_tuple(tma_desc_a, tma_desc_b); + struct TensorMapArray { + cute::TmaDescriptor* tma_desc_a; + cute::TmaDescriptor* tma_desc_b; + + TensorMapArray() = default; + + CUTLASS_DEVICE + TensorMapArray(cute::TmaDescriptor* tma_desc_a, cute::TmaDescriptor* tma_desc_b) : tma_desc_a(tma_desc_a), tma_desc_b(tma_desc_b) {} + + CUTLASS_DEVICE + cute::tuple + operator[](int32_t idx) { + idx = idx % NumTmaDescriptorsPerSm; + return cute::make_tuple(tma_desc_a + idx, tma_desc_b + idx); + } + }; + + if constexpr (IsTensorMapUpdateAsync) { + return TensorMapArray(tma_desc_a, tma_desc_b); + } + else { + return cute::make_tuple(tma_desc_a, tma_desc_b); + } } // Replace address for the global tensor (to be done by single thread) @@ -826,7 +862,7 @@ struct CollectiveMma< } // The entire warp must call this function collectively (that is, the instructions are aligned) - template + template CUTLASS_DEVICE void tensormaps_perform_update( @@ -834,7 +870,8 @@ struct CollectiveMma< Params const& mainloop_params, cute::tuple const& input_tensormaps, ProblemShape problem_shape, - int32_t next_batch) { + int32_t next_batch + ) { if (cute::elect_one_sync()) { // Replacing global_address for the next batch tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); @@ -849,18 +886,24 @@ struct CollectiveMma< // Ensure warp is converged before issuing tensormap fence release __syncwarp(); // Entire warp must do this (ie its aligned) - tensormaps_cp_fence_release(shared_tensormaps, input_tensormaps); + tensormaps_cp_fence_release( + shared_tensormaps, + input_tensormaps + ); } - template + template CUTLASS_DEVICE void tensormaps_cp_fence_release ( TensorMapStorage& shared_tensormaps, - cute::tuple const& input_tensormaps) { - if (cute::elect_one_sync()) { - cute::tma_desc_commit_group(); - cute::tma_desc_wait_group(); + cute::tuple const& input_tensormaps + ) { + if constexpr (WaitForInflightTmaRequests) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } } // Entire warp must do this (i.e. it's aligned) tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); diff --git a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_rcggemm.hpp b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_rcggemm.hpp new file mode 100644 index 00000000..2eb48fb8 --- /dev/null +++ b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_rcggemm.hpp @@ -0,0 +1,899 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/cuda_host_adapter.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100RCGroupGemmTmaUmmaWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100RCGroupGemmTmaUmmaWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. + // This should be larger than the total number of TMA requests inflight (from update to issued to returned). + // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). + constexpr static uint32_t NumTmaDescriptorsPerSm = SchedulerPipelineStageCount + Stages + 2; + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + using TmaInternalElementA = cute::conditional_t, cutlass::tfloat32_t, ElementAMma>; + using TmaInternalElementB = cute::conditional_t, cutlass::tfloat32_t, ElementBMma>; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + + template + struct TmemStorage { + AccTensor accumulators; + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + InternalStrideA dA{}; + ArrayElementB const** ptr_B{nullptr}; + StrideB dB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + cute::TmaDescriptor* tensormaps; + ArrayElementA const* ptr_A; + InternalStrideA dA; + ArrayElementB const** ptr_B; + StrideB dB; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_K_A = get<2>(init_shape); + auto init_K_B = get<2>(init_shape); + auto init_L = get<3>(init_shape); + + // Tensor pointers will be fixed before the first access + auto ptr_A_first_batch = recast_ptr(args.ptr_A); + TmaInternalElementB const* ptr_B_first_batch = nullptr; + + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_K_A = get<2>(problem_shape_MNK); + + InternalStrideA stride_a = args.dA; + InternalStrideB stride_b = InternalStrideB{}; + + // Batches/Groups are managed by using appropriate pointers to input matrices. + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M, init_K_A, problem_shapes.groups()), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N, init_K_B, init_L), stride_b)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + reinterpret_cast(workspace), + args.ptr_A, + args.dA, + reinterpret_cast(args.ptr_B), + args.dB + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 1; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count * NumTmaDescriptorsPerSm); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + [[maybe_unused]] Arguments const& args) { + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + return partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return tmem_storage.accumulators(_,_,_,stage); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx, + int32_t num_groups, + [[maybe_unused]] int32_t init_group) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + const int32_t mock_L = 1; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,num_groups)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,mock_L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-Cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + auto ret = cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b // multicast masks + ); + + if constexpr (IsTensorMapUpdateAsync) { + return ret; + } + else { + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); + return cute::tuple_cat(ret, cute::make_tuple(input_tensormaps)); + } + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + [[maybe_unused]] TmemStorage tmem_storage, + TensorStorage& shared_tensors) const { + + // Allocate "fragments/descriptors" for A and B matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; + } + + return cute::make_tuple(tiled_mma, tCrA, tCrB); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TensorMapB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change, + int curr_batch) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b, + input_tensormaps] = load_inputs; + + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change) { + tensormaps_fence_acquire(input_tensormaps); + } + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, curr_batch); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + } + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + template + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx * NumTmaDescriptorsPerSm]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + __syncwarp(); + + struct TensorMapArray { + cute::TmaDescriptor* tma_desc_b; + + TensorMapArray() = default; + + CUTLASS_DEVICE + TensorMapArray(cute::TmaDescriptor* tma_desc_b) : tma_desc_b(tma_desc_b) {} + + CUTLASS_DEVICE + cute::tuple + operator[](int32_t idx) { + idx = idx % NumTmaDescriptorsPerSm; + return cute::make_tuple(tma_desc_b + idx); + } + }; + + if constexpr (IsTensorMapUpdateAsync) { + return TensorMapArray(tma_desc_b); + } + else { + return cute::make_tuple(tma_desc_b); + } + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + + TmaInternalElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, + prob_shape_B, prob_stride_B); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape problem_shape, + int32_t next_batch + ) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_MNKL); + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release( + shared_tensormaps, + input_tensormaps + ); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps + ) { + if constexpr (WaitForInflightTmaRequests) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp index 5adc2b81..b6e4912c 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp @@ -262,7 +262,6 @@ public: AtomThrShapeMNK>; using Mma2AccumPipelineState = typename Mma2AccumPipeline::PipelineState; - static constexpr int ScaleGranularityMN = size<0,0>(LayoutScale{}); static constexpr int ScaleGranularityK = size<1,0>(LayoutScale{}); using ScaleConfig = cutlass::detail::Sm100MixedInputBlockwiseScaleConfig< @@ -273,11 +272,10 @@ public: decltype(make_shape(size<0>(TileShape{}), size<2>(TileShape{}))), decltype(make_shape(size<1>(TileShape{}), size<2>(TileShape{})))>; - static constexpr int ScaleTileShape_MN = get<0>(ScaleTileShape{}); + using SmemLayoutAtomScaleFull = decltype(ScaleConfig::smem_atom_layout_scale(ScaleTileShape{})); - static constexpr int ScaleK = get<1>(ScaleTileShape{}) / ScaleGranularityK; - - using SmemLayoutAtomScale = decltype(ScaleConfig::smem_atom_layout_scale(ScaleTileShape{})); + // Getting the SmemSizeMN and SmemSizeK from the mixed_dtype blockwise utils. + using SmemLayoutAtomScale = decltype(slice(make_coord(make_coord(_,0),make_coord(_,0)), SmemLayoutAtomScaleFull{})); static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -324,10 +322,10 @@ public: append(CtaShapeB_NK{}, Int{}), (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); - using SmemLayoutScale = decltype(make_layout( - append(shape(SmemLayoutAtomScale{}), Int{}), - append(stride(SmemLayoutAtomScale{}), size(filter_zeros(SmemLayoutAtomScale{}))) - )); + using SmemLayoutScale = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomScale{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); static_assert(DispatchPolicy::Load2TransformPipelineStageCount >= 2 && DispatchPolicy::Load2TransformPipelineStageCount >= 2, "Specialization requires Stages set to value 2 or more."); @@ -437,12 +435,13 @@ public: using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); - using TMA_Scale = decltype(make_tma_atom( + using TMA_Scale = decltype(make_tma_atom_A_sm100( GmemTiledCopyScale{}, make_tensor(static_cast(nullptr), LayoutScale{}), - SmemLayoutScale{}(_,_,cute::Int<0>{}), - ScaleTileShape{}, - size<2>(ClusterLayout_VMNK{})) + SmemLayoutScale{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) ); TMA_Scale tma_load_scale; @@ -576,13 +575,13 @@ public: ElementScale const* ptr_S = args.ptr_S; Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), args.layout_S); - typename Params::TMA_Scale tma_load_scale = make_tma_atom( - GmemTiledCopyScale{}, - tensor_scale, - SmemLayoutScale{}(_,_,cute::Int<0>{}), - ScaleTileShape{}, - size<2>(cluster_layout_vmnk) - ); + typename Params::TMA_Scale tma_load_scale = make_tma_atom_A_sm100( + GmemTiledCopyScale{}, + tensor_scale, + SmemLayoutScale{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) { typename Params::TMAScaleParams scale_params{tma_load_scale, {}}; @@ -598,12 +597,13 @@ public: } else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { Tensor tensor_zero = make_tensor(detail::get_logical_ptr(args.ptr_Z), args.layout_S); - typename Params::TMA_Scale tma_load_zero = make_tma_atom( - GmemTiledCopyScale{}, - tensor_zero, - SmemLayoutScale{}(_,_,cute::Int<0>{}), - ScaleTileShape{}, - size<2>(cluster_layout_vmnk)); + typename Params::TMA_Scale tma_load_zero = make_tma_atom_A_sm100( + GmemTiledCopyScale{}, + tensor_zero, + SmemLayoutScale{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); typename Params::TMAScaleParams scale_params{tma_load_scale, tma_load_zero}; return { @@ -932,12 +932,11 @@ public: Tensor sS = make_tensor(make_smem_ptr(shared_storage.input.smem_scale.begin()), SmemLayoutScale{}); Tensor tCgS_mkl = cta_mma.partition_A(gS_mkl); // (MMA, MMA_M, MMA_K, m, k, l) - Tensor tCsS = cta_mma.partition_A(sS); // Project the cta_layout for tma_scale along the n-modes auto [tSgS_mkl, tSsS] = tma_partition(params.tma_load_scale, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), - group_modes<0,3>(tCsS), group_modes<0,3>(tCgS_mkl)); + group_modes<0,3>(sS), group_modes<0,3>(tCgS_mkl)); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple( @@ -953,11 +952,10 @@ public: Tensor tCgZ_mkl = cta_mma.partition_A(gZ_mkl); // (MMA, MMA_M, MMA_K, m, k, l) - Tensor tCsZ = cta_mma.partition_A(sZ); // Project the cta_layout for tma_scale along the n-modes auto [tZgZ_mkl, tZsZ] = tma_partition(params.tma_load_zero, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), - group_modes<0,3>(tCsZ), group_modes<0,3>(tCgZ_mkl)); + group_modes<0,3>(sZ), group_modes<0,3>(tCgZ_mkl)); return cute::make_tuple( gA_mkl, gB_nkl, // for scheduler tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values @@ -1134,7 +1132,7 @@ public: setup_copy_ops(sA, InputCopyAtomA{}, sACompute, [&](auto &arg) {return TiledMma::make_fragment_A(arg);}, ComputeCopyAtomA{}); // Partition of thread -> shared and thread -> RF - auto fragment_compute = TiledMma::make_fragment_A(sACompute); + auto fragment_compute = TiledMma::make_fragment_A(sS); fragment_compute.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); auto r2t_tiled_copy = make_tmem_copy(ComputeCopyAtomA{}, fragment_compute(_,_,_,0)); auto src_copy_scale = make_tiled_copy_S(Copy_Atom{}, r2t_tiled_copy); diff --git a/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp b/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp index 03163121..bcdb01a9 100755 --- a/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp +++ b/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp @@ -101,7 +101,8 @@ struct CollectiveMma< using StridePairB = StridePairB_; using SmemCopyAtomsA = SmemCopyAtomsA_; using SmemCopyAtomsB = SmemCopyAtomsB_; - + using RuntimeDataTypeA = void*; + using RuntimeDataTypeB = void*; using TiledMma = TiledMma_; using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; using DispatchPolicy = MainloopSm120TmaWarpSpecializedSparseBlockScaled; diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 6f42fc7b..4076e52e 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -539,7 +539,7 @@ struct KernelTmaWarpSpecializedInputTransformSm100 final { static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; }; -// InputTransform GEMM +// Mixed Input Transform GEMM template< int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_ @@ -1177,6 +1177,21 @@ struct MainloopSm100ArrayTmaUmmaWarpSpecialized { using Schedule = KernelPtrArrayTmaWarpSpecializedSm100; }; +// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100RCGroupGemmTmaUmmaWarpSpecialized { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + constexpr static bool IsOverlappingAccum = false; + using Schedule = KernelPtrArrayTmaWarpSpecializedSm100; +}; + // n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule template< int Stages_, @@ -1241,7 +1256,6 @@ struct MainloopSm100ArrayTmaUmmaWarpSpecializedFastF32 { constexpr static int Stages = Load2TransformPipelineStageCount; }; - // n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule template< int LoadABPipelineStageCount_, diff --git a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp index b86919b1..185b9d5d 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp @@ -132,7 +132,10 @@ public: using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; - static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + //For the case of RCGroupedGemm we are still GroupedGemm but our StrideA will not match with InternalStrideA + // Hence it's better to take this decision based upon StrideB + static constexpr bool IsGroupedGemmKernel = !(cute::is_same_v); using TileSchedulerTag = cute::conditional_t; using TileScheduler = typename detail::TileSchedulerSelector< @@ -140,25 +143,40 @@ public: using TileSchedulerArguments = typename TileScheduler::Arguments; using TileSchedulerParams = typename TileScheduler::Params; + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + static constexpr bool IsTensorMapUpdateAsync = not IsSchedDynamicPersistent; static constexpr bool IsDynamicCluster = not cute::is_static_v; static constexpr uint32_t MinTensorMapWorkspaceAlignment = 64; // Warp specialization thread count per threadblock - static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; - static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumTensorMapUpdaterThreads = IsTensorMapUpdateAsync ? NumThreadsPerWarp * 4 : 0; // Four warps to update tensor maps and plumb updated tileId. + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; - static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + static_assert( + SchedulerPipelineStageCount % (IsTensorMapUpdateAsync ? NumTensorMapUpdaterThreads / NumThreadsPerWarp : 1) == 0, + "SchedulerPipelineStageCount for async tensor map update kernels must be divisible by the number of asynchronous tensor map updater warps." + ); + + static_assert( + (!IsTensorMapUpdateAsync) + || CollectiveEpilogue::NumMaxSchedulerPipelineStageCount >= SchedulerPipelineStageCount, + "The epilog collective expected a less scheduler stage count. Consider relaxing its NumMaxSchedulerPipelineStageCount parameter." + ); + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + NumTensorMapUpdaterThreads + NumMainloopLoadThreads + NumMMAThreads + NumEpilogueLoadThreads + NumEpilogueThreads; static constexpr uint32_t MinBlocksPerMultiprocessor = 1; static constexpr uint32_t NumFixupBarriers = 1; static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); - - static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + + static constexpr uint32_t GenericRegisterRequirement = 136; + static constexpr uint32_t AccumRegisterRequirement = 232; // Pipeline and pipeline state types using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; @@ -179,17 +197,42 @@ public: cutlass::PipelineCLCFetchAsync, cutlass::PipelineAsync>; using CLCPipelineState = typename CLCPipeline::PipelineState; + using TensorMapReadyPipeline = cute::conditional_t, + CLCPipeline + >; + using TensorMapReadyPipelineState = typename TensorMapReadyPipeline::PipelineState; using CLCThrottlePipeline = cute::conditional_t, cutlass::PipelineEmpty>; using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + template + struct WithTensorMapUpdateInfo : public BaseResponse { + uint16_t batch_changed = 0; + uint16_t TMA_stage = 0; + WithTensorMapUpdateInfo() = default; + CUTLASS_DEVICE WithTensorMapUpdateInfo(BaseResponse const& response) : BaseResponse(response) {} + }; + + using CLCResponseWithAdditionalInformation = cute::conditional_t< + IsTensorMapUpdateAsync, + WithTensorMapUpdateInfo, + typename TileScheduler::CLCResponse + >; + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; // Kernel level shared memory storage struct SharedStorage { - struct PipelineStorage : cute::aligned_struct<16, _1> { + // The PipelineStorageImplWithoutAsyncUpdate and PipelineStorageImplWithAsyncUpdate only differ in the + // presence of the TensorMapReadyPipelineStorage. + // We could use some other technique to avoid duplication for the common members, but any technique + // we tried would break the MSVC build. + // As a workaround, we just copied the code. + + struct PipelineStorageImplWithoutAsyncUpdate : cute::aligned_struct<16, _1> { using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; @@ -204,9 +247,34 @@ public: alignas(16) AccumulatorPipelineStorage accumulator; alignas(16) CLCThrottlePipelineStorage clc_throttle; alignas(16) arch::ClusterBarrier tmem_dealloc; - } pipelines; + }; - alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + struct PipelineStorageImplWithAsyncUpdate : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) LoadOrderBarrierStorage load_order; + alignas(16) CLCPipelineStorage clc; + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(16) arch::ClusterBarrier tmem_dealloc; + + // Below is the only difference between PipelineStorageImpl and PipelineStorageImpl + using TensorMapReadyPipelineStorage = typename TensorMapReadyPipeline::SharedStorage; + alignas(16) TensorMapReadyPipelineStorage tensor_map_ready; + }; + + using PipelineStorage = cute::conditional_t; + + PipelineStorage pipelines; + + alignas(16) CLCResponseWithAdditionalInformation clc_response[IsTensorMapUpdateAsync ? 2 : 1][SchedulerPipelineStageCount]; uint32_t tmem_base_ptr; struct TensorMapStorage : cute::aligned_struct<128, _1> { @@ -214,7 +282,7 @@ public: using MainloopTensorMapStorage = typename CollectiveMainloop::TensorMapStorage; alignas(128) EpilogueTensorMapStorage epilogue; alignas(128) MainloopTensorMapStorage mainloop; - } tensormaps; + } tensormaps[(NumTensorMapUpdaterThreads/NumThreadsPerWarp)+1]; struct TensorStorage : cute::aligned_struct<128, _1> { using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; @@ -226,7 +294,6 @@ public: }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Host facing host arguments struct Arguments { @@ -253,7 +320,9 @@ public: Sched = 1, MainloopLoad = 2, EpilogueLoad = 3, - Epilogue = 4 + Epilogue = 4, + // TensorMapUpdater starts at 256 thread alignment + TensorMapUpdater = 8 }; struct IsParticipant { @@ -262,6 +331,7 @@ public: uint32_t main_load = false; uint32_t epi_load = false; uint32_t epilogue = false; + uint32_t tensor_map_updater = false; }; // @@ -480,18 +550,20 @@ public: using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); auto problem_shape = params.problem_shape; // Account for more than one epilogue warp int warp_idx = canonical_warp_idx_sync(); - WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) - : WarpCategory::Epilogue; + WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) + : warp_idx < static_cast(WarpCategory::TensorMapUpdater) ? WarpCategory::Epilogue + : WarpCategory::TensorMapUpdater; uint32_t lane_predicate = cute::elect_one_sync(); auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}); int cluster_size = size(cluster_shape); uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); - bool is_first_cta_in_cluster = IsSchedDynamicPersistent ? (cta_rank_in_cluster == 0) : true; + bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); bool is_mma_leader_cta = cta_coord_v == 0; constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; @@ -508,12 +580,57 @@ public: bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); IsParticipant is_participant = { (warp_category == WarpCategory::MMA), // mma - (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::Sched) && (IsSchedDynamicPersistent ? is_first_cta_in_cluster : true), // sched (warp_category == WarpCategory::MainloopLoad), // main_load (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load - (warp_category == WarpCategory::Epilogue) // epilogue + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::TensorMapUpdater) && IsTensorMapUpdateAsync // tensor_map_updater }; + int32_t sm_id = static_cast(cutlass::arch::SmId()); + if constexpr (IsGroupedGemmKernel) { + // In case user wants to engage less SMs than available on device + sm_id = blockIdx.x + (blockIdx.y * gridDim.x); + } + auto tensormaps_init_main_load = [&] () { + if constexpr (IsTensorMapUpdateAsync) { + return collective_mainloop.template tensormaps_init( + params.mainloop, + shared_storage.tensormaps[0].mainloop, + params.hw_info.sm_count, + sm_id + ); + } + else { + return nullptr; + } + }; + + auto tensormaps_init_epi_load = [&] () { + if constexpr (IsTensorMapUpdateAsync) { + return collective_epilogue.template tensormaps_init( + params.epilogue, + shared_storage.tensormaps[0].epilogue, + params.hw_info.sm_count, + sm_id + ); + } + else { + return nullptr; + } + }; + + decltype(tensormaps_init_main_load()) pre_init_main_load_tensormaps; + decltype(tensormaps_init_epi_load()) pre_init_epi_load_tensormaps; + + + if (is_participant.main_load) { + pre_init_main_load_tensormaps = tensormaps_init_main_load(); + } + if (is_participant.epi_load) { + pre_init_epi_load_tensormaps = tensormaps_init_epi_load(); + } + // Mainloop Load pipeline typename MainloopPipeline::Params mainloop_pipeline_params; if (WarpCategory::MainloopLoad == warp_category) { @@ -581,9 +698,14 @@ public: clc_pipeline_params.transaction_bytes = CLCResponseSize; } else { - clc_pipeline_params.consumer_arv_count = NumMainloopLoadThreads + NumEpilogueThreads + NumMMAThreads; - if (is_epi_load_needed) { - clc_pipeline_params.consumer_arv_count += NumEpilogueLoadThreads; + if constexpr (IsTensorMapUpdateAsync) { + clc_pipeline_params.consumer_arv_count = NumThreadsPerWarp; + } + else { + clc_pipeline_params.consumer_arv_count = NumMainloopLoadThreads + NumEpilogueThreads + NumMMAThreads; + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += NumEpilogueLoadThreads; + } } } // Now declare the pipeline outside the if constexpr @@ -596,6 +718,32 @@ public: } }(); + auto tensor_map_ready_pipeline = [&] () { + if constexpr (IsGroupedGemmKernel) { + // TMA update ready pipeline + typename TensorMapReadyPipeline::Params tensor_map_ready_pipeline_params; + + if (WarpCategory::TensorMapUpdater == warp_category) { + tensor_map_ready_pipeline_params.role = TensorMapReadyPipeline::ThreadCategory::Producer; + } + else { + tensor_map_ready_pipeline_params.role = TensorMapReadyPipeline::ThreadCategory::Consumer; + } + + tensor_map_ready_pipeline_params.initializing_warp = 8; + tensor_map_ready_pipeline_params.producer_arv_count = NumThreadsPerWarp; + + tensor_map_ready_pipeline_params.consumer_arv_count = NumMainloopLoadThreads + NumEpilogueThreads + NumMMAThreads; + if (is_epi_load_needed) { + tensor_map_ready_pipeline_params.consumer_arv_count += NumEpilogueLoadThreads; + } + return TensorMapReadyPipeline(shared_storage.pipelines.tensor_map_ready, tensor_map_ready_pipeline_params); + } + else { + return clc_pipeline; + } + }(); + // Mainloop-Epilogue pipeline typename AccumulatorPipeline::Params accumulator_pipeline_params; if (WarpCategory::MMA == warp_category) { @@ -672,19 +820,66 @@ public: CLCPipelineState clc_pipe_consumer_state; CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + TensorMapReadyPipelineState tensor_map_ready_pipe_consumer_state; + TensorMapReadyPipelineState tensor_map_ready_pipe_producer_state = cutlass::make_producer_start_state(); + AccumulatorPipelineState accumulator_pipe_consumer_state; AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); dim3 block_id_in_cluster = cute::block_id_in_cluster(); - int32_t sm_id = static_cast(cutlass::arch::SmId()); // Calculate mask after cluster barrier arrival mainloop_pipeline.init_masks(cluster_shape, block_id_in_cluster); accumulator_pipeline.init_masks(cluster_shape, block_id_in_cluster); // TileID scheduler - TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); - typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + TileScheduler scheduler( + (!IsTensorMapUpdateAsync || is_participant.sched || is_participant.tensor_map_updater) + ? &shared_storage.clc_response[0][0] + : &shared_storage.clc_response[1][0], + params.scheduler, + block_id_in_cluster + ); + + auto work_tile_info = [&] () { + if constexpr (!IsSchedDynamicPersistent) { + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction. + // For the static grouped scheduler, the problem shapes + // might be produced by a previous kernel in global memory. + cutlass::arch::wait_on_dependent_grids(); + } + if constexpr (IsTensorMapUpdateAsync) { + return scheduler.initial_work_tile_info(cluster_shape, [] (typename TileScheduler::CLCResponse response) { + CLCResponseWithAdditionalInformation response_with_additional_info = response; + response_with_additional_info.TMA_stage = 0; + response_with_additional_info.batch_changed = 1; + return response_with_additional_info; + }); + } + else { + return scheduler.initial_work_tile_info(cluster_shape); + } + } (); + + auto get_tma_desc_offset = [] ([[maybe_unused]] const auto& tile_info) { + if constexpr (IsTensorMapUpdateAsync) { + return tile_info.TMA_stage; + } + else { + return 0; + } + }; + + auto get_tensormap = [] (auto& tensormaps, [[maybe_unused]] auto tma_desc_offset) { + if constexpr (IsTensorMapUpdateAsync) { + return tensormaps[tma_desc_offset]; + } + else { + return tensormaps; + } + }; + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); // @@ -698,44 +893,67 @@ public: // When problem shapes are only on device, the grid launched may be larger than the total number of blocks across groups return; } - // In case user wants to engage less SMs than available on device - sm_id = blockIdx.x + (blockIdx.y * gridDim.x); } // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(work_tile_info.L_idx), 1); if (is_participant.main_load) { - auto load_inputs = collective_mainloop.load_init( - problem_shape_MNKL, params.mainloop, - shared_storage.tensors.mainloop, - shared_storage.tensormaps.mainloop, - params.hw_info.sm_count, sm_id, work_tile_info.L_idx); - // Ensure that the prefetched kernel does not touch // unflushed global memory prior to this instruction cutlass::arch::wait_on_dependent_grids(); + auto load_inputs = collective_mainloop.load_init( + problem_shape_MNKL, params.mainloop, + shared_storage.tensors.mainloop, + shared_storage.tensormaps[0].mainloop, + params.hw_info.sm_count, sm_id, problem_shape.groups(), work_tile_info.L_idx); + bool do_load_order_arrive = is_epi_load_needed; Tensor gA_mkl = get<0>(load_inputs); + // Fetch a copy of tensormaps for the CTA from Params - auto input_tensormaps = get(load_inputs); + auto input_tensormaps = [&] ([[maybe_unused]] auto inputs) { + if constexpr (IsTensorMapUpdateAsync) { + return pre_init_main_load_tensormaps; + } + else { + static constexpr size_t idx = rank(inputs) - 1; + return get(inputs); + } + } (load_inputs); + + if constexpr (IsTensorMapUpdateAsync) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + } + + auto pad_inputs = [] (auto& inputs, [[maybe_unused]] auto tensormaps) { + if constexpr (IsTensorMapUpdateAsync) { + return cute::tuple_cat(inputs, cute::make_tuple(tensormaps)); + } + else { + return inputs; + } + }; // Initial batch's tensor address update // Even the first tile for a CTA can be from any of the batches. // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool is_first_iteration = true; bool did_batch_change = true; bool requires_clc_query = true; do { + auto tma_desc_offset = get_tma_desc_offset(work_tile_info); int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); // Usually just returns work_tile_info.L_idx; if constexpr (IsGroupedGemmKernel) { problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); } - if (did_batch_change) { + if (IsTensorMapUpdateAsync ? is_first_iteration : did_batch_change) { collective_mainloop.tensormaps_perform_update( - shared_storage.tensormaps.mainloop, + shared_storage.tensormaps[0].mainloop, params.mainloop, - input_tensormaps, + get_tensormap(input_tensormaps, tma_desc_offset), problem_shape, curr_batch ); @@ -763,10 +981,11 @@ public: params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - load_inputs, + pad_inputs(load_inputs, get_tensormap(input_tensormaps, tma_desc_offset)), cta_coord_mnk, k_tile_iter, k_tile_prologue, - did_batch_change + IsTensorMapUpdateAsync ? is_first_iteration : did_batch_change, // did_batch_change + curr_batch ); mainloop_pipe_producer_state = mainloop_producer_state_next; @@ -779,10 +998,11 @@ public: params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - load_inputs, + pad_inputs(load_inputs, get_tensormap(input_tensormaps, tma_desc_offset)), cta_coord_mnk, k_tile_iter_next, k_tile_count - k_tile_prologue, - false /* did_batch_change - prologue loads handle tensormap acquire */ + false, /* did_batch_change - prologue loads handle tensormap acquire */ + curr_batch ); mainloop_pipe_producer_state = mainloop_producer_state_next_; @@ -791,16 +1011,17 @@ public: auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( work_tile_info, - clc_pipeline, - clc_pipe_consumer_state + tensor_map_ready_pipeline, + tensor_map_ready_pipe_consumer_state ); work_tile_info = next_work_tile_info; cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); requires_clc_query = increment_pipe; if (increment_pipe) { - ++clc_pipe_consumer_state; + ++tensor_map_ready_pipe_consumer_state; } // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + is_first_iteration = false; did_batch_change = curr_batch != idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); } while (work_tile_info.is_valid()); collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); @@ -808,6 +1029,12 @@ public: } else if (is_participant.sched) { + + if constexpr (IsTensorMapUpdateAsync) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + } + // Grouped GEMM uses static tile scheduler if constexpr (IsSchedDynamicPersistent) { // Whether a new CLC query must be performed. @@ -853,20 +1080,135 @@ public: } else { - cutlass::arch::wait_on_dependent_grids(); + static_assert(IsTensorMapUpdateAsync || IsSchedDynamicPersistent, "We only support async tensor map update with static persistent scheduler"); + + auto update_tensor_map_stages = [&] (typename TileScheduler::CLCResponse next_work_tile_info_from_scheduler) { + if constexpr (IsTensorMapUpdateAsync) { + CLCResponseWithAdditionalInformation next_work_tile_info = next_work_tile_info_from_scheduler; + auto tensor_map_buffer_stage = work_tile_info.TMA_stage; + next_work_tile_info.batch_changed = work_tile_info.L_idx != next_work_tile_info.L_idx; + if (next_work_tile_info.batch_changed) { + ++tensor_map_buffer_stage; + } + next_work_tile_info.TMA_stage = tensor_map_buffer_stage; + return next_work_tile_info; + } + }; do { - auto [next_work_tile_info, increment_pipe] = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + auto [next_work_tile_info, increment_pipe] = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state, 1, update_tensor_map_stages); work_tile_info = next_work_tile_info; if (increment_pipe) { ++clc_pipe_producer_state; } } while (work_tile_info.is_valid()); - clc_pipeline.producer_tail(clc_pipe_producer_state); + + // Push additional invalid work items for all tensormap updater threads + for (int i = 0; i < NumTensorMapUpdaterThreads / NumThreadsPerWarp;) { + auto [next_work_tile_info, increment_pipe] = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state, 1, update_tensor_map_stages); + work_tile_info = next_work_tile_info; + if (increment_pipe) { + ++clc_pipe_producer_state; + ++i; + } + } + } + } + + else if (is_participant.tensor_map_updater) { + if constexpr (IsTensorMapUpdateAsync) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + } + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + if constexpr (IsTensorMapUpdateAsync) { + auto updater_id = canonical_warp_idx_sync() - static_cast(WarpCategory::TensorMapUpdater); + + clc_pipe_consumer_state += updater_id; + tensor_map_ready_pipe_producer_state += updater_id; + + auto tensormaps_mainloop = collective_mainloop.tensormaps_init( + params.mainloop,shared_storage.tensormaps[updater_id+1].mainloop, params.hw_info.sm_count, sm_id); + auto tensormaps_epilogue_load = collective_epilogue.template tensormaps_init( + params.epilogue, shared_storage.tensormaps[updater_id+1].epilogue, params.hw_info.sm_count, sm_id); + auto tensormaps_epilogue_store = collective_epilogue.template tensormaps_init( + params.epilogue, shared_storage.tensormaps[updater_id+1].epilogue, params.hw_info.sm_count, sm_id); + + auto update_tensor_map_and_increment_pipe_if_needed = [&] (auto &next_work_tile_info, auto &increment_pipe) { + auto next_batch = next_work_tile_info.L_idx; + auto did_batch_change = next_work_tile_info.batch_changed; + + if (increment_pipe) { + tensor_map_ready_pipeline.producer_acquire(tensor_map_ready_pipe_producer_state); + if (next_work_tile_info.is_valid() && did_batch_change) { + auto tma_desc_offset = get_tma_desc_offset(next_work_tile_info); + collective_mainloop.template tensormaps_perform_update( + shared_storage.tensormaps[updater_id+1].mainloop, + params.mainloop, + tensormaps_mainloop[tma_desc_offset], + problem_shape, + next_batch + ); + + if (collective_epilogue.is_producer_load_needed()) { + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps[updater_id+1].epilogue, + params.epilogue, + tensormaps_epilogue_load[tma_desc_offset], + problem_shape, + next_batch + ); + } + + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps[updater_id+1].epilogue, + params.epilogue, + tensormaps_epilogue_store[tma_desc_offset], + problem_shape, + next_batch + ); + + collective_mainloop.tensormaps_fence_acquire(tensormaps_mainloop[tma_desc_offset]); + + if (collective_epilogue.is_producer_load_needed()) { + collective_epilogue.template tensormaps_fence_acquire(tensormaps_epilogue_load[tma_desc_offset]); + } + collective_epilogue.template tensormaps_fence_acquire(tensormaps_epilogue_store[tma_desc_offset]); + } + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + shared_storage.clc_response[1][tensor_map_ready_pipe_producer_state.index()] = next_work_tile_info; + cutlass::arch::fence_view_async_shared(); + cute::tma_desc_wait_group(); + } + + // Signal the other warps that the TMA update is complete + tensor_map_ready_pipeline.producer_commit(tensor_map_ready_pipe_producer_state); + tensor_map_ready_pipe_producer_state += (NumTensorMapUpdaterThreads / NumThreadsPerWarp); + clc_pipe_consumer_state += (NumTensorMapUpdaterThreads / NumThreadsPerWarp); + } + }; + + do { + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + update_tensor_map_and_increment_pipe_if_needed(next_work_tile_info, increment_pipe); + work_tile_info = next_work_tile_info; + + } while (work_tile_info.is_valid()); } } else if (is_participant.mma) { + // Tmem allocation sequence tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); __syncwarp(); @@ -875,19 +1217,13 @@ public: collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); auto mma_inputs = collective_mainloop.mma_init(tmem_storage, shared_storage.tensors.mainloop); + if constexpr (IsTensorMapUpdateAsync) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + } + do { - // Fetch next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( - work_tile_info, - clc_pipeline, - clc_pipe_consumer_state - ); - - if (increment_pipe) { - ++clc_pipe_consumer_state; - } - if constexpr (IsGroupedGemmKernel) { problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(work_tile_info.L_idx), 1); } @@ -915,6 +1251,17 @@ public: } ++accumulator_pipe_producer_state; + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + tensor_map_ready_pipeline, + tensor_map_ready_pipe_consumer_state + ); + + if (increment_pipe) { + ++tensor_map_ready_pipe_consumer_state; + } + work_tile_info = next_work_tile_info; cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); } while (work_tile_info.is_valid()); @@ -949,6 +1296,12 @@ public: } else if (is_participant.epi_load) { + + if constexpr (IsTensorMapUpdateAsync) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + } + // Ensure that the prefetched kernel does not touch // unflushed global memory prior to this instruction cutlass::arch::wait_on_dependent_grids(); @@ -958,37 +1311,42 @@ public: int current_wave = 0; // Fetch a copy of tensormaps for the CTA from Params - auto epi_load_tensormap = get<0>(collective_epilogue.load_init( - params.epilogue, shared_storage.tensormaps.epilogue, params.hw_info.sm_count, sm_id)); + auto epi_load_tensormap = [&] () { + if constexpr (IsTensorMapUpdateAsync) { + collective_epilogue.template load_init( + params.epilogue, + shared_storage.tensormaps[0].epilogue, + params.hw_info.sm_count, + sm_id + ); + return pre_init_epi_load_tensormaps; + } + else { + return get<0>(collective_epilogue.template load_init( + params.epilogue, shared_storage.tensormaps[0].epilogue, params.hw_info.sm_count, sm_id)); + } + } (); + // Initial batch's tensor address update // Even the first tile for a CTA can be from any of the batches. // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool is_first_iteration = true; bool did_batch_change = true; constexpr bool IsEpiLoad = true; do { int32_t curr_batch = work_tile_info.L_idx; - if (did_batch_change) { - collective_epilogue.template tensormaps_perform_update( - shared_storage.tensormaps.epilogue, + auto tma_desc_offset = get_tma_desc_offset(work_tile_info); + if (IsTensorMapUpdateAsync ? is_first_iteration : did_batch_change) { + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps[0].epilogue, params.epilogue, - epi_load_tensormap, + get_tensormap(epi_load_tensormap, tma_desc_offset), problem_shape, curr_batch ); } bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); - // Get current work tile and fetch next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( - work_tile_info, - clc_pipeline, - clc_pipe_consumer_state - ); - work_tile_info = next_work_tile_info; - - if (increment_pipe) { - ++clc_pipe_consumer_state; - } if (compute_epilogue) { if (do_load_order_wait) { @@ -1009,7 +1367,7 @@ public: TileShape{}, TiledMma{}, shared_storage.tensors.epilogue, - cute::make_tuple(epi_load_tensormap, did_batch_change), + cute::make_tuple(get_tensormap(epi_load_tensormap, tma_desc_offset), IsTensorMapUpdateAsync ? is_first_iteration : did_batch_change), reverse_epi_n ); @@ -1017,9 +1375,21 @@ public: } current_wave++; + // Fetch the next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + tensor_map_ready_pipeline, + tensor_map_ready_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++tensor_map_ready_pipe_consumer_state; + } // Calculate the cta coordinates of the next work tile cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + is_first_iteration = false; did_batch_change = curr_batch != work_tile_info.L_idx; } while (work_tile_info.is_valid()); @@ -1035,6 +1405,11 @@ public: } else if (is_participant.epilogue) { + if constexpr (IsTensorMapUpdateAsync) { + // Register reconfiguration + arch::warpgroup_reg_alloc(); + } + // Wait for tmem allocate here tmem_allocation_result_barrier.arrive_and_wait(); uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; @@ -1043,20 +1418,43 @@ public: auto warp_idx_in_epi = canonical_warp_idx_sync() - static_cast(WarpCategory::Epilogue); bool do_tail_store = false; // Fetch a copy of tensormaps for the CTA from Params - auto epi_store_tensormap = get<0>(collective_epilogue.store_init( - params.epilogue, shared_storage.tensormaps.epilogue, params.hw_info.sm_count, sm_id)); + auto epi_store_tensormap = [&] () { + if constexpr (IsTensorMapUpdateAsync) { + collective_epilogue.template store_init( + params.epilogue, + shared_storage.tensormaps[0].epilogue, + params.hw_info.sm_count, + sm_id + ); + + return collective_epilogue.template tensormaps_init( + params.epilogue, + shared_storage.tensormaps[0].epilogue, + params.hw_info.sm_count, + sm_id, + warp_idx_in_epi == 0 + ); + } + else { + return get<0>(collective_epilogue.template store_init( + params.epilogue, shared_storage.tensormaps[0].epilogue, params.hw_info.sm_count, sm_id)); + } + } (); + // Initial batch's tensor address update // Even the first tile for a CTA can be from any of the batches. // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool is_first_iteration = true; bool did_batch_change = true; constexpr bool IsEpiLoad = false; do { int32_t curr_batch = work_tile_info.L_idx; - if (did_batch_change && warp_idx_in_epi == 0) { - collective_epilogue.template tensormaps_perform_update( - shared_storage.tensormaps.epilogue, + auto tma_desc_offset = get_tma_desc_offset(work_tile_info); + if ((IsTensorMapUpdateAsync ? is_first_iteration : did_batch_change) && warp_idx_in_epi == 0) { + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps[0].epilogue, params.epilogue, - epi_store_tensormap, + get_tensormap(epi_store_tensormap, tma_desc_offset), problem_shape, curr_batch ); @@ -1064,12 +1462,12 @@ public: // Fetch next work tile auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( work_tile_info, - clc_pipeline, - clc_pipe_consumer_state + tensor_map_ready_pipeline, + tensor_map_ready_pipe_consumer_state ); if (increment_pipe) { - ++clc_pipe_consumer_state; + ++tensor_map_ready_pipe_consumer_state; } // Accumulator stage slice @@ -1090,7 +1488,11 @@ public: // // Epilogue and write to gD // - auto [load_state_next, store_state_next, acc_state_next] = collective_epilogue.template store( + auto [ + load_state_next, + store_state_next, + acc_state_next + ] = collective_epilogue.template store( epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, @@ -1104,7 +1506,7 @@ public: TiledMma{}, accumulator, shared_storage.tensors.epilogue, - cute::make_tuple(epi_store_tensormap, did_batch_change) + cute::make_tuple(get_tensormap(epi_store_tensormap, tma_desc_offset), IsTensorMapUpdateAsync ? is_first_iteration : did_batch_change) ); epi_load_pipe_consumer_state = load_state_next; epi_store_pipe_producer_state = store_state_next; @@ -1114,6 +1516,7 @@ public: work_tile_info = next_work_tile_info; cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + is_first_iteration = false; did_batch_change = curr_batch != work_tile_info.L_idx; } while (work_tile_info.is_valid()); @@ -1138,6 +1541,11 @@ public: } else { + if constexpr (IsTensorMapUpdateAsync) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + } + } } }; diff --git a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp index 76432e1e..9d5e5b4e 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp @@ -222,7 +222,6 @@ public: }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Host facing host arguments struct Arguments { @@ -420,6 +419,7 @@ public: using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); auto problem_shape = params.problem_shape; // Account for multiple epilogue and transformation warps diff --git a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp index 2ec1049b..ee670016 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp @@ -234,7 +234,6 @@ public: }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Host facing host arguments struct Arguments { @@ -493,6 +492,7 @@ public: using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); auto problem_shape = params.problem_shape; // Account for more than one epilogue warp diff --git a/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp index 21ff5959..626294ca 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp @@ -188,7 +188,6 @@ public: }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Host facing host arguments struct Arguments { @@ -371,6 +370,8 @@ public: using namespace cute; using X = Underscore; + + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); diff --git a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp index fb62f1b8..45523d6d 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp @@ -210,7 +210,6 @@ public: }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Host facing host arguments struct Arguments { @@ -408,6 +407,7 @@ public: using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); diff --git a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp index 24efff6f..ea7a1710 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp @@ -205,7 +205,6 @@ public: }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Device side arguments struct Arguments { @@ -396,6 +395,7 @@ public: using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); diff --git a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp index 55c18c9a..e4805725 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp @@ -202,7 +202,6 @@ public: }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Device side arguments struct Arguments { @@ -396,6 +395,7 @@ public: using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); diff --git a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp index 11d381d2..82ef91fd 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp @@ -216,7 +216,6 @@ public: }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Host facing host arguments struct Arguments { @@ -402,6 +401,7 @@ public: using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); diff --git a/include/cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp index a5f6eb9b..9cc80a06 100644 --- a/include/cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp @@ -216,7 +216,6 @@ public: }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Host facing host arguments struct Arguments { @@ -456,6 +455,7 @@ public: using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); diff --git a/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp b/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp index 8cf885f8..8da289d2 100755 --- a/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp +++ b/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp @@ -128,11 +128,11 @@ public: scheduler_sm90(params.params_sm90_, clc_response_ptr) { } // Returns the initial work tile info that will be computed over - template + template CUTLASS_DEVICE auto - initial_work_tile_info(ClusterShape cluster_shape) { - return scheduler_sm90.initial_work_tile_info(cluster_shape); + initial_work_tile_info(ClusterShape cluster_shape, CallbackBeforeCommit callback_before_commit = [] (WorkTileInfo info) { return info;}) { + return scheduler_sm90.initial_work_tile_info(cluster_shape, callback_before_commit); } template @@ -189,15 +189,16 @@ public: ); } - template + template CUTLASS_DEVICE auto advance_to_next_work( CLCPipeline& clc_pipeline, CLCPipelineState clc_pipe_producer_state, - uint32_t advance_count = 1) { + uint32_t advance_count = 1, + CallbackBeforeCommit callback_before_commit = [] (WorkTileInfo info) { return info;}) { - return scheduler_sm90.advance_to_next_work(clc_pipeline, clc_pipe_producer_state, advance_count); + return scheduler_sm90.advance_to_next_work(clc_pipeline, clc_pipe_producer_state, advance_count, callback_before_commit); } // @@ -304,11 +305,11 @@ public: } // Kernel helper function to get next CLC ID - template + template CUTLASS_DEVICE auto fetch_next_work( - WorkTileInfo work_tile_info, + WorkTileWithCallbackInfo work_tile_info, CLCPipeline& clc_pipeline, CLCPipelineState clc_pipe_consumer_state) { @@ -320,7 +321,7 @@ private: // Methods // [[nodiscard]] CUTLASS_DEVICE - static CLCResponse + static auto load_query_response(uint32_t smem_ptr) { return UnderlyingScheduler::load_query_response(smem_ptr); } diff --git a/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp index 8d6e286b..224abf38 100644 --- a/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp +++ b/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp @@ -37,6 +37,8 @@ #include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp" #include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/conv/detail.hpp" + //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel::detail { @@ -177,6 +179,44 @@ public: return params; } + template + static Params + to_underlying_arguments( + cutlass::conv::ConvProblemShape problem_shape, + TileShapeMNK tile_shape_mnk, + AtomThrShape atom_thr_shape_mnk, + ClusterShape cluster_shape_mnk, + KernelHardwareInfo const& hw_info, + Arguments const& args, + void* workspace = nullptr + ) { + + auto problem_shape_mnkl = [&] () { + // Infer im2col linearization from ConvOp and TileShape + constexpr bool is_linearized_M = (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) + && cute::depth<0>(TileShapeMNK{}) == _0{}; + constexpr bool is_linearized_K = ConvOp == conv::Operator::kWgrad && cute::depth<2>(TileShapeMNK{}) == _1{}; + if constexpr (is_linearized_M || is_linearized_K) { + // transformation + im2col linearization + return cutlass::conv::detail::get_linearized_problem_shape_MNKL(problem_shape); + } + else { + // transformation + return cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape); + } + }(); + + return to_underlying_arguments( + problem_shape_mnkl, + tile_shape_mnk, + atom_thr_shape_mnk, + cluster_shape_mnk, + hw_info, + args, + workspace + ); + } + static bool can_implement(Arguments const& args) { return UnderlyingStreamKScheduler::can_implement(args); diff --git a/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp index 06fd138d..6708a88c 100644 --- a/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp @@ -233,7 +233,6 @@ public: }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Host facing host arguments struct Arguments { @@ -508,6 +507,7 @@ public: using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); auto problem_shape = params.problem_shape; // Account for more than one epilogue warp @@ -551,7 +551,6 @@ public: typename MainloopABPipeline::Params mainloop_ab_pipeline_params; if (WarpCategory::MainloopABLoad == warp_category) { mainloop_ab_pipeline_params.role = MainloopABPipeline::ThreadCategory::Producer; - // Initialize the barrier for TMA load prefetch } if (WarpCategory::MMA == warp_category) { mainloop_ab_pipeline_params.role = MainloopABPipeline::ThreadCategory::Consumer; diff --git a/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp index ae93b2ff..c873bfcf 100644 --- a/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp @@ -212,7 +212,6 @@ public: }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Host facing host arguments struct Arguments { @@ -425,6 +424,7 @@ public: using namespace cute; using X = Underscore; + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); @@ -483,6 +483,7 @@ public: if (WarpCategory::MainloopABLoad == warp_category) { mainloop_ab_pipeline_params.role = MainloopABPipeline::ThreadCategory::Producer; // Initialize the barrier for TMA load prefetch + } if (WarpCategory::MMA == warp_category) { mainloop_ab_pipeline_params.role = MainloopABPipeline::ThreadCategory::Consumer; diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp index 92749b19..5a3961d2 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp @@ -135,7 +135,7 @@ public: // Sink scheduler params as a member Params scheduler_params; - SchedulerResponse *response_ptr_ = nullptr; + void *response_ptr_ = nullptr; ProblemShape cached_problem_shapes_[2]; // @@ -225,6 +225,8 @@ public: for (int group = 0; group < groups; group++) { auto ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes.get_host_problem_shape(group)), cute::shape<0>(cta_shape))); auto ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes.get_host_problem_shape(group)), cute::shape<1>(cta_shape))); + if(ctas_along_m <= 0) ctas_along_m = 1; + if(ctas_along_n <= 0) ctas_along_n = 1; auto problem_blocks_m = round_up(ctas_along_m, cute::get<0>(cluster_shape)); auto problem_blocks_n = round_up(ctas_along_n, cute::get<1>(cluster_shape)); total_ctas += problem_blocks_m * problem_blocks_n; @@ -301,7 +303,7 @@ public: int32_t log_swizzle_size, RasterOrder raster_order) { - int32_t valid_tile = 1; + uint8_t valid_tile = 1; // Use a warp to "speculatively" check if the work tile maps to the next 32 groups int lane_idx = canonical_lane_idx(); @@ -329,7 +331,8 @@ public: auto problem_blocks_n = round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n()); group_info.problem_blocks_along_raster_order = raster_order == RasterOrder::AlongN ? problem_blocks_n : problem_blocks_m; group_info.total_tiles = problem_blocks_m * problem_blocks_n; - } else { + } + else { group_info.total_tiles = INT_MAX; } @@ -428,23 +431,31 @@ public: scheduler_params.log_swizzle_size_, scheduler_params.raster_order_); } - template + + template CUTLASS_DEVICE auto advance_to_next_work( TileSchedulerPipeline& scheduler_pipeline, TileSchedulerPipelineState scheduler_pipe_producer_state, - uint32_t advance_count = 1) { + uint32_t advance_count = 1, + CallbackBeforeCommit callback_before_commit = [] (WorkTileInfo info) { return info;}) { current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count); auto work_tile = get_current_work_for_linear_idx(current_work_linear_idx_); + using WorkTileWithCallbackInfo = decltype(callback_before_commit(work_tile)); + WorkTileWithCallbackInfo work_tile_with_callback_info = work_tile; scheduler_pipeline.producer_acquire(scheduler_pipe_producer_state); + if (work_tile_with_callback_info.is_valid()) { + work_tile_with_callback_info = callback_before_commit(work_tile); + } + if (cute::elect_one_sync()) { - response_ptr_[scheduler_pipe_producer_state.index()] = work_tile; + reinterpret_cast(response_ptr_)[scheduler_pipe_producer_state.index()] = work_tile_with_callback_info; cutlass::arch::fence_view_async_shared(); scheduler_pipeline.producer_commit(scheduler_pipe_producer_state); } - return cute::make_tuple(work_tile, true); + return cute::make_tuple(work_tile_with_callback_info, true); } // Returns whether the block assigned this work should compute the epilogue for the corresponding @@ -555,31 +566,37 @@ public: } // Kernel helper function to get next work tile - template + template CUTLASS_DEVICE auto fetch_next_work( - WorkTileInfo work_tile_info, + WorkTileWithCallbackInfo work_tile_with_callback_info, TileSchedulerPipeline& scheduler_pipeline, TileSchedulerPipelineState scheduler_pipe_consumer_state) { - if (continue_current_work(work_tile_info)) { - return cute::make_tuple(work_tile_info, true); + if (continue_current_work(work_tile_with_callback_info)) { + return cute::make_tuple(work_tile_with_callback_info, true); } scheduler_pipeline.consumer_wait(scheduler_pipe_consumer_state); - auto work_tile = response_ptr_[scheduler_pipe_consumer_state.index()]; + work_tile_with_callback_info = reinterpret_cast(response_ptr_)[scheduler_pipe_consumer_state.index()]; cutlass::arch::fence_view_async_shared(); scheduler_pipeline.consumer_release(scheduler_pipe_consumer_state); - return cute::make_tuple(work_tile, true); + return cute::make_tuple(work_tile_with_callback_info, true); } // Returns the initial work tile info that will be computed over - template + template CUTLASS_DEVICE auto - initial_work_tile_info(ClusterShape) { - return get_current_work_for_linear_idx(current_work_linear_idx_); + initial_work_tile_info(ClusterShape, CallbackBeforeCommit callback_before_commit = [] (WorkTileInfo response) { return response;}) { + auto work_tile = get_current_work_for_linear_idx(current_work_linear_idx_); + using WorkTileWithCallbackInfo = decltype(callback_before_commit(work_tile)); + WorkTileWithCallbackInfo work_tile_with_callback_info = work_tile; + if (work_tile_with_callback_info.is_valid()) { + work_tile_with_callback_info = callback_before_commit(work_tile); + } + return work_tile_with_callback_info; } }; diff --git a/include/cutlass/half.h b/include/cutlass/half.h index 118a80d7..0c377ef5 100644 --- a/include/cutlass/half.h +++ b/include/cutlass/half.h @@ -918,12 +918,12 @@ half_t operator--(half_t & lhs, int) { // CUTLASS_HOST_DEVICE -cutlass::half_t operator "" _hf(long double x) { +cutlass::half_t operator""_hf(long double x) { return cutlass::half_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::half_t operator "" _hf(unsigned long long int x) { +cutlass::half_t operator""_hf(unsigned long long int x) { return cutlass::half_t(int(x)); } diff --git a/include/cutlass/integer_subbyte.h b/include/cutlass/integer_subbyte.h index 43047eae..65d19de9 100644 --- a/include/cutlass/integer_subbyte.h +++ b/include/cutlass/integer_subbyte.h @@ -84,6 +84,14 @@ struct integer_subbyte { integer_subbyte(float value) : integer_subbyte(static_cast(value)) {} + CUTLASS_HOST_DEVICE + integer_subbyte(double value) + : integer_subbyte(static_cast(value)) {} + + CUTLASS_HOST_DEVICE + integer_subbyte(signed char value) + : integer_subbyte(static_cast(value)) {} + // CUTLASS code commonly converts both signed and unsigned integers // into integer_subbyte, so the class provides both explicit // conversions. @@ -114,10 +122,9 @@ struct integer_subbyte { : storage(reinterpret_cast(value) & bits_mask_) { if constexpr (Signed) { - [[maybe_unused]] constexpr int lower_bound = -(1 << (Bits - 1)); + // no need to check lower bound since input value is unsigned [[maybe_unused]] constexpr int upper_bound = (1 << (Bits - 1)) - 1; - assert(value >= lower_bound); - assert(value <= upper_bound); + assert(value <= static_cast(upper_bound)); } else { [[maybe_unused]] constexpr unsigned upper_bound = 1u << Bits; @@ -205,12 +212,6 @@ using int2b_t = integer_subbyte<2, true>; /// 2-bit Unsigned integer type using uint2b_t = integer_subbyte<2, false>; -/// 3-bit Integer type -using int3b_t = integer_subbyte<3, true>; - -/// 3-bit Unsigned integer type -using uint3b_t = integer_subbyte<3, false>; - /// 4-bit Integer type using int4b_t = integer_subbyte<4, true>; diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index 86ba43a4..649b158f 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -607,6 +607,7 @@ struct alignment_of { enum { value = 16 }; }; + #if !defined(CUDA_VECTOR_TYPE_ALIGNMENT_16_32_ENABLED) #define CUDA_VECTOR_TYPE_ALIGNMENT_16_32_ENABLED (__CUDACC_VER_MAJOR__ >= 13) #endif @@ -676,6 +677,7 @@ struct alignment_of { #endif + // Specializations for volatile/const qualified types template struct alignment_of : alignment_of {}; diff --git a/include/cutlass/relatively_equal.h b/include/cutlass/relatively_equal.h index 68bdb26e..0738eee7 100644 --- a/include/cutlass/relatively_equal.h +++ b/include/cutlass/relatively_equal.h @@ -300,6 +300,7 @@ bool relatively_equal(float_ue4m3_t a, float_ue4m3_t b, float_ue4 return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); } + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/include/cutlass/tfloat32.h b/include/cutlass/tfloat32.h index 7bc13e17..b778e21d 100644 --- a/include/cutlass/tfloat32.h +++ b/include/cutlass/tfloat32.h @@ -467,12 +467,12 @@ tfloat32_t operator--(tfloat32_t & lhs, int) { // CUTLASS_HOST_DEVICE -cutlass::tfloat32_t operator "" _tf32(long double x) { +cutlass::tfloat32_t operator""_tf32(long double x) { return cutlass::tfloat32_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::tfloat32_t operator "" _tf32(unsigned long long int x) { +cutlass::tfloat32_t operator""_tf32(unsigned long long int x) { return cutlass::tfloat32_t(int(x)); } diff --git a/include/cutlass/uint128.h b/include/cutlass/uint128.h index 68896d6b..c6f8024c 100644 --- a/include/cutlass/uint128.h +++ b/include/cutlass/uint128.h @@ -60,6 +60,64 @@ #endif #endif +CUTLASS_HOST_DEVICE +uint64_t umul128( + uint64_t multiplier, + uint64_t multiplicand, + uint64_t *high_product +) { + +#if defined(CUTLASS_INT128_ARITHMETIC) + return _umul128(multiplier, multiplicand, high_product); +#else + const uint64_t mask = 0xFFFFFFFF; + + uint64_t a_lo = multiplier & mask; + uint64_t a_hi = multiplier >> 32; + uint64_t b_lo = multiplicand & mask; + uint64_t b_hi = multiplicand >> 32; + + uint64_t p_ll = a_lo * b_lo; + uint64_t p_lh = a_lo * b_hi; + uint64_t p_hl = a_hi * b_lo; + uint64_t p_hh = a_hi * b_hi; + + uint64_t p_mid = (p_ll >> 32) + (p_lh & mask) + (p_hl & mask); + uint64_t r_lo = (p_ll & mask) + (p_mid << 32); + uint64_t r_hi = (p_lh & mask) + (p_hl & mask) + p_hh; + + *high_product = r_hi; + return r_lo; +#endif +} + + +CUTLASS_HOST_DEVICE +uint64_t udiv128(uint64_t high, uint64_t low, uint64_t divisor, uint64_t *remainder_ptr) { +#if defined(CUTLASS_INT128_ARITHMETIC_DIV) + return _udiv128(high, low, divisor, remainder_ptr); +#else + uint64_t quotient = 0, remainder = 0; + uint64_t const bit = 1; + for (int32_t i=127; i>=0; --i) { + uint64_t r = 0; + if (i >= 64) { + r = ((high >> (i - 64)) & bit); + } + else { + r = ((low >> i) & bit); + } + remainder = (remainder << 1) | r; + if (remainder >= divisor) { + remainder -= divisor; + quotient |= (bit << i); + } + } + *remainder_ptr = remainder; + return quotient; +#endif +} + namespace cutlass { ///! Unsigned 128b integer type @@ -157,16 +215,13 @@ struct alignas(16) uint128_t uint128_t y{}; #if defined(CUTLASS_UINT128_NATIVE) y.native = native * rhs; -#elif defined(CUTLASS_INT128_ARITHMETIC) +#else // Multiply by the low part - y.hilo_.lo = _umul128(hilo_.lo, rhs, &y.hilo_.hi); + y.hilo_.lo = umul128(hilo_.lo, rhs, &y.hilo_.hi); // Add the high part and ignore the overflow uint64_t overflow{0}; - y.hilo_.hi += _umul128(hilo_.hi, rhs, &overflow); -#else - CUTLASS_UNUSED(rhs); - exception(); + y.hilo_.hi += umul128(hilo_.hi, rhs, &overflow); #endif return y; } @@ -178,13 +233,10 @@ struct alignas(16) uint128_t uint64_t quotient{0}; #if defined(CUTLASS_UINT128_NATIVE) quotient = uint64_t(native / divisor); -#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) +#else // implemented using MSVC's arithmetic intrinsics uint64_t remainder{0}; - quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -#else - CUTLASS_UNUSED(divisor); - exception(); + quotient = udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #endif return quotient; } @@ -196,12 +248,9 @@ struct alignas(16) uint128_t uint64_t remainder{0}; #if defined(CUTLASS_UINT128_NATIVE) remainder = uint64_t(native % divisor); -#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) - // implemented using MSVC's arithmetic intrinsics - (void)_udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #else - CUTLASS_UNUSED(divisor); - exception(); + // implemented using MSVC's arithmetic intrinsics + (void)udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #endif return remainder; } @@ -214,13 +263,9 @@ struct alignas(16) uint128_t #if defined(CUTLASS_UINT128_NATIVE) quotient = uint64_t(native / divisor); remainder = uint64_t(native % divisor); -#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) - // implemented using MSVC's arithmetic intrinsics - quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #else - CUTLASS_UNUSED(remainder); - CUTLASS_UNUSED(divisor); - exception(); + // implemented using MSVC's arithmetic intrinsics + quotient = udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #endif return quotient; } diff --git a/include/cutlass/version.h b/include/cutlass/version.h index 57a73a5f..3324fbb8 100644 --- a/include/cutlass/version.h +++ b/include/cutlass/version.h @@ -35,8 +35,8 @@ #include #define CUTLASS_MAJOR 4 -#define CUTLASS_MINOR 2 -#define CUTLASS_PATCH 1 +#define CUTLASS_MINOR 3 +#define CUTLASS_PATCH 0 #ifdef CUTLASS_VERSIONS_GENERATED #include "cutlass/version_extended.h" diff --git a/media/docs/cpp/cute/02_layout_algebra.md b/media/docs/cpp/cute/02_layout_algebra.md index 465d3aef..c35b9a67 100644 --- a/media/docs/cpp/cute/02_layout_algebra.md +++ b/media/docs/cpp/cute/02_layout_algebra.md @@ -151,7 +151,7 @@ For example, * `(3,6,2,8) / 9 => (1,2,2,8)` * `(3,6,2,8) / 72 => (1,1,1,4)` -To compute the strides of the strided layout, the residues of the above operation are used to scale the strides of `A`. For instance, the last example `(3,6,2,8):(w,x,y,z) / 72` with strides `(w,x,y,z)` produces `(72*w,24*x,4*y,2*z)` as the strides of the strided layout. +To compute the strides of the strided layout, the residues of the above operation are used to scale the strides of `A`. For instance, the last example `(3,6,2,8):(w,x,y,z) / 72` with strides `(w,x,y,z)` produces `(72*w,24*x,4*x,2*z)` as the strides of the strided layout. As you may have noticed, we can only divide shapes by certain values and get a sensible result. This is called the **stride divisibility condition** and is statically checked in CuTe when possible. @@ -388,7 +388,7 @@ Informally, `logical_divide(A, B)` splits a layout `A` into two modes -- in the Formally, this can be written as -$A \oslash B := A \circ (B,B^*)$ +$$A \oslash B := A \circ (B,B^*)$$ and implemented as ```cpp diff --git a/media/docs/pythonDSL/cute_dsl_api.rst b/media/docs/pythonDSL/cute_dsl_api.rst index 2461af39..f8bc6ee8 100644 --- a/media/docs/pythonDSL/cute_dsl_api.rst +++ b/media/docs/pythonDSL/cute_dsl_api.rst @@ -8,6 +8,6 @@ CuTe DSL API changelog cute - cute_arch cute_nvgpu + pipeline utils diff --git a/media/docs/pythonDSL/cute_dsl_api/changelog.rst b/media/docs/pythonDSL/cute_dsl_api/changelog.rst index fc4d1fd4..4de8d6e0 100644 --- a/media/docs/pythonDSL/cute_dsl_api/changelog.rst +++ b/media/docs/pythonDSL/cute_dsl_api/changelog.rst @@ -2,7 +2,32 @@ Changelog for CuTe DSL API changes ====================================== -`4.2.0 `_ (2025-09-15) +`4.3.0 `_ (2025-10-07) +============================================================================== + +* Debuggability improvements: + - Supported source location tracking for DSL APIs + - Supported dumping PTX and SASS code +* Remove deprecated ``cutlass._utils.SMEM_CAPACITY[""]`` and ``cutlass.utils.ampere_helpers`` +* Support calling nested functions without capturing variables inside dynamic control flow +* Replace usage of ``cute.arch.barrier`` in examples with corresponding APIs in ``pipeline`` + - Use ``pipeline.sync`` for simple cases like synchronizing the whole CTA + - Use ``pipeline.NamedBarrier`` to customize barriers with different participating threads and barrier id +* Added new APIs ``repeat`` and ``repeat_as_tuple`` +* Added new APIs ``make_rmem_tensor`` to replace ``make_fragment`` with better naming +* Added new APIs ``make_rmem_tensor_like`` which create rmem tensor from a tensor using the same shape with compact col-major strides +* Added ``TmemAllocator`` for allocating tensor memory +* Updated ``SmemAllocator.allocate`` to support allocation of a single scalar value +* Fixed ``TensorSSA.reduce`` to support static value as initial value +* Updated docstring for following APIs to be more concise and easier to understand: + - ``make_layout_tv`` + - ``is_static`` + - ``PipelineAsync`` + - ``SmemAllocator`` +* Fixed documentation for ``pipeline``, ``utils`` and ``cute.math`` + + +`4.2.0 `_ (2025-09-10) ============================================================================== * Added back ``cute.make_tiled_copy`` per the request from community @@ -40,7 +65,7 @@ Changelog for CuTe DSL API changes - Introduce S2T CopyOps in `tcgen05/copy.py `_. - Introduce BlockScaled layout utilities in `blockscaled_layout.py `_ for creating the required scale factor layouts in global memory, shared memory and tensor memory. -* ``cutlass.cute.compile`` now supports compilation options. Refer to `JIT compilation options `_ for more details. +* ``cutlass.cute.compile`` now supports compilation options. Refer to `JIT compilation options `_ for more details. * ``cutlass.cute.testing.assert_`` now works for device JIT function. Specify ``--enable-device-assertions`` as compilation option to enable. * ``cutlass.cute.make_tiled_copy`` is now deprecated. Please use ``cutlass.cute.make_tiled_copy_tv`` instead. * Shared memory capacity query diff --git a/media/docs/pythonDSL/cute_dsl_api/cute.rst b/media/docs/pythonDSL/cute_dsl_api/cute.rst index bd5d5c56..f64e6b03 100644 --- a/media/docs/pythonDSL/cute_dsl_api/cute.rst +++ b/media/docs/pythonDSL/cute_dsl_api/cute.rst @@ -9,3 +9,9 @@ cutlass.cute :show-inheritance: :special-members: __init__ :private-members: + +.. toctree:: + :maxdepth: 2 + :hidden: + + cute_arch diff --git a/media/docs/pythonDSL/cute_dsl_api/cute_arch.rst b/media/docs/pythonDSL/cute_dsl_api/cute_arch.rst index 4e2d4d0d..a82d9cfc 100644 --- a/media/docs/pythonDSL/cute_dsl_api/cute_arch.rst +++ b/media/docs/pythonDSL/cute_dsl_api/cute_arch.rst @@ -1,17 +1,18 @@ .. _cute_arch: -cutlass.cute.arch -================= +arch +==== -The ``cute.arch`` module contains wrappers around NVVM-level MLIR Op builders that seamlessly -inter-operate with the Python types used in CUTLASS Python. Another benefit of wrapping these Op -builders is that the source location can be tracked with the ``@dsl_user_op`` decorator. Available -functions include +The ``cute.arch`` module provides lightweight wrappers for NVVM Operation builders which implement CUDA built-in +device functions such as ``thread_idx``. It integrates seamlessly with CuTe DSL types. -- basic API like ``thr_idx``; -- functions related to the direct management of mbarriers; -- low-level SMEM management (prefer using the ``SmemAllocator`` class); -- TMEM management. +These wrappers enable source location tracking through the ``@dsl_user_op`` +decorator. The module includes the following functionality: + +- Core CUDA built-in functions such as ``thread_idx``, ``warp_idx``, ``block_dim``, ``grid_dim``, ``cluster_dim``, and related functions +- Memory barrier management functions including ``mbarrier_init``, ``mbarrier_arrive``, ``mbarrier_wait``, and associated operations +- Low-level shared memory (SMEM) management capabilities, with ``SmemAllocator`` as the recommended interface +- Low-level tensor memory (TMEM) management capabilities, with ``TmemAllocator`` as the recommended interface API documentation ----------------- diff --git a/media/docs/pythonDSL/cute_dsl_api/pipeline.rst b/media/docs/pythonDSL/cute_dsl_api/pipeline.rst new file mode 100644 index 00000000..eda23991 --- /dev/null +++ b/media/docs/pythonDSL/cute_dsl_api/pipeline.rst @@ -0,0 +1,9 @@ +cutlass.pipeline +================ + +.. automodule:: cutlass.pipeline + :members: + :undoc-members: + :show-inheritance: + :special-members: __init__ + :private-members: diff --git a/media/docs/pythonDSL/cute_dsl_api/utils.rst b/media/docs/pythonDSL/cute_dsl_api/utils.rst index 086bef60..9079a7f8 100644 --- a/media/docs/pythonDSL/cute_dsl_api/utils.rst +++ b/media/docs/pythonDSL/cute_dsl_api/utils.rst @@ -1,9 +1,19 @@ cutlass.utils ============= +The ``cutlass.utils`` module contains utilities for developing kernels with CuTe DSL. + .. automodule:: cutlass.utils :members: :undoc-members: :show-inheritance: :special-members: __init__ :private-members: + :exclude-members: sm90_make_smem_layout_a, sm90_make_smem_layout_b, sm90_make_smem_layout_epi + +.. toctree:: + :maxdepth: 2 + :hidden: + + utils_sm90 + utils_sm100 diff --git a/media/docs/pythonDSL/cute_dsl_api/utils_sm100.rst b/media/docs/pythonDSL/cute_dsl_api/utils_sm100.rst new file mode 100644 index 00000000..1ac4814f --- /dev/null +++ b/media/docs/pythonDSL/cute_dsl_api/utils_sm100.rst @@ -0,0 +1,10 @@ +.. _utils_sm100: + +Utilities for SM100 +=================== + +.. automodule:: cutlass.utils.sm100 + :members: + :undoc-members: + :show-inheritance: + :special-members: __init__ diff --git a/media/docs/pythonDSL/cute_dsl_api/utils_sm90.rst b/media/docs/pythonDSL/cute_dsl_api/utils_sm90.rst new file mode 100644 index 00000000..2acd3145 --- /dev/null +++ b/media/docs/pythonDSL/cute_dsl_api/utils_sm90.rst @@ -0,0 +1,10 @@ +.. _utils_sm90: + +Utilities for SM90 +================== + +.. automodule:: cutlass.utils.sm90 + :members: + :undoc-members: + :show-inheritance: + :special-members: __init__ diff --git a/media/docs/pythonDSL/cute_dsl_general/debugging.rst b/media/docs/pythonDSL/cute_dsl_general/debugging.rst index 6302100b..93c75112 100644 --- a/media/docs/pythonDSL/cute_dsl_general/debugging.rst +++ b/media/docs/pythonDSL/cute_dsl_general/debugging.rst @@ -15,6 +15,14 @@ Understanding these limitations will help you avoid potential pitfalls from the Please refer to :doc:`../limitations` for more details. +Source Code Correlation +----------------------- + +CuTe DSL provides Python code to PTX/SASS correlation to enable the profiling/debugging of generated kernels with debug symbols by generating line info when compiling the kernel. + +You can enable that globally via the environment variable CUTE_DSL_LINEINFO=1. Alternative, you can use compilation options to enable that per kernel. Please refer to :doc:`./dsl_jit_compilation_options` for more details. + + DSL Debugging ------------- @@ -75,6 +83,48 @@ This helps you verify whether the IR is generated as expected. export CUTE_DSL_KEEP_IR=1 +Dump the generated PTX & CUBIN +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For users familiar with PTX and SASS, CuTe DSL supports dumping the generated PTX and CUBIN. + +.. code:: bash + + # Dump generated PTX in a .ptx file (default: False) + export CUTE_DSL_KEEP_PTX=1 + + # Dump generated cubin in a .cubin file (default: False) + export CUTE_DSL_KEEP_CUBIN=1 + +To further get SASS from cubin, users can use ``nvdisasm`` (usually installed with CUDA toolkit) to disassemble the cubin. + +.. code:: bash + + nvdisasm your_dsl_code.cubin > your_dsl_code.sass + + +Access the dumped contents programmatically +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For compiled kernels, the generated PTX/CUBIN/IR can be accessed programmatically as well through following attributes: + +- ``__ptx__``: The generated PTX code of the compiled kernel. +- ``__cubin__``: The generated CUBIN data of the compiled kernel. +- ``__mlir__``: The generated IR code of the compiled kernel. + +.. code:: python + + compiled_foo = cute.compile(foo, ...) + print(f"PTX: {compiled_foo.__ptx__}") + with open("foo.cubin", "wb") as f: + f.write(compiled_foo.__cubin__) + + +Change the dump directory +~~~~~~~~~~~~~~~~~~~~~~~~~ + +By default, all dumped files are saved in the current working directory. To specify a different directory for the dumped files, please set the environment variable CUTE_DSL_DUMP_DIR accordingly. + Kernel Functional Debugging ---------------------------- @@ -122,6 +172,7 @@ For detecting memory errors and race conditions: Please refer to the `compute-sanitizer documentation `_ for more details. + Conclusion ---------- diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_jit_arg_generation.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_jit_arg_generation.rst index 18ee4e2d..18970012 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_jit_arg_generation.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_jit_arg_generation.rst @@ -124,7 +124,7 @@ JIT function arguments with |CUSTOM_TYPES| - ``__extract_mlir_values__``: Generate a dynamic expression for the current object. - ``__new_from_mlir_values__``: Create a new object from MLIR values. -Refer to `typing.py `__ for more details on these protocol APIs. +Refer to `typing.py `__ for more details on these protocol APIs. Depending on different cases of the |CUSTOM_TYPES|, |DSL| provides easy ways to adopt |CUSTOM_TYPES| for JIT function arguments. diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_jit_compilation_options.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_jit_compilation_options.rst index d5dd42f8..1f4bcb25 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_jit_compilation_options.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_jit_compilation_options.rst @@ -18,9 +18,11 @@ Compilation options allow you to customize how your JIT-compiled functions are b These options can be passed as keyword arguments to ``cute.compile`` or set globally for all JIT compilations. The available options and their effects are described in the following sections, along with usage examples to help you get started. +The |DSL| provides multiple ways to specify compilation options - either by specifying additional arguments to ``cute.compile`` or by using a more Pythonic approach with separate Python types for ``cute.compile``. -``cute.compile`` Compilation Options ------------------------------------- + +``cute.compile`` Compilation Options as strings +----------------------------------------------- You can provide additional compilation options as a string when calling ``cute.compile``. The |DSL| uses ``argparse`` to parse these options and will raise an error if any invalid options are specified. @@ -36,10 +38,30 @@ You can provide additional compilation options as a string when calling ``cute.c - Optimization level of compilation. The higher the level, the more optimizations are applied. The valid value range is [0, 3]. - 3 (highest level of optimization) - int - * - ``enable-device-assertions`` - - Enable device code assertions. + * - ``enable-assertions`` + - Enable host and device code assertions. - False - bool + * - ``keep-cubin`` + - Keep the generated CUBIN file. + - False + - bool + * - ``keep-ptx`` + - Keep the generated PTX file. + - False + - bool + * - ``ptxas-options`` + - The options to pass to the PTX Compiler library. + - "" + - str + * - ``generate-line-info`` + - Generate line information for debugging. + - False + - bool + * - ``gpu-arch`` + - The GPU architecture to compile for. + - "" + - str You can use the following code to specify compilation options: @@ -47,4 +69,34 @@ You can use the following code to specify compilation options: jit_executor_with_opt_level_2 = cute.compile(add, 1, 2, options="--opt-level 2") jit_executor_with_opt_level_1 = cute.compile(add, 1, 2, options="--opt-level 1") - jit_executor_with_enable_device_assertions = cute.compile(add, 1, 2, options="--enable-device-assertions") + jit_executor_with_enable_device_assertions = cute.compile(add, 1, 2, options="--enable-assertions") + jit_executor_with_keep_cubin = cute.compile(add, 1, 2, options="--keep-cubin") + jit_executor_with_keep_ptx = cute.compile(add, 1, 2, options="--keep-ptx") + jit_executor_with_ptxas_options = cute.compile(add, 1, 2, options="--ptxas-options '--opt-level=2'") + + +``cute.compile`` Compilation Options as separate Python types +------------------------------------------------------------- + +Alternatively, you can also use a more Pythonic way to specify compilation options with separate Python types. +Compilation options can be programmatically composed using tuple and passed to ``cute.compile`` separately. + +.. code-block:: python + + from cutlass.cute import OptLevel, EnableAssertions, GenerateLineInfo, KeepCUBIN, KeepPTX + + my_debugging_options = (OptLevel(1), EnableAssertions, GenerateLineInfo, KeepCUBIN, KeepPTX) + compiled_kernel_1 = cute.compile[my_debugging_options](my_kernel_1, ...) + compiled_kernel_2 = cute.compile[my_debugging_options](my_kernel_2, ...) + +This approach causes invalid options to raise errors immediately, making it much easier to detect typos when specifying multiple options. +Notebly, boolean options are automatically converted to True instances of the option type for convenience. + +.. code-block:: python + + jit_executor_with_opt_level_2 = cute.compile[OptLevel(2)](add, 1, 2) + jit_executor_with_opt_level_1 = cute.compile[OptLevel(1)](add, 1, 2) + jit_executor_with_enable_device_assertions = cute.compile[EnableAssertions](add, 1, 2) + jit_executor_with_keep_cubin = cute.compile[KeepCUBIN](add, 1, 2) + jit_executor_with_keep_ptx = cute.compile[KeepPTX](add, 1, 2) + jit_executor_with_ptxas_options = cute.compile[PtxasOptions("--opt-level=2")](add, 1, 2) \ No newline at end of file diff --git a/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst b/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst index edd9eb94..7d223ebb 100644 --- a/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst +++ b/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst @@ -63,7 +63,7 @@ The full signature of from_dlpack is as follows: .. code-block:: python - def from_dlpack(tensor, assumed_align=None): + def from_dlpack(tensor, assumed_align=None, use_32bit_stride=False): The ``assumed_align`` integer parameter specifies the alignment of the tensor in unit of bytes. The tensor's base address must be divisible by ``assumed_align``. When not provided explicitly, @@ -72,6 +72,13 @@ information is part of the pointer type in the generated IR. Therefore, programs alignments have a different IR and identical IRs are required for hitting the kernel caching mechanism of |DSL|. +The ``use_32bit_stride`` parameter determines whether to use 32-bit stride for the tensor's dynamic stride values. +By default, it is set to False (64bit) to ensure that address calculations do not risk overflow. For smaller +problem sizes (where ``cosize(layout_of_tensor) <= Int32_MAX``), users may set it to True (32bit) to improve performance +by reducing register usage and the number of address calculation instructions. When ``use_32bit_stride`` is set +to True, a runtime check is performed to ensure that the layout does not overflow. Please note that this parameter +only has an effect when the tensor's layout is marked as dynamic. + Code Example ~~~~~~~~~~~~ @@ -242,6 +249,10 @@ The following example demonstrates how to use ``mark_layout_dynamic`` to specify t7 = from_dlpack(b).mark_layout_dynamic(leading_dim=3) # Expected strides[leading_dim] == 1, but got 4 + c = torch.empty(1000000000, 1000000000) + t8 = from_dlpack(c, use_32bit_stride=True).mark_layout_dynamic() + # Layout in DLTensorWrapper has int32 overflow risk. Please set use_32bit_stride to False. + Mark the Tensor's Layout as Dynamic with ``mark_compact_shape_dynamic`` ----------------------------------------------------------------------- @@ -398,6 +409,12 @@ The following example demonstrates how to use ``mark_compact_shape_dynamic`` to ) # The stride_order is not consistent with the layout + c = torch.empty(1000000000, 1000000000) + t13 = from_dlpack(c, use_32bit_stride=True).mark_compact_shape_dynamic( + mode=0, divisibility=1 + ) + # Layout in DLTensorWrapper has int32 overflow risk. Please set use_32bit_stride to False. + Bypass the DLPack Protocol -------------------------- diff --git a/media/docs/pythonDSL/quick_start.rst b/media/docs/pythonDSL/quick_start.rst index 18569b17..495ec8ff 100644 --- a/media/docs/pythonDSL/quick_start.rst +++ b/media/docs/pythonDSL/quick_start.rst @@ -8,7 +8,15 @@ The CUTLASS DSL 4.0 release currently supports **Linux** and **Python 3.12** onl Installation ----------------------- -To install the CUTLASS DSL, run: +To ensure compatibility with the examples and code on `GitHub `_, +use the `requirements.txt `_ file from the corresponding commit in the repository. + +.. code-block:: bash + + git clone https://github.com/NVIDIA/cutlass.git + pip install -r cutlass/python/CuTeDSL/requirements.txt + +If you just want to try out the last known stable release of the CUTLASS DSL (may not compatible with the latest examples and code), run: .. code-block:: bash @@ -18,9 +26,6 @@ The ``nvidia-cutlass-dsl`` wheel includes everything needed to generate GPU kern the same NVIDIA driver version as the `CUDA Toolkit 12.9 `_. -To ensure compatibility with the examples and code on `GitHub `_, -use the ``requirements.txt`` file from the corresponding commit in the repository. - Recommended Dependencies --------------------------------- diff --git a/pyproject.toml b/pyproject.toml index 575ee076..d14493b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "nvidia-cutlass" -version = "4.2.1.0" +version = "4.2.0.0" description = "CUTLASS" readme = "README.md" requires-python = ">=3.8" diff --git a/python/CuTeDSL/cutlass/__init__.py b/python/CuTeDSL/cutlass/__init__.py index f2c7ed26..38bd7626 100644 --- a/python/CuTeDSL/cutlass/__init__.py +++ b/python/CuTeDSL/cutlass/__init__.py @@ -11,13 +11,12 @@ from .cutlass_dsl import ( Constexpr, + dsl_user_op, as_numeric, min, max, and_, or_, - all_, - any_, not_, all_, any_, @@ -29,6 +28,7 @@ from .cutlass_dsl import ( while_generate, yield_out, # Control-flow with AST pre-processor + range, range_constexpr, range_dynamic, const_expr, @@ -46,6 +46,7 @@ from .cute.typing import * # Utilities not belonging to CuTe from . import utils as utils +from . import pipeline as pipeline # Used as internal symbol from . import cutlass_dsl as _dsl diff --git a/python/CuTeDSL/cutlass/base_dsl/__init__.py b/python/CuTeDSL/cutlass/base_dsl/__init__.py index cbb617dc..c1929784 100644 --- a/python/CuTeDSL/cutlass/base_dsl/__init__.py +++ b/python/CuTeDSL/cutlass/base_dsl/__init__.py @@ -12,6 +12,6 @@ # Local module imports from .dsl import * from .runtime import * -from ._mlir_helpers import lru_cache_ir +from ._mlir_helpers import lru_cache_ir, dsl_user_op from .env_manager import get_str_env_var, detect_gpu_arch diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/__init__.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/__init__.py index 607a24d0..9d0ebee6 100644 --- a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/__init__.py +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/__init__.py @@ -15,9 +15,9 @@ This module provides MLIR Dialect helper functions from . import arith from .lru_cache_ir import lru_cache_ir +from .op import dsl_user_op - -__all__ = ["arith", "lru_cache_ir"] +__all__ = ["arith", "lru_cache_ir", "dsl_user_op"] try: from . import gpu diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/arith.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/arith.py index 60cc8db3..9afb8d05 100644 --- a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/arith.py +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/arith.py @@ -20,6 +20,7 @@ from ..common import * from ..._mlir import ir # type: ignore from ..._mlir.extras import types as T # type: ignore from ..._mlir.dialects import arith, nvgpu, math, builtin # type: ignore +from .op import dsl_user_op from .lru_cache_ir import lru_cache_ir @@ -393,9 +394,10 @@ def _binary_op(op): class ArithValue(ir.Value): """Overloads operators for MLIR's Arith dialects binary operations.""" - def __init__(self, v, signed: Union[bool, None] = None): + @dsl_user_op + def __init__(self, v, signed: Union[bool, None] = None, *, loc=None, ip=None): if isinstance(v, int): - v = arith.constant(self.type, v) + v = arith.constant(self.type, v, loc=loc, ip=ip) super().__init__(v) elem_ty = element_type(self.type) @@ -406,6 +408,7 @@ class ArithValue(ir.Value): def with_signedness(self, signed: Union[bool, None]): return type(self)(self, signed) + @dsl_user_op def __neg__(self, *, loc=None, ip=None): if self.type == T.bool(): raise TypeError( @@ -418,6 +421,7 @@ class ArithValue(ir.Value): c0 = arith.constant(self.type, 0, loc=loc, ip=ip) return arith.subi(c0, self, loc=loc, ip=ip) + @dsl_user_op @_binary_op def __pow__(self, other, *, loc=None, ip=None) -> "ArithValue": if self.is_float and other.is_float: @@ -433,12 +437,13 @@ class ArithValue(ir.Value): else: raise DSLNotImplemented(f"Unsupported '{self} ** {other}'") + @dsl_user_op @_binary_op def __rpow__(self, other, *, loc=None, ip=None) -> "ArithValue": return other.__pow__(self, loc=loc, ip=ip) # arith operators - + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __add__(self, other, *, loc=None, ip=None) -> "ArithValue": @@ -447,6 +452,7 @@ class ArithValue(ir.Value): else: return arith.addi(self, other, loc=loc, ip=ip) + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __sub__(self, other, *, loc=None, ip=None) -> "ArithValue": @@ -455,6 +461,7 @@ class ArithValue(ir.Value): else: return arith.subi(self, other, loc=loc, ip=ip) + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __mul__(self, other, *, loc=None, ip=None) -> "ArithValue": @@ -463,6 +470,7 @@ class ArithValue(ir.Value): else: return arith.muli(self, other, loc=loc, ip=ip) + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __truediv__(self, other, *, loc=None, ip=None) -> "ArithValue": @@ -473,6 +481,7 @@ class ArithValue(ir.Value): rhs = itofp(other, other.signed, T.f32(), loc=loc, ip=ip) return arith.divf(lhs, rhs, loc=loc, ip=ip) + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __floordiv__(self, other, *, loc=None, ip=None) -> "ArithValue": @@ -484,6 +493,7 @@ class ArithValue(ir.Value): else: return arith.divui(self, other, loc=loc, ip=ip) + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __mod__(self, other, *, loc=None, ip=None) -> "ArithValue": @@ -494,31 +504,38 @@ class ArithValue(ir.Value): else: return arith.remui(self, other, loc=loc, ip=ip) + @dsl_user_op @_binary_op def __radd__(self, other, *, loc=None, ip=None) -> "ArithValue": return other.__add__(self, loc=loc, ip=ip) + @dsl_user_op @_binary_op def __rsub__(self, other, *, loc=None, ip=None) -> "ArithValue": return other.__sub__(self, loc=loc, ip=ip) + @dsl_user_op @_binary_op def __rmul__(self, other, *, loc=None, ip=None) -> "ArithValue": return other.__mul__(self, loc=loc, ip=ip) + @dsl_user_op @_binary_op def __rtruediv__(self, other, *, loc=None, ip=None) -> "ArithValue": return other.__truediv__(self, loc=loc, ip=ip) + @dsl_user_op @_binary_op def __rfloordiv__(self, other, *, loc=None, ip=None) -> "ArithValue": return other.__floordiv__(self, loc=loc, ip=ip) + @dsl_user_op @_binary_op def __rmod__(self, other, *, loc=None, ip=None) -> "ArithValue": return other.__mod__(self, loc=loc, ip=ip) # Comparison operators (comparison doesn't have right-hand-side variants) + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __lt__(self, other, *, loc=None, ip=None) -> "ArithValue": @@ -529,6 +546,7 @@ class ArithValue(ir.Value): else: return arith.cmpi(arith.CmpIPredicate.ult, self, other, loc=loc, ip=ip) + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __le__(self, other, *, loc=None, ip=None) -> "ArithValue": @@ -539,6 +557,7 @@ class ArithValue(ir.Value): else: return arith.cmpi(arith.CmpIPredicate.ule, self, other, loc=loc, ip=ip) + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __eq__(self, other, *, loc=None, ip=None) -> "ArithValue": @@ -547,6 +566,7 @@ class ArithValue(ir.Value): else: return arith.cmpi(arith.CmpIPredicate.eq, self, other, loc=loc, ip=ip) + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __ne__(self, other, *, loc=None, ip=None) -> "ArithValue": @@ -556,6 +576,7 @@ class ArithValue(ir.Value): else: return arith.cmpi(arith.CmpIPredicate.ne, self, other, loc=loc, ip=ip) + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __gt__(self, other, *, loc=None, ip=None) -> "ArithValue": @@ -566,6 +587,7 @@ class ArithValue(ir.Value): else: return arith.cmpi(arith.CmpIPredicate.ugt, self, other, loc=loc, ip=ip) + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __ge__(self, other, *, loc=None, ip=None) -> "ArithValue": @@ -577,25 +599,30 @@ class ArithValue(ir.Value): return arith.cmpi(arith.CmpIPredicate.uge, self, other, loc=loc, ip=ip) # Unary operators + @dsl_user_op def __invert__(self, *, loc=None, ip=None) -> "ArithValue": return arith.xori(self, arith.constant(self.type, -1)) # Bitwise operations + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __and__(self, other, *, loc=None, ip=None) -> "ArithValue": return arith.andi(self, other, loc=loc, ip=ip) + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __or__(self, other, *, loc=None, ip=None) -> "ArithValue": return arith.ori(self, other, loc=loc, ip=ip) + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __xor__(self, other, *, loc=None, ip=None) -> "ArithValue": return arith.xori(self, other, loc=loc, ip=ip) + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __rshift__(self, other, *, loc=None, ip=None) -> "ArithValue": @@ -604,27 +631,33 @@ class ArithValue(ir.Value): else: return arith.shrui(self, other, loc=loc, ip=ip) + @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op def __lshift__(self, other, *, loc=None, ip=None) -> "ArithValue": return arith.shli(self, other, loc=loc, ip=ip) + @dsl_user_op @_binary_op def __rand__(self, other, *, loc=None, ip=None) -> "ArithValue": return arith.andi(other, self, loc=loc, ip=ip) + @dsl_user_op @_binary_op def __ror__(self, other, *, loc=None, ip=None) -> "ArithValue": return arith.ori(other, self, loc=loc, ip=ip) + @dsl_user_op @_binary_op def __rxor__(self, other, *, loc=None, ip=None) -> "ArithValue": return arith.xori(other, self, loc=loc, ip=ip) + @dsl_user_op @_binary_op def __rrshift__(self, other, *, loc=None, ip=None) -> "ArithValue": return other.__rshift__(self, loc=loc, ip=ip) + @dsl_user_op @_binary_op def __rlshift__(self, other, *, loc=None, ip=None) -> "ArithValue": return other.__lshift__(self, loc=loc, ip=ip) diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/gpu.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/gpu.py index a0b0d050..e9b4837f 100644 --- a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/gpu.py +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/gpu.py @@ -13,7 +13,6 @@ This module provides MLIR GPU Dialect helper functions """ - from ..._mlir import ir from ..._mlir.dialects import gpu, arith, scf from ..._mlir.extras import types as T diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/lru_cache_ir.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/lru_cache_ir.py index 57d717b4..94ce6a42 100644 --- a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/lru_cache_ir.py +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/lru_cache_ir.py @@ -23,7 +23,6 @@ def make_layout(...): """ - from functools import lru_cache, wraps from ..._mlir import ir # type: ignore diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/op.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/op.py index 3989c75e..7882d2af 100644 --- a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/op.py +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/op.py @@ -13,7 +13,6 @@ This module provides MLIR's OP helper functions """ - import inspect from functools import wraps @@ -21,13 +20,43 @@ from ..._mlir import ir def dsl_user_op(opFunc): + """ + This is a decorator that needs to be used in each user-facing API to + manage source location for toolchain. + + :param opFunc: The user-facing API function. + :type opFunc: Callable + :return: The wrapped user-facing API function. + :rtype: Callable + """ + @wraps(opFunc) def wrapper(*args, **kwargs): loc = kwargs.pop("loc", None) if loc is None: frame = inspect.currentframe().f_back - file_loc = ir.Location.file(frame.f_code.co_filename, frame.f_lineno, 0) - loc = ir.Location.name(frame.f_code.co_name, childLoc=file_loc) + frameInfo = inspect.getframeinfo(frame) + # In Python < 3.11, getframeinfo returns a NamedTuple without positions + if not hasattr(frameInfo, "positions"): + file_loc = ir.Location.file( + frameInfo.filename, + frameInfo.lineno, + 0, + ) + else: + file_loc = ir.Location.file( + frameInfo.filename, + frameInfo.positions.lineno, + frameInfo.positions.col_offset, + ) + loc = ir.Location.name( + ( + "".join([c.strip() for c in frameInfo.code_context]) + if frameInfo.code_context + else frameInfo.function + ), + childLoc=file_loc, + ) res_or_list = opFunc(*args, **kwargs, loc=loc) return res_or_list diff --git a/python/CuTeDSL/cutlass/base_dsl/arch.py b/python/CuTeDSL/cutlass/base_dsl/arch.py new file mode 100644 index 00000000..78b9bb67 --- /dev/null +++ b/python/CuTeDSL/cutlass/base_dsl/arch.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from enum import Enum +import re +from typing import Callable, List + + +class Arch(Enum): + # sm_arch = (major, minor, suffix) + # Ampere + sm_80 = (8, 0, "") + sm_86 = (8, 6, "") + sm_87 = (8, 7, "") + # Ada + sm_89 = (8, 9, "") + # Hopper + sm_90 = (9, 0, "") + sm_90a = (9, 0, "a") + # Blackwell + sm_100 = (10, 0, "") + sm_100a = (10, 0, "a") + sm_100f = (10, 0, "f") + sm_101 = (10, 1, "") + sm_101a = (10, 1, "a") + sm_101f = (10, 1, "f") + sm_110 = (11, 0, "") + sm_110a = (11, 0, "a") + sm_110f = (11, 0, "f") + sm_120 = (12, 0, "") + sm_120a = (12, 0, "a") + sm_120f = (12, 0, "f") + sm_121 = (12, 1, "") + sm_121a = (12, 1, "a") + sm_121f = (12, 1, "f") + + def __init__(self, major, minor, suffix): + self.major = major + self.minor = minor + self.suffix = suffix + + @classmethod + def _missing_(cls, value): + """Support creating Arch enum from (major, minor, suffix) tuple""" + if not isinstance(value, tuple): + raise ValueError(f"invalid arguments for Arch: {value}") + major, minor, suffix = None, None, None + if len(value) == 2: + major, minor, suffix = *value, "" + else: + raise ValueError(f"invalid arguments for Arch: {value}") + + return cls(major, minor, suffix) + + def __repr__(self): + return self.__str__() + + @classmethod + def from_string(cls, arch_str): + pattern = r"^(?:sm_?|SM_?)?(\d+)(\d)([af]?)$" + match = re.match(pattern, arch_str) + if not match: + raise ValueError(f"Invalid architecture string format: {arch_str}") + major, minor, suffix = match.groups() + return cls((int(major), int(minor), suffix)) + + @classmethod + def filter(cls, criterion: Callable[["Arch"], bool]) -> List["Arch"]: + """ + Filter the archs by the given criterion. + """ + return [arch for arch in cls if criterion(arch)] + + def is_family_of(self, arch: "Arch") -> bool: + """ + Check if this arch is equal or higher in the same family than the given arch, so that the family-specific features can be used. + + Example: + + .. code-block:: python + + >>> arch = Arch.sm_103f + >>> arch.is_family_of(Arch.sm_100f) + True + + """ + # sm_101 is renamed to sm_110, sm_101f is family of sm_110f, but is not family of sm_100f + if self in [Arch.sm_101a, Arch.sm_101f]: + return arch.major == 11 and arch.minor == 0 + + return ( + self.major == arch.major + and self.minor >= arch.minor + and self.suffix in ["a", "f"] + ) + + def __lt__(self, other): + if not isinstance(other, Arch): + return NotImplemented + return (self.major, self.minor) < (other.major, other.minor) + + def __le__(self, other): + if not isinstance(other, Arch): + return NotImplemented + return (self.major, self.minor) <= (other.major, other.minor) + + def __gt__(self, other): + if not isinstance(other, Arch): + return NotImplemented + return (self.major, self.minor) > (other.major, other.minor) + + def __ge__(self, other): + if not isinstance(other, Arch): + return NotImplemented + return (self.major, self.minor) >= (other.major, other.minor) diff --git a/python/CuTeDSL/cutlass/base_dsl/ast_helpers.py b/python/CuTeDSL/cutlass/base_dsl/ast_helpers.py index 7b11474c..fdb09b9e 100644 --- a/python/CuTeDSL/cutlass/base_dsl/ast_helpers.py +++ b/python/CuTeDSL/cutlass/base_dsl/ast_helpers.py @@ -112,9 +112,9 @@ class Executor: unroll_full=False, prefetch_stages=None, ): - assert ( - self._loop_execute_range_dynamic - ), "Functions must be set before execution." + assert self._loop_execute_range_dynamic, ( + "Functions must be set before execution." + ) log().debug("start [%s] stop [%s] step [%s]", start, stop, step) return self._loop_execute_range_dynamic( @@ -433,9 +433,9 @@ def compare_executor(left, comparators, ops): Raises: AssertionError: If the executor function is not set before execution """ - assert ( - executor._compare_executor is not None - ), "Function must be set before execution." + assert executor._compare_executor is not None, ( + "Function must be set before execution." + ) return executor._compare_executor(left, comparators, ops) @@ -521,24 +521,6 @@ def range_value_check(*args): ) -def range_perf_warning(filename, lineno, *args): - has_dynamic_expr = False - for arg in args: - if executor._is_dynamic_expression(arg): - has_dynamic_expr = True - break - if not has_dynamic_expr: - warnings.warn_explicit( - ( - "This loop is no longer unrolled and may cause performance regression. " - "Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants." - ), - category=DSLOptimizationWarning, - filename=filename, - lineno=lineno, - ) - - @lru_cache(maxsize=1) def _get_self_module(): """ @@ -614,3 +596,15 @@ def get_locals_or_none(locals, symbols): else: variables.append(None) return variables + + +def closure_check(closures): + """ + Check if the closures have any captures + """ + for closure in closures: + if closure.__closure__: + raise DSLRuntimeError( + f"Function `{closure.__name__}` is a closure that captures variables and is not supported in dynamic control flow", + suggestion="Please implicitly pass in captured variables as arguments", + ) diff --git a/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py b/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py index 11f2d1ae..a4531914 100644 --- a/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py +++ b/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py @@ -33,8 +33,11 @@ to generate dialect-specific operations for `for` and `if` statements. """ import ast +import contextlib import importlib import inspect +import os +import sys import textwrap import warnings from dataclasses import dataclass @@ -72,6 +75,9 @@ class OrderedSet: def __sub__(self, other): return OrderedSet(key for key in self._dict if key not in other) + def __bool__(self): + return bool(self._dict) + def intersections(self, others): """Compute the intersection of this set with multiple other sets. @@ -94,11 +100,38 @@ class ImportInfo: """ Information about an import expression. """ + module_path: str attr_name: Optional[str] alias_name: str +@dataclass +class TryImportInfo: + """ + Represents information about a try-import block in the AST. + + This dataclass is used to capture and organize the import statements that appear + within the different clauses of a try-except-else-finally block. Each field holds + a list of import statements (or related nodes) that are encountered in the corresponding + clause of the try block. + + Attributes: + try_imports (list): Import statements found in the 'try' clause. + except_imports (list): Import statements found in any 'except' clauses. + else_imports (list): Import statements found in the 'else' clause, if present. + finally_imports (list): Import statements found in the 'finally' clause, if present. + + This structure allows the preprocessor to track and process imports that are conditionally + executed depending on exception handling logic. + """ + + try_imports: list + except_imports: list + else_imports: list + finally_imports: list + + @dataclass class ScopeManager: """ @@ -191,6 +224,89 @@ class DSLPreprocessor(ast.NodeTransformer): set_location(node, lineno, col_offset) return node + def _get_imports_from_ast(self, node, module): + """ + Recursively extracts all import statements from the given AST node. + + This method traverses the AST of a Python module and collects information about all + import statements, including standard imports, from-imports (with support for relative imports), + and imports that appear within try/except/finally blocks. For try blocks, it also handles + imports that may be conditionally executed in except, else, or finally clauses, specifically + looking for handlers that catch ImportError, ModuleNotFoundError, or Exception. + + Args: + node: The AST node (typically an ast.Module) to search for import statements. + module: The Python module object corresponding to the AST, used for resolving relative imports. + + Returns: + A list of ImportInfo and TryImportInfo objects representing all discovered imports in the AST. + """ + imports = [] + alias = lambda n: n.asname if n.asname else n.name + for child_node in ast.iter_child_nodes(node): + if isinstance(child_node, ast.Import): + for name in child_node.names: + imports.append( + ImportInfo( + module_path=name.name, + attr_name=None, + alias_name=alias(name), + ) + ) + elif isinstance(child_node, ast.ImportFrom): + module_name = child_node.module + if child_node.level > 0: + # Handle relative imports. + if module.__package__: + package_name = module.__package__.rsplit( + ".", child_node.level - 1 + )[0] + module_name = f"{package_name}.{module_name}" + else: + # Handle typically some local import like: + # from .common_dense_gemm import DenseGemmKernel + # where there is no __package__, either None + # when in __main__ or '' otherwise. + module_name = f"{module_name}" + for name in child_node.names: + imports.append( + ImportInfo( + module_path=module_name, + attr_name=name.name, + alias_name=alias(name), + ) + ) + # ast.TryStar is introduced in Python 3.11. Can't use directly in Python 3.10 and lower. + elif isinstance(child_node, (ast.Try, getattr(ast, "TryStar", ast.Try))): + # Handle try-catch + try_imports = self._get_imports_from_ast( + ast.Module(body=child_node.body), module + ) + # search handler for ImportError or ModuleNotFoundError + except_imports = [] + for handler in child_node.handlers: + if handler.type == None or handler.type.id in [ + "ImportError", + "ModuleNotFoundError", + "Exception", + ]: + except_imports = self._get_imports_from_ast( + ast.Module(body=handler.body), module + ) + break + else_imports = self._get_imports_from_ast( + ast.Module(body=child_node.orelse), module + ) + finally_imports = self._get_imports_from_ast( + ast.Module(body=child_node.finalbody), module + ) + imports.append( + TryImportInfo( + try_imports, except_imports, else_imports, finally_imports + ) + ) + return imports + def _get_module_imports(self, decorated_func): """Extract imports from the module containing the decorated function""" imports = [] @@ -202,70 +318,82 @@ class DSLPreprocessor(ast.NodeTransformer): source = inspect.getsource(module) module_ast = ast.parse(source) - # Extract imports from the full module - alias = lambda n: n.asname if n.asname else n.name - for node in ast.walk(module_ast): - if isinstance(node, ast.Import): - for name in node.names: - imports.append( - ImportInfo( - module_path=name.name, - attr_name=None, - alias_name=alias(name), - ) - ) - elif isinstance(node, ast.ImportFrom): - module_name = node.module - if node.level > 0: - # Handle relative imports - package_name = module.__package__.rsplit( - ".", node.level - 1 - )[0] - module_name = f"{package_name}.{module_name}" - for name in node.names: - imports.append( - ImportInfo( - module_path=module_name, - attr_name=name.name, - alias_name=alias(name), - ) - ) + imports = self._get_imports_from_ast(module_ast, module) except (IOError, TypeError): pass return imports + def try_import_first_and_then_local_import(self, module_path): + @contextlib.contextmanager + def local_import(module_path): + # Directory where some local import might happen: + local_dir = os.path.dirname(self.file_name) + # Momentarily insert the directory where the local import + # used to happen, so the import can find the module. + sys.path.insert(0, local_dir) + try: + yield importlib.import_module(module_path) + finally: + # Clean up even in the case of an exception. + sys.path.pop(0) + + try: + # Try the normal import first. + return importlib.import_module(module_path) + except (ImportError, AttributeError): + # If the normal import failed, tried a local import because we might + # have lost track of sys.path changes. + with local_import(module_path) as module: + return module + + def exec_import(self, import_info, exec_globals): + module_path, attr_name, alias_name = ( + import_info.module_path, + import_info.attr_name, + import_info.alias_name, + ) + module = self.try_import_first_and_then_local_import(module_path) + if attr_name: + if attr_name == "*": + if hasattr(module, "__all__"): + attrs = module.__all__ + else: + attrs = [name for name in dir(module) if not name.startswith("_")] + else: + attrs = [attr_name] + + for attr in attrs: + alias = attr if attr_name == "*" else alias_name + exec_globals[alias] = getattr(module, attr) + else: + exec_globals[alias_name] = module + + def exec_imports(self, import_infos, exec_globals): + for import_info in import_infos: + if isinstance(import_info, ImportInfo): + try: + self.exec_import(import_info, exec_globals) + except (ImportError, AttributeError) as e: + raise ImportError( + f"Failed to import {import_info.module_path}: {str(e)}" + ) + elif isinstance(import_info, TryImportInfo): + try: + self.exec_imports(import_info.try_imports, exec_globals) + except (ImportError, AttributeError): + self.exec_imports(import_info.except_imports, exec_globals) + else: + self.exec_imports(import_info.else_imports, exec_globals) + finally: + self.exec_imports(import_info.finally_imports, exec_globals) + def exec(self, function_name, original_function, code_object, exec_globals): # Get imports from the original module module_imports = self._get_module_imports(original_function) # Import all required modules - for import_info in module_imports: - module_path, attr_name, alias_name = ( - import_info.module_path, - import_info.attr_name, - import_info.alias_name, - ) - try: - module = importlib.import_module(module_path) - if attr_name: - if attr_name == "*": - if hasattr(module, "__all__"): - attrs = module.__all__ - else: - attrs = [ - name for name in dir(module) if not name.startswith("_") - ] - else: - attrs = [attr_name] - - for attr in attrs: - alias = attr if attr_name == "*" else alias_name - exec_globals[alias] = getattr(module, attr) - else: - exec_globals[alias_name] = module - except (ImportError, AttributeError) as e: - raise ImportError(f"Failed to import {module_path}: {str(e)}") + self.exec_imports(module_imports, exec_globals) # Execute the transformed code log().info( @@ -304,12 +432,19 @@ class DSLPreprocessor(ast.NodeTransformer): return [] # Step 1. Parse the given function - file_name = inspect.getsourcefile(function_pointer) - lines, start_line = inspect.getsourcelines(function_pointer) - dedented_source = textwrap.dedent("".join(lines)) - tree = ast.parse(dedented_source, filename=file_name) - # Bump the line numbers so they match the real source file - ast.increment_lineno(tree, start_line - 1) + try: + file_name = inspect.getsourcefile(function_pointer) + lines, start_line = inspect.getsourcelines(function_pointer) + dedented_source = textwrap.dedent("".join(lines)) + tree = ast.parse(dedented_source, filename=file_name) + # Bump the line numbers so they match the real source file + ast.increment_lineno(tree, start_line - 1) + except Exception: + # Under REPL mode, there is no way to get source of a function object, error out + raise DSLRuntimeError( + f"Failed to parse function {func_name}", + suggestion="DSL does not support REPL mode, save the function to a file instead.", + ) # Step 1.2 Check the decorator if not self.check_decorator(tree.body[0]): @@ -471,6 +606,7 @@ class DSLPreprocessor(ast.NodeTransformer): local_closure = self.local_closures file_name = self.file_name region_node = node + called_closures = OrderedSet() class RegionAnalyzer(ast.NodeVisitor): force_store = False @@ -538,11 +674,7 @@ class DSLPreprocessor(ast.NodeTransformer): if isinstance(node.func, ast.Name): func_name = node.func.id if func_name in local_closure: - raise DSLAstPreprocessorError( - f"Function `{func_name}` is a closure and is not supported in for/if statements", - filename=file_name, - snippet=ast.unparse(region_node), - ) + called_closures.add(func_name) # Classes are mutable by default. Mark them as write. If they are # dataclass(frozen=True), treat them as read in runtime. @@ -560,7 +692,7 @@ class DSLPreprocessor(ast.NodeTransformer): write_args = list(write_args.intersections(active_symbols)) invoked_args = list(invoked_args.intersections(active_symbols)) - return write_args + invoked_args, len(write_args) + return write_args + invoked_args, len(write_args), called_closures def extract_range_args(self, iter_node): args = iter_node.args @@ -874,26 +1006,6 @@ class DSLPreprocessor(ast.NodeTransformer): lineno=node.iter.lineno, ) - warning_call = None - if range_kind == "range" and is_builtin_range and not has_keyword: - # Warn about possible performance regression due to behavior change - warning_call = ast.Expr( - ast.Call( - func=self._create_module_attribute( - "range_perf_warning", - lineno=node.lineno, - col_offset=node.col_offset, - ), - args=[ - ast.Constant(value=self.file_name), - ast.Constant(value=node.iter.lineno), - ] - + node.iter.args, - keywords=[], - ) - ) - ast.copy_location(warning_call, node.iter) - is_prefixed_range = range_kind == "range" and not is_builtin_range check_call = None if range_kind == "range_dynamic" or is_prefixed_range: @@ -908,7 +1020,7 @@ class DSLPreprocessor(ast.NodeTransformer): if check_call is not None: new_for_node = [check_call] + new_for_node - return new_for_node if warning_call is None else [warning_call] + new_for_node + return new_for_node @staticmethod def _hoist_expr_to_assignments(expr, name): @@ -1036,6 +1148,24 @@ class DSLPreprocessor(ast.NodeTransformer): extra_exprs, ) + def _create_closure_check_call(self, called_closures, node): + return ast.Expr( + ast.Call( + func=self._create_module_attribute( + "closure_check", + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[ + ast.List( + elts=[ast.Name(id=c, ctx=ast.Load()) for c in called_closures], + ctx=ast.Load(), + ) + ], + keywords=[], + ) + ) + def transform_for_loop(self, node, active_symbols): # Check for early exit and raise exception self.check_early_exit(node, "for") @@ -1083,8 +1213,8 @@ class DSLPreprocessor(ast.NodeTransformer): start_expr, stop_expr, step_expr, has_step = self.extract_range_args(node.iter) unroll, unroll_full = self.extract_unroll_args(node.iter) prefetch_stages = self.extract_prefetch_stages_args(node.iter) - write_args, full_write_args_count = self.analyze_region_variables( - node, active_symbols + write_args, full_write_args_count, called_closures = ( + self.analyze_region_variables(node, active_symbols) ) if has_step and self.client_module_name[0] == "cutlass": @@ -1097,6 +1227,9 @@ class DSLPreprocessor(ast.NodeTransformer): if target_var_is_active_before_loop: exprs.append(pre_loop_expr) + if called_closures: + exprs.append(self._create_closure_check_call(called_closures, node)) + func_name = f"loop_body_{self.counter}" self.counter += 1 @@ -1207,6 +1340,7 @@ class DSLPreprocessor(ast.NodeTransformer): node, ) elif isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name): + def create_downcast_call(arg): return ast.copy_location( ast.Call( @@ -1221,6 +1355,7 @@ class DSLPreprocessor(ast.NodeTransformer): ), arg, ) + module = self.function_globals.get(func.value.id) if isinstance(module, ModuleType) and module.__package__.endswith( "._mlir.dialects" @@ -1388,14 +1523,17 @@ class DSLPreprocessor(ast.NodeTransformer): return [check, node] active_symbols = self.scope_manager.get_active_symbols() - with self.scope_manager: # Check for early exit and raise exception self.check_early_exit(node, "while") - write_args, full_write_args_count = self.analyze_region_variables( - node, active_symbols + write_args, full_write_args_count, called_closures = ( + self.analyze_region_variables(node, active_symbols) ) + exprs = [] + if called_closures: + exprs.append(self._create_closure_check_call(called_closures, node)) + func_name = f"while_region_{self.counter}" self.counter += 1 @@ -1404,7 +1542,7 @@ class DSLPreprocessor(ast.NodeTransformer): ) assign = self.create_cf_call(func_name, write_args, node) - return [func_def] + assign + return exprs + [func_def] + assign def visit_Try(self, node): with self.scope_manager: @@ -1514,6 +1652,7 @@ class DSLPreprocessor(ast.NodeTransformer): "In": "in", "NotIn": "not in", } + def compare_ops_to_str(self, node): names = [ ast.Constant(value=self.cmpops[op.__class__.__name__]) for op in node.ops @@ -1556,9 +1695,13 @@ class DSLPreprocessor(ast.NodeTransformer): # Check for early exit and raise exception self.check_early_exit(node, "if") - yield_args, full_write_args_count = self.analyze_region_variables( - node, active_symbols + yield_args, full_write_args_count, called_closures = ( + self.analyze_region_variables(node, active_symbols) ) + exprs = [] + if called_closures: + exprs.append(self._create_closure_check_call(called_closures, node)) + func_name = f"if_region_{self.counter}" self.counter += 1 @@ -1567,7 +1710,7 @@ class DSLPreprocessor(ast.NodeTransformer): ) assign = self.create_cf_call(func_name, yield_args, node) - return [func_def] + assign + return exprs + [func_def] + assign def generate_get_locals_or_none_call(self, write_args): return ast.Call( @@ -1721,7 +1864,6 @@ class DSLPreprocessor(ast.NodeTransformer): decorator_list=[], ) else: - else_body = [] for stmt in node.orelse: transformed_stmt = self.visit( diff --git a/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py b/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py index 5d9234f2..a272497f 100644 --- a/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py +++ b/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py @@ -21,9 +21,10 @@ import pwd import time from pathlib import Path import hashlib +from functools import lru_cache from .utils.logger import log -from .jit_executor import JitExecutor +from .jit_executor import JitCompiledFunction from .._mlir import ir @@ -33,6 +34,9 @@ from .._mlir import ir def get_current_user(): + """ + Get the current user. This is used to determine the path to the cache directory. + """ # Try to get the user from the environment variable first user = os.getenv("USER") or os.getenv("USERNAME") if not user: @@ -41,6 +45,9 @@ def get_current_user(): return user +# default_generated_ir_path is the path to the cache directory. +# It is set to /tmp/{user}/cutlass_python_cache/ by default. +# If the user is not found, the default path is used or /tmp/cutlass_python_cache/ is used. try: default_generated_ir_path = f"/tmp/{get_current_user()}/cutlass_python_cache/" except Exception as e: @@ -49,8 +56,25 @@ except Exception as e: print(f"Could not determine user, using default path. Error: {e}") +@lru_cache(maxsize=1) +def get_default_file_dump_root(): + """ + Get the default file dump root. + """ + dump_root = Path.cwd() + return dump_root + + def load_ir(file, asBytecode=False): - """Load generated IR from a file.""" + """Load generated IR from a file. + + :param file: The path to the file to load. + :type file: str + :param asBytecode: Whether to load the IR as bytecode, defaults to False + :type asBytecode: bool, optional + :return: The function name and the IR module + :rtype: tuple[str, ir.Module] + """ assert "mlir" in file func_name = file.split(".mlir")[0].split("dsl_")[-1] with ir.Context() as ctx: @@ -61,7 +85,16 @@ def load_ir(file, asBytecode=False): def make_unique_filename(fpath: Path, new_ext: str = None) -> Path: - """Generate a unique filename with an optional new extension.""" + """ + Generate a unique filename with an optional new extension. + + :param fpath: The path to the file to generate a unique filename for. + :type fpath: Path + :param new_ext: The new extension to add to the filename, defaults to None + :type new_ext: str, optional + :return: The unique filename + :rtype: Path + """ random_part = random.randint(0, 999999) timestamp = time.time() hash_input = f"{fpath}_{timestamp}_{random_part}".encode() @@ -74,12 +107,21 @@ def save_ir( dsl_name: str, module: object, fname: str, - isTemp: bool = False, - asBytecode: bool = False, + output_dir: str | None = None, + as_bytecode: bool = False, + bytecode_writer: callable = None, ) -> str: - """Save generated IR to a file.""" + """Save generated IR to a file. + + :param dsl_name: The name of the DSL. + :type dsl_name: str + :param module: The IR module to save. + :type module: object + :param fname: The name of the file to save. + :type fname: str + """ initial_name = f"{dsl_name.lower()}_{fname}.mlir" - save_path = Path(tempfile.gettempdir() if isTemp else os.getcwd()) + save_path = Path(output_dir if output_dir else tempfile.gettempdir()) save_fname = save_path / initial_name # Random ID to avoid any collisions rnd_id = str(uuid.uuid4()) @@ -90,9 +132,12 @@ def save_ir( os.makedirs(temp_dir, exist_ok=False) temp_fname = os.path.join(temp_dir, initial_name) - if asBytecode: + if as_bytecode: with open(temp_fname, "wb") as f: - module.operation.write_bytecode(f) + if bytecode_writer: + bytecode_writer(f) + else: + module.operation.write_bytecode(f) else: with open(temp_fname, "w") as f: print(module, file=f) @@ -105,13 +150,35 @@ def save_ir( def check_func_name(jit_cache, func_name): + """Check if the function name is in the cache. + If not, create a new JitCompiledFunction object and add it to the cache. + + :param jit_cache: The cache to check. + :type jit_cache: dict + :param func_name: The name of the function to check. + :type func_name: str + :return: The cache + :rtype: dict + """ if not func_name in jit_cache: - jit_cache[func_name] = JitExecutor(None, None, None, None, None, None) + jit_cache[func_name] = JitCompiledFunction( + None, None, None, None, None, [], False, None + ) return jit_cache def load_cache_from_path(dsl_name, cache_limit, path=default_generated_ir_path): - """Load cache from a directory path.""" + """Load cache from a directory path. + + :param dsl_name: The name of the DSL. + :type dsl_name: str + :param cache_limit: The limit of the cache. + :type cache_limit: int + :param path: The path to the cache directory, defaults to default_generated_ir_path + :type path: str, optional + :return: The cache + :rtype: dict + """ if not os.path.exists(path): return dict() files = os.listdir(path) @@ -136,18 +203,38 @@ def load_cache_from_path(dsl_name, cache_limit, path=default_generated_ir_path): def dump_cache_to_path( - dsl_name, jit_cache, cache_limit, path=default_generated_ir_path + dsl_name, + jit_cache, + cache_limit, + path=default_generated_ir_path, + bytecode_writer=None, ): + """Dump the cache to a directory path. + + :param dsl_name: The name of the DSL. + :type dsl_name: str + :param jit_cache: The cache to dump. + :type jit_cache: dict + :param cache_limit: The limit of the cache. + :type cache_limit: int + :param path: The path to the cache directory, defaults to default_generated_ir_path + :type path: str, optional + :param bytecode_writer: The bytecode writer to use, defaults to None + :type bytecode_writer: callable, optional + """ log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache)) os.makedirs(path, exist_ok=True) - original_path = os.getcwd() try: - os.chdir(path) for idx, [key, value] in enumerate(jit_cache.items()): if idx >= int(cache_limit): break - save_ir(dsl_name, value.ir_module, key, asBytecode=True) + save_ir( + dsl_name, + value.ir_module, + key, + output_dir=path, + as_bytecode=True, + bytecode_writer=bytecode_writer, + ) except Exception as e: print(f"{dsl_name} failed with caching generated IR", e) - finally: - os.chdir(original_path) diff --git a/python/CuTeDSL/cutlass/base_dsl/common.py b/python/CuTeDSL/cutlass/base_dsl/common.py index 3cf413ed..7d36e01d 100644 --- a/python/CuTeDSL/cutlass/base_dsl/common.py +++ b/python/CuTeDSL/cutlass/base_dsl/common.py @@ -10,7 +10,7 @@ # is strictly prohibited. import os -from typing import Any, Dict, Iterable, Optional, Union +from typing import Any, Dict, Iterable, Optional, Union, Sequence """ This module provides a Exception classes DSL class for any Dialect. @@ -150,6 +150,9 @@ def _get_friendly_cuda_error_message(error_code, error_name): "CUDA_ERROR_NOT_INITIALIZED": ( f"{Colors.RED}❌ CUDA context not initialized.{Colors.RESET}\n\n" ), + "CUDA_ERROR_INVALID_CONTEXT": ( + f"{Colors.RED}❌ CUDA context not initialized.{Colors.RESET}\n\n" + ), "CUDA_ERROR_INVALID_VALUE": ( f"{Colors.RED}⚠️ Invalid parameter passed to CUDA operation.{Colors.RESET}\n\n" f"{Colors.YELLOW}This is likely a bug - please report it with:{Colors.RESET}" @@ -157,6 +160,10 @@ def _get_friendly_cuda_error_message(error_code, error_name): } error_suggestions = { + "CUDA_ERROR_INVALID_CONTEXT": ( + f"1. Check if CUDA context is properly initialized under your environment", + f"2. Initialize CUDA context with `cuda.cuInit(0)` or `cutlass.cuda.initialize_cuda_context()`", + ), "CUDA_ERROR_INVALID_SOURCE": ( f"1. Ensure env CUTE_DSL_ARCH matches your GPU architecture", f"2. Clear the compilation cache and regenerate the kernel", @@ -266,3 +273,44 @@ class DSLNotImplemented(DSLBaseError): # Useful for stubs in your DSL that you plan to implement in the future. pass + + +class CudaDriverDependencyError(DSLRuntimeError): + """Custom error class for CUDA driver dependency issues""" + + def __init__( + self, + message: str, + ): + # Create a detailed error message with instructions + detailed_message = f"""CUDA Driver Dependency Error + +{message} + +This error typically occurs when: +• NVIDIA GPU drivers are not installed on your system +• The installed drivers are incompatible with CUDA Toolkit 12.9 or latest version +• The libcuda.so.1 library is not accessible""" + + # Use DSLRuntimeError's structured approach + super().__init__( + detailed_message, + suggestion=[ + "Install or update NVIDIA GPU drivers:", + " • Visit: https://www.nvidia.com/Download/index.aspx", + " • Download drivers compatible with CUDA Toolkit 12.9 or latest version", + " • Follow the installation instructions for your OS", + "", + "Verify driver installation:", + " • Run: nvidia-smi", + " • This should display GPU information without errors", + "", + "Check CUDA library availability:", + " • Run: ldconfig -p | grep libcuda", + " • This should show libcuda.so.1 in the output", + "", + "For more information, see:", + " • CUDA Toolkit documentation: https://docs.nvidia.com/cuda/", + " • CUTLASS DSL requirements: nvidia-cutlass-dsl documentation", + ], + ) diff --git a/python/CuTeDSL/cutlass/base_dsl/compiler.py b/python/CuTeDSL/cutlass/base_dsl/compiler.py index f8b2da07..9aee3acc 100644 --- a/python/CuTeDSL/cutlass/base_dsl/compiler.py +++ b/python/CuTeDSL/cutlass/base_dsl/compiler.py @@ -19,9 +19,9 @@ from typing import Sequence, Optional, Tuple import os import sys import inspect -import argparse -from .common import DSLRuntimeError +from .common import DSLRuntimeError, CudaDriverDependencyError from .utils.logger import log +from .env_manager import EnvironmentVarManager _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) sys.path.append(_SCRIPT_PATH) @@ -59,6 +59,7 @@ class CompilationError(RuntimeError): self.arch = arch # Call parent with formatted error to avoid showing class name super().__init__("") # Empty string to avoid class name + # Store formatted error for str() representation self._formatted_error = self._format_error() @@ -96,10 +97,8 @@ class Compiler: def __init__(self, passmanager, execution_engine): self.passmanager = passmanager self.execution_engine = execution_engine - - def __call__(self, module): - """Convenience application method.""" - self.compile(module) + # Flag to track if CUDA dependencies have been checked once in this process + self._cuda_dependencies_checked = False def _process_error(self, error_msg: str) -> Tuple[Optional[str], Optional[str]]: """Process error message to extract NVVM error and IR context""" @@ -161,6 +160,11 @@ class Compiler: def jit(self, module, opt_level: int = 2, shared_libs: Sequence[str] = ()): """Wraps the module in a JIT execution engine.""" + # Check CUDA driver and GPU dependencies before JIT execution (once per process) + self._check_cuda_dependencies_once(shared_libs) + + # If pre-checks passed, attempt to create ExecutionEngine + # Any failures at this point are likely non-CUDA related return self.execution_engine.ExecutionEngine( module, opt_level=opt_level, shared_libs=shared_libs ) @@ -181,108 +185,381 @@ class Compiler: cuda_toolkit, arch, ) + return self.jit(module, opt_level, shared_libs) + def _check_cuda_dependencies_once(self, shared_libs: Sequence[str]) -> None: + """ + Check CUDA dependencies only once per process lifecycle. + After the first check (success or failure), skip all subsequent checks + as the runtime environment doesn't change during process execution. + """ + if self._cuda_dependencies_checked: + return # Already checked in this process, skip + + # Mark as checked to skip all future validations + self._cuda_dependencies_checked = True + + # Simple CUDA driver check - just call cuInit(0) + try: + import cuda.bindings.driver as cuda + + cuda.cuInit(0) + except Exception as e: + # Create a comprehensive error message for CUDA driver issues + error_message = ( + "CUDA runtime initialization failed during dependency check." + ) + + raise CudaDriverDependencyError( + message=error_message, + ) + + +class CompileOption: + """ + Base class for compile options. + """ + + option_name = "" # name of the compile option in the pipeline + + def __init__(self, val): + self._value = val + + def serialize(self): + return f"{self.__class__.option_name}={self._value}" + + @property + def value(self): + return self._value + + @value.setter + def value(self, value): + self._value = value + + +class BooleanCompileOption(CompileOption): + def __init__(self, val: bool = True): + super().__init__(val) + + def serialize(self): + return f"{self.__class__.option_name}={'true' if self._value else 'false'}" + + +class StringCompileOption(CompileOption): + def serialize(self): + if self._value: + self._value = self._value.strip("'") + return f"{self.__class__.option_name}='{self._value}'" + return "" + + +class BooleanBasedFileDumpOption(CompileOption): + def __init__(self, val: bool = True): + super().__init__(val) + self._dump_path = "" + + @property + def dump_path(self): + return self._dump_path + + @dump_path.setter + def dump_path(self, path): + self._dump_path = path + + def serialize(self): + if self._value: + assert self._dump_path, ( + f"Dump path is not set for {self.__class__.__name__}" + ) + return f"{self.__class__.option_name}='{self._dump_path}'" + return "" + + +class OptLevel(CompileOption): + option_name = "opt-level" + + def __init__(self, val: int): + if val < 0 or val > 3: + raise DSLRuntimeError(f"Invalid OPT_LEVEL: {val}, valid range is [0, 3]") + super().__init__(val) + + + +class PtxasOptions(StringCompileOption): + option_name = "ptx-options" + + +class EnableAssertions(BooleanCompileOption): + option_name = "enable-assertions" + + +class GenerateLineInfo(BooleanCompileOption): + option_name = "preserve-line-info" + + +class KeepCUBIN(BooleanBasedFileDumpOption): + option_name = "dump-cubin-path" + + +class KeepPTX(BooleanBasedFileDumpOption): + option_name = "dump-ptx-path" + + +class LinkLibraries(StringCompileOption): + option_name = "link-libraries" + + +class GPUArch(StringCompileOption): + option_name = "cubin-chip" + + def __init__(self, val): + if isinstance(val, str) and val.startswith("sm_110"): + val = val.replace("sm_110", "sm_101") + super().__init__(val) + + @property + def value(self) -> bool: + return self._value + + @value.setter + def value(self, value: bool): + if isinstance(value, str) and value.startswith("sm_110"): + value = value.replace("sm_110", "sm_101") + self._value = value + class CompileOptions: - def __init__(self, options: str = ""): - """ - This class encapsulates all compilation options relevant to function compilation. - It provides a convenient way to manage and pass compilation options, - particularly for controlling compilation settings. - By centralizing these options, it ensures consistent and flexible configuration of - compilation parameters such as optimization level, debugging control, etc. + """ + This class encapsulates compilation options to configure the JIT compilation. + It provides a convenient way to manage and pass compilation options. + By centralizing these options, it ensures consistent and flexible configuration of + compilation parameters such as optimization level, debugging control, etc. + """ - :param options: The options for the function. Will be parsed by argparse. - :type options: str - """ - if not isinstance(options, str): - raise DSLRuntimeError( - f"Invalid compilation `options`: {options}, it should be a string" + def __init__(self, options=None): + self.options = { + # Compilation control options + OptLevel: OptLevel(3), + PtxasOptions: PtxasOptions(""), + # Debugging options + EnableAssertions: EnableAssertions(False), + GenerateLineInfo: GenerateLineInfo(False), + KeepCUBIN: KeepCUBIN(False), + KeepPTX: KeepPTX(False), + GPUArch: GPUArch(""), + LinkLibraries: LinkLibraries(""), + } + + if options is not None: + self._update(options) + + def _update(self, options): + def _validate_and_update_option(option): + if type(option) not in self.options: + raise DSLRuntimeError(f"Invalid compile option: {option}") + self.options[type(option)] = option + + if isinstance(options, tuple): + for option in options: + _validate_and_update_option(option) + else: + _validate_and_update_option(options) + + def apply_envar_settings(self, envar: EnvironmentVarManager, function_name: str): + # Honor the settings from environment variables as well + if envar.keep_ptx: + self.options[KeepPTX].value = True + if envar.keep_cubin: + self.options[KeepCUBIN].value = True + if envar.enable_assertions: + self.options[EnableAssertions].value = True + if envar.lineinfo: + self.options[GenerateLineInfo].value = True + + # Update the dump path if the option is set + if self.options[KeepPTX].value: + self.options[KeepPTX].dump_path = os.path.join( + envar.dump_dir, f"{function_name}.ptx" ) - self._parser = argparse.ArgumentParser() - self._parser.add_argument("--opt-level", nargs="?", type=int, default=3) - self._parser.add_argument( - "--enable-device-assertions", action="store_true", default=False + if self.options[KeepCUBIN].value: + self.options[KeepCUBIN].dump_path = os.path.join( + envar.dump_dir, f"{function_name}.cubin" + ) + + @property + def generate_line_info(self) -> bool: + return self.options[GenerateLineInfo].value + + @property + def gpu_arch(self) -> str: + return self.options[GPUArch].value + + @property + def dump_ptx_path(self) -> str | None: + return self.options[KeepPTX].dump_path if self.options[KeepPTX].value else None + + @property + def dump_cubin_path(self) -> str | None: + return ( + self.options[KeepCUBIN].dump_path if self.options[KeepCUBIN].value else None ) - self._parser.add_argument("--link-libraries", type=str, default="") - try: - self._options = self._parser.parse_args(options.split()) - except SystemExit as e: - # catch argparse error and raise as DSLRuntimeError - raise DSLRuntimeError( - f"Invalid compile options: '{options}'. Please check the option values and format." - ) - log().info("`cute.compile` CompileOptions: options=" + options) - - def to_str(self): + def to_str(self) -> str: """ Generate a string representation of all compilation options which will be used in pipeline options. """ - option_strings = [] - for key, value in vars(self._options).items(): - hyphen_key = key.replace("_", "-") - if isinstance(value, bool): - formatted_value = "true" if value else "false" - else: - formatted_value = str(value) - option_strings.append(f"{hyphen_key}={formatted_value}") + flattend_options = "" + for option in self.options.values(): + flattend_options += option.serialize() + " " - return " ".join(option_strings) + log().info("`cute.compile` CompileOptions: options=" + flattend_options) + return flattend_options -def compile(func, *args, **kwargs): +def _parse_compile_options_from_str(options: str) -> CompileOptions: """ - This function is used to compile a `cute.jit` decorated function. - It will process the compile options and input parameters, do explicit compilation and return the jit executor. - - :param func: The function to compile. It can be a regular function, a method or a class instance. - :param args: The arguments to pass to the function. - :param kwargs: The keyword arguments to pass to the function. It can contain `options` like - `opt_level` to control the compilation flags. - - :return: The jit executor. - - :raises: DSLRuntimeError if the function is not decorated with `cute.jit` or is not callable. + Parse the compile options from a string. """ - if func is None: - raise DSLRuntimeError("Function is not set or invalid.") - if not callable(func): - raise DSLRuntimeError("Object is not callable.") + def _get_compile_option_from_str(option_str: str): + mapping = { + "opt_level": OptLevel, + "ptxas_options": PtxasOptions, + "enable_assertions": EnableAssertions, + "link_libraries": LinkLibraries, + "generate_line_info": GenerateLineInfo, + "keep_cubin": KeepCUBIN, + "keep_ptx": KeepPTX, + "gpu_arch": GPUArch, + } + return mapping[option_str] - kwargs["compile_only"] = True - kwargs["no_cache"] = True + import argparse + import shlex - if inspect.isfunction(func): - # regular function - pass - elif inspect.ismethod(func): - # if it's a method, add the instance to the first argument - args = [func.__self__] + list(args) - func = func.__func__ - elif inspect.isclass(type(func)) and hasattr(func, "__call__"): - # If it's a class instance, get the class's __call__ method - args = [func] + list(args) - # Get the actual function from the class definition - func = func.__call__.__func__ - else: + parser = argparse.ArgumentParser() + parser.add_argument("--opt-level", nargs="?", type=int, default=3) + parser.add_argument("--enable-assertions", action="store_true", default=False) + parser.add_argument("--link-libraries", type=str, default="") + parser.add_argument("--generate-line-info", action="store_true", default=False) + parser.add_argument("--keep-cubin", action="store_true", default=False) + parser.add_argument("--keep-ptx", action="store_true", default=False) + parser.add_argument("--ptxas-options", type=str, default="") + parser.add_argument("--gpu-arch", type=str, default="") + compile_options = CompileOptions() + try: + # Use shlex to properly handle options with spaces + parsed_options = shlex.split(options) if options else [] + # Avoid parsing the ptxas-options value as a hyphen key + for i in range(1, len(parsed_options)): + if parsed_options[i - 1] in ["--ptxas-options"]: + parsed_options[i] = f"'{parsed_options[i]}'" + option_dict = vars(parser.parse_args(parsed_options)) + for option, value in option_dict.items(): + option = _get_compile_option_from_str(option) + compile_options.options[option].value = value + except SystemExit as e: + # catch argparse error and raise as DSLRuntimeError raise DSLRuntimeError( - "Invalid function type, only function, method and module are supported, but got", - func, - ) + f"Invalid compile options: '{options}'. Please check the option values and format." + ) from e - # If it's a wrapped function created by jit decorator, get the original function - if hasattr(func, "__wrapped__"): - func = func.__wrapped__ + return compile_options - if not hasattr(func, "_dsl_object"): - raise DSLRuntimeError("Function is not decorated with jit decorator.") - # process compile options, extract the options and remove them from the kwargs - options = kwargs.pop("options", "") - func._dsl_object.compile_options = CompileOptions(options) - fcn_ptr = func._dsl_object._preprocess_and_execute(func) - return func._dsl_object._func(fcn_ptr, *args, **kwargs) +class CompileCallable: + def __init__(self, options=None): + def preprocess_options(option): + if type(option) is type and issubclass( + option, (BooleanCompileOption, BooleanBasedFileDumpOption) + ): + # Automatically creates a True instance of the option + return option(True) + elif isinstance(option, tuple): + return tuple(preprocess_options(opt) for opt in option) + return option + + self._compile_options = CompileOptions(preprocess_options(options)) + + def __getitem__(self, options): + """ + Get a new CompileCallable object with the specified options. + """ + new_callable_with_options = CompileCallable(options) + return new_callable_with_options + + def __call__(self, *args, **kwargs): + return self._compile(*args, **kwargs) + + def _compile(self, func, *args, **kwargs): + """ + This function is used to compile a `cute.jit` decorated function. + It will process the compile options and input parameters, do explicit compilation and return the jit executor. + + :param func: The function to compile. It can be a regular function, a method or a class instance. + :param args: The arguments to pass to the function. + :param kwargs: The keyword arguments to pass to the function. It can contain `options` like + `opt_level` to control the compilation flags. + + :return: The jit executor. + + :raises: DSLRuntimeError if the function is not decorated with `cute.jit` or is not callable. + """ + if func is None: + raise DSLRuntimeError("Function is not set or invalid.") + + if not callable(func): + raise DSLRuntimeError("Object is not callable.") + + kwargs["compile_only"] = True + kwargs["no_cache"] = True + + if inspect.isfunction(func): + # regular function + pass + elif inspect.ismethod(func): + # if it's a method, add the instance to the first argument + args = [func.__self__] + list(args) + func = func.__func__ + elif ( + inspect.isclass(type(func)) + and hasattr(func, "__call__") + and hasattr(func.__call__, "__func__") + ): + # If it's a class instance, get the class's __call__ method + args = [func] + list(args) + # Get the actual function from the class definition + func = func.__call__.__func__ + else: + raise DSLRuntimeError( + "Invalid function type, only function, method and module are supported, but got", + func, + ) + + # If it's a wrapped function created by jit decorator, get the original function + if hasattr(func, "__wrapped__"): + func = func.__wrapped__ + + # Lazy initialization of DSL object if has not been initialized + # Use local import to avoid circular import + from .dsl import BaseDSL + + BaseDSL._lazy_initialize_dsl(func) + + if not hasattr(func, "_dsl_object"): + raise DSLRuntimeError("Function is not decorated with jit decorator.") + + # process compile options, extract the options and remove them from the kwargs + options = kwargs.pop("options", None) + if options is not None and isinstance(options, str): + compile_options = _parse_compile_options_from_str(options) + else: + compile_options = self._compile_options + func._dsl_object.compile_options = compile_options + fcn_ptr = func._dsl_object._preprocess_and_execute(func) + + if hasattr(func, "_decorator_frame"): + kwargs["_decorator_frame"] = func._decorator_frame + return func._dsl_object._func(fcn_ptr, *args, **kwargs) diff --git a/python/CuTeDSL/cutlass/base_dsl/dsl.py b/python/CuTeDSL/cutlass/base_dsl/dsl.py index 2b17d22b..4cb3c111 100644 --- a/python/CuTeDSL/cutlass/base_dsl/dsl.py +++ b/python/CuTeDSL/cutlass/base_dsl/dsl.py @@ -16,7 +16,6 @@ It handles most of the mechanics for the DSL in an agnostic way, for example, it can handle various dialect-specific tasks. """ - # Standard library imports from dataclasses import dataclass, field import atexit @@ -24,16 +23,15 @@ import os import io import sys import errno -import ctypes import re import inspect import argparse import hashlib from functools import lru_cache, wraps -from collections import namedtuple +from collections import namedtuple, OrderedDict from abc import ABC, abstractmethod -from typing import Any, Union, Tuple, get_origin, get_args, List -from types import FunctionType, SimpleNamespace +from typing import Any, Callable, List +from types import SimpleNamespace import warnings from . import typing as t @@ -41,38 +39,29 @@ from .env_manager import EnvironmentVarManager from .compiler import CompileOptions from .ast_helpers import DSLOptimizationWarning -# ============================================================================= -# CUDA Python -# ============================================================================= - -from ..base_dsl._mlir_helpers.arith import const - # ============================================================================= # Local module imports # ============================================================================= from .cache_helpers import * -from .jit_executor import JitExecutor +from .jit_executor import JitCompiledFunction, JitFunctionArtifacts from .utils.timer import timer -from .utils.logger import setup_log, log +from .utils.logger import log from .utils.stacktrace import filter_exception, walk_to_top_module, filter_stackframe from .runtime.jit_arg_adapters import is_argument_constexpr, JitArgAdapterRegistry from .ast_preprocessor import DSLPreprocessor from .common import * -from .typing import ( - get_c_pointers, - get_mlir_types, -) +from .typing import get_c_pointers, get_mlir_types, Integer +from .arch import Arch # ============================================================================= # MLIR modules # ============================================================================= from .._mlir import ir -from .._mlir import runtime as rt from .._mlir.extras import types as T -from .._mlir.dialects import arith, math, func +from .._mlir.dialects import func # ============================================================================= # Global Variables @@ -229,7 +218,6 @@ def new_from_mlir_values(obj, values): suggestion="Consider using a list or tuple instead", ) elif is_dynamic_expression(obj): - if len(values) == 0: return obj @@ -260,6 +248,7 @@ class DSLCallable: def __init__(self, func): self.func = func + self.name = func.__name__ def __call__(self, *args, **kwargs): ret = self.__func__(*args, **kwargs) @@ -277,11 +266,12 @@ class DSLCallable: @property def __name__(self): - return self.__func__.__name__ + return self.name class BaseDSL: gpu_module = None + _env_class = EnvironmentVarManager def __init__( self, @@ -324,7 +314,7 @@ class BaseDSL: self.device_compilation_only = device_compilation_only self.num_kernels = 0 # Read environment variables - self.envar = EnvironmentVarManager(self.name) + self.envar = self._env_class(self.name) self.enable_preprocessor = preprocess # This cache uses hash of original ir and env as key, allows dump/load to/from file. Enabled by default self.jit_cache = ( @@ -344,20 +334,10 @@ class BaseDSL: if self.envar.warnings_ignore: warnings.filterwarnings("ignore") - # Initialize logger - if self.envar.log_to_console == False and self.envar.jitTimeProfiling: - self.envar.log_to_console = True - self.envar.log_level = 20 # info level - setup_log( - self.name, - self.envar.log_to_console, - self.envar.log_to_file, - f"{self.name}.log", - self.envar.log_level, - ) - - # kernel symbols are temporary symbol string variables, their values are valid until the compilation is done. - self.kernel_symbols = [] + # kernel info contains per kernel info including symbol string and CUfunction attributes to set + # It's valid until the compilation is done. + # {symbol_string: {CUfunction_attribute: value}} + self.kernel_info = OrderedDict() # used to generate unique name for gpu.launch self.launch_inner_count = 0 # initialize default compile options @@ -369,7 +349,7 @@ class BaseDSL: log().debug(f"Logger initialized for {self.name}") # Hook excepthook - if self.envar.filterStacktrace: + if self.envar.filter_stacktrace: origin_excepthook = sys.excepthook module_dir = walk_to_top_module(os.path.dirname(os.path.abspath(__file__))) @@ -427,10 +407,13 @@ class BaseDSL: """ Get the original function from the decorated function """ - while fcn_ptr.__name__ != name: + + while not hasattr(fcn_ptr, "__name__") or fcn_ptr.__name__ != name: # If the function is wrapped with functools, get from __wrapped__ if hasattr(fcn_ptr, "__wrapped__"): fcn_ptr = fcn_ptr.__wrapped__ + elif isinstance(fcn_ptr, staticmethod): + fcn_ptr = fcn_ptr.__func__ # If the function is wrapped manually, it's the first in clousure elif callable(fcn_ptr.__closure__[0].cell_contents): fcn_ptr = fcn_ptr.__closure__[0].cell_contents @@ -445,6 +428,19 @@ class BaseDSL: """ Run ast transformation and return the materialized function pointer """ + + # Lazy initialization of DSL object if has not been initialized + if not hasattr(func, "_dsl_object"): + func._dsl_object = func._dsl_cls._get_dsl() + delattr(func, "_dsl_cls") + + if not func._dsl_object.enable_preprocessor: + if hasattr(func, "_decorator_frame"): + delattr(func, "_decorator_frame") + if hasattr(func, "_transformed_ast"): + delattr(func, "_transformed_ast") + return func + if hasattr(func, "_transformed_ast"): # If the function ptr is already materialized, use the existing one func._dsl_object.frame = func._decorator_frame @@ -462,16 +458,17 @@ class BaseDSL: return DSLCallable(fcn_ptr) return func - def jit_runner(self, executor, frame, *dargs, **dkwargs): + @staticmethod + def jit_runner(cls, executor_name, frame, *dargs, **dkwargs): """ Decorator to mark a function for JIT compilation. """ log().info("jit_runner") def jit_runner_decorator(func): - func._dsl_object = self # Run preprocessor that alters AST - if self.enable_preprocessor and BaseDSL._can_preprocess(**dkwargs): + func._dsl_cls = cls + if BaseDSL._can_preprocess(**dkwargs): # For an annotated function, add some DSL attributes # When materializing the AST, we need decorator's frame func._decorator_frame = frame @@ -481,7 +478,9 @@ class BaseDSL: @wraps(func) def jit_wrapper(*args, **kwargs): func_ptr = BaseDSL._preprocess_and_execute(func) - return executor(func_ptr, *args, **kwargs) + return getattr(func._dsl_object, executor_name)( + func_ptr, *args, **kwargs + ) return jit_wrapper @@ -490,15 +489,22 @@ class BaseDSL: else: return jit_runner_decorator + @staticmethod + def _lazy_initialize_dsl(func): + """ + Lazy initialization of DSL object if has not been initialized + """ + if hasattr(func, "_dsl_cls"): + func._dsl_object = func._dsl_cls._get_dsl() + delattr(func, "_dsl_cls") + @classmethod def jit(cls, *dargs, **dkwargs): """ Decorator to mark a function for JIT compilation for Host code. """ frame = inspect.currentframe().f_back - # Instantiate the DSL Class - main_dsl = cls._get_dsl() - return main_dsl.jit_runner(main_dsl._func, frame, *dargs, **dkwargs) + return BaseDSL.jit_runner(cls, "_func", frame, *dargs, **dkwargs) @classmethod def kernel(cls, *dargs, **dkwargs): @@ -506,9 +512,7 @@ class BaseDSL: Decorator to mark a function for JIT compilation for GPU. """ frame = inspect.currentframe().f_back - # Instantiate the DSL Class - main_dsl = cls._get_dsl() - return main_dsl.jit_runner(main_dsl._kernel_helper, frame, *dargs, **dkwargs) + return BaseDSL.jit_runner(cls, "_kernel_helper", frame, *dargs, **dkwargs) @abstractmethod def _kernel_helper(self, func, *args, **kwargs): @@ -518,7 +522,7 @@ class BaseDSL: pass @abstractmethod - def _build_gpu_module(self, attrs): + def _build_gpu_module(self, attrs, loc=None): """ Build the module op that contains the kernels. """ @@ -581,6 +585,7 @@ class BaseDSL: function_name = re.sub(r"\s+", " ", function_name) function_name = function_name.replace(" ", "_") function_name = function_name.replace("\n", "_") + function_name = function_name.replace("/", "_") # max fname is 256 character, leave space function_name = function_name[:180] log().info(f"Final mangled function name: {function_name}") @@ -817,7 +822,7 @@ class BaseDSL: pass else: raise DSLRuntimeError( - f"failed to generate argument #{i+1} ({arg_name}) for JIT function '{function_name}'.", + f"failed to generate argument #{i + 1} ({arg_name}) for JIT function '{function_name}'.", context={ f"Argument {arg_name}": "The DSL attempted to convert it into Dynamic Expression (aka MLIR values) but failed.", f"Call-site argument value": arg, @@ -851,9 +856,9 @@ class BaseDSL: log().debug("Execution Arguments: %s", ", ".join(map(str, exe_args))) log().debug("Types: %s", ", ".join(map(str, types))) - assert len(exe_args) == len( - types - ), "expects the same number of arguments and function parameters" + assert len(exe_args) == len(types), ( + "expects the same number of arguments and function parameters" + ) return exe_args, types, adapted_args @@ -866,13 +871,30 @@ class BaseDSL: async_deps: list = field(default_factory=list) has_cluster: bool = False min_blocks_per_mp: int = 0 + use_pdl: bool = False auto_smem: bool = False + @staticmethod + def _check_and_canonicalize_dim(dim, name): + if not isinstance(dim, (list, tuple)): + dim = [dim] + + if len(dim) > 3: + raise DSLRuntimeError( + f"Expected {name} dimension to be less than or equal to 3, but got {len(dim)}" + ) + + if any(not isinstance(e, (Integer, int)) for e in dim): + raise DSLRuntimeError( + f"Expected integer for {name} dimension, but got {type(e)}" + ) + + # Pad with 1s to 3-dim vector for grid or block dimensions + return list(dim) + [1] * (3 - len(dim)) + def __post_init__(self): - if len(self.grid) != 3: - raise DSLRuntimeError(f"Expect 3d grid!") - if len(self.block) != 3: - raise DSLRuntimeError(f"Expect 3d block!") + self.grid = self._check_and_canonicalize_dim(self.grid, "grid") + self.block = self._check_and_canonicalize_dim(self.block, "block") if self.smem is None: self.smem = 0 @@ -916,20 +938,27 @@ class BaseDSL: else: ir._GlobalDebug.set_types(f"diagnostic-{args.diagnostic}") - def get_location(self): + def get_location(self, frame=None): """ Get python location information and generate MLIR location """ - - if self.frame is None: - log().debug("Frame is None") - return None + frame = self.frame if frame is None else frame + frame = inspect.currentframe().f_back if frame is None else frame + frameInfo = inspect.getframeinfo(frame) file_loc = ir.Location.file( - self.frame.f_code.co_filename, self.frame.f_lineno, 0 + frame.f_code.co_filename, + frame.f_lineno, + frameInfo.positions.col_offset if hasattr(frameInfo, "positions") else 0, + ) + loc = ir.Location.name( + ( + "".join([c.strip() for c in frameInfo.code_context]) + if frameInfo.code_context + else frameInfo.function + ), + childLoc=file_loc, ) - - loc = ir.Location.name(self.frame.f_code.co_name, childLoc=file_loc) return loc def compile_and_jit(self, module, pipeline, shared_libs, function_name=""): @@ -945,13 +974,18 @@ class BaseDSL: sys.stderr = redirect_stderr = io.StringIO() sys.stdout = redirect_stdout = io.StringIO() + compile_gpu_arch = ( + self.envar.arch + if not self.compile_options.gpu_arch + else self.compile_options.gpu_arch + ) try: kernel = self.compiler_provider.compile_and_jit( module, pipeline, shared_libs=shared_libs, cuda_toolkit=self.envar.cuda_toolkit, - arch=self.envar.arch, + arch=compile_gpu_arch, ) finally: @@ -971,7 +1005,6 @@ class BaseDSL: pass def preprocess_pipeline(self, pipeline, arch) -> str: - if self.envar.cuda_toolkit is None: self.print_warning( "CUDA_TOOLKIT_PATH environment variable is not set. Cannot set toolkitPath." @@ -1046,14 +1079,18 @@ class BaseDSL: """ # Save IR in a file - if self.envar.keepIR: - save_ir(self.name, module, function_name) - - if self.envar.printIR: - print("\n//===--- ------ Generated IR ------ ---====\n") - module.operation.print( - enable_debug_info=self.envar.generate_source_location + if self.envar.keep_ir: + self.dump_mlir_path = save_ir( + self.name, + module, + function_name, + output_dir=self.envar.dump_dir, ) + + if self.envar.print_ir: + print("\n//===--- ------ Generated IR ------ ---====\n") + enable_debug_info = self.compile_options.generate_line_info + module.operation.print(enable_debug_info=enable_debug_info) print("\n//===--- --- End of Generated IR -- ---====\n") # Verify the module @@ -1075,22 +1112,20 @@ class BaseDSL: gpu_module_attrs, args, args_spec, + frame=None, ): - # This location is set to None for now; otherwise, calls to the same - # function on different lines would produce different line numbers, - # which would break the cache. - loc = None # self.get_location() - def build_ir_module(): - module = ir.Module.create(loc=loc) + module = ir.Module.create(loc=self.get_location(frame)) unit_attr = ir.UnitAttr.get() module.operation.attributes["gpu.container_module"] = unit_attr with ir.InsertionPoint(module.body): # Always generate gpu module. It's canonicalized by the compiler when it's not used. - self._build_gpu_module(gpu_module_attrs) + self._build_gpu_module(gpu_module_attrs, loc=self.get_location(frame)) - fop = func.FuncOp(function_name, (func_types, []), loc=loc) + fop = func.FuncOp( + function_name, (func_types, []), loc=self.get_location(frame) + ) fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() log().debug("Generated Function OP [%s]", fop) with ir.InsertionPoint(fop.add_entry_block()): @@ -1100,7 +1135,7 @@ class BaseDSL: # Call user function body try: result = funcBody(*ir_args, **ir_kwargs) - func.ReturnOp([]) + func.ReturnOp([], loc=self.get_location(frame)) except NameError as name_error: raise DSLRuntimeError( f"💥💥💥 Error during runtime code generation for function `{funcBody.__name__}` 💥💥💥", @@ -1113,7 +1148,7 @@ class BaseDSL: return module, result # Build IR module - profiler = timer(enable=self.envar.jitTimeProfiling) + profiler = timer(enable=self.envar.jit_time_profiling) module, result = profiler(build_ir_module)() module_hash = self.get_module_hash(module, function_name) @@ -1122,12 +1157,29 @@ class BaseDSL: return module, module_hash, result def compile_and_cache( - self, module, module_hash, function_name, pipeline, args_spec, no_cache + self, + module, + module_hash, + function_name, + pipeline, + args_spec, + no_cache, + func_type=JitCompiledFunction, ): - arch = self.envar.arch - pipeline = self.preprocess_pipeline(self._get_pipeline(pipeline), arch) + # If `gpu-arch` is set by compile_options, use it. Otherwise, use the arch from the environment variable. + compile_gpu_arch = ( + self.envar.arch + if not self.compile_options.gpu_arch + else self.compile_options.gpu_arch + ) + # If no gpu kernels or compile_gpu_arch is same as the arch from the environment variable, generate a JIT engine. Otherwise, only do the compilation. + gen_jit_engine = self.num_kernels == 0 or compile_gpu_arch == self.envar.arch + # Preprocess the pipeline. + pipeline = self.preprocess_pipeline( + self._get_pipeline(pipeline), compile_gpu_arch + ) shared_libs = self.get_shared_libs() - profiler = timer(enable=self.envar.jitTimeProfiling) + profiler = timer(enable=self.envar.jit_time_profiling) if ( no_cache or module_hash not in self.jit_cache @@ -1139,9 +1191,13 @@ class BaseDSL: module_hash, ) # Compile and JIT MLIR module - engine = profiler(self.compile_and_jit)( - module, pipeline, shared_libs, function_name=function_name - ) + if gen_jit_engine: + engine = profiler(self.compile_and_jit)( + module, pipeline, shared_libs, function_name=function_name + ) + else: + profiler(self.compiler_provider.compile)(module, pipeline) + engine = None else: log().info( "JIT cache hit IN-FILE function=[%s] module_hash=[%s]", @@ -1149,29 +1205,38 @@ class BaseDSL: module_hash, ) module = self.jit_cache[module_hash].ir_module - engine = self.compiler_provider.jit(module, shared_libs=shared_libs) - capi_func = profiler(engine.lookup)(function_name) - jit_executor = JitExecutor( - self, + engine = ( + self.compiler_provider.jit(module, shared_libs=shared_libs) + if gen_jit_engine + else None + ) + capi_func = profiler(engine.lookup)(function_name) if engine else None + + fn = JitCompiledFunction( + module, engine, capi_func, - module, args_spec, function_name, - jit_time_profiling=self.envar.jitTimeProfiling, + self.kernel_info, + jit_time_profiling=self.envar.jit_time_profiling, + jit_function_artifacts=JitFunctionArtifacts( + PTX=self.compile_options.dump_ptx_path, + CUBIN=self.compile_options.dump_cubin_path, + MLIR=(self.dump_mlir_path if self.envar.keep_ir else None), + ), ) - jit_executor = jit_executor.update_jit_cuda_modules(self.kernel_symbols) if not no_cache: # module stored in cache is compiled. - self.jit_cache[module_hash] = jit_executor + self.jit_cache[module_hash] = fn - return jit_executor + return fn def post_compilation_cleanup(self): """Clean up some internal state after one compilation is completed.""" - # clear the kernel symbols after the compilation is done. - self.kernel_symbols = [] + # clear the kernel info after the compilation is done. + self.kernel_info = {} self.launch_inner_count = 0 # reset num_kernels to 0 for next compilation. self.num_kernels = 0 @@ -1190,6 +1255,7 @@ class BaseDSL: no_cache, compile_only, loc=None, + frame=None, ): """Generate MLIR module and compile iself.T_provider.""" with ir.Context(), ir.Location.unknown(): @@ -1209,6 +1275,7 @@ class BaseDSL: gpu_module_attrs, args, args_spec, + frame=frame, ) # dryrun is used to only generate IR @@ -1221,7 +1288,7 @@ class BaseDSL: or self.jit_cache[module_hash].capi_func is None ): # no cache or cache miss, do ir generation/compilation/jit engine - jit_executor = self.compile_and_cache( + jit_function = self.compile_and_cache( module, module_hash, function_name, pipeline, args_spec, no_cache ) else: @@ -1231,14 +1298,16 @@ class BaseDSL: function_name, module_hash, ) - jit_executor = self.jit_cache[module_hash] + jit_function = self.jit_cache[module_hash] self.post_compilation_cleanup() + # If compile_only is set, bypass execution return the jit_executor directly if compile_only: - return jit_executor + return jit_function + # Run the compiled program - jit_executor.run_compiled_program(exe_args) + jit_function.run_compiled_program(exe_args) return result @@ -1344,6 +1413,7 @@ class BaseDSL: pipeline = kwargs.pop("pipeline", None) gpu_module_attrs = kwargs.pop("gpu_module_attrs", {}) + decorator_frame = kwargs.pop("_decorator_frame", None) # Disable cache no_cache = kwargs.pop("no_cache", False) @@ -1351,6 +1421,13 @@ class BaseDSL: # Always compile(disable cache) and return the result jit_executor compile_only = kwargs.pop("compile_only", False) + if not no_cache and ( + self.envar.keep_ptx + or self.envar.keep_cubin + ): + no_cache = True + self.print_warning("Cache is disabled as user wants to generate PTX/ASM.") + if not no_cache and compile_only: no_cache = True self.print_warning("Cache is disabled as user wants to compile only.") @@ -1367,6 +1444,9 @@ class BaseDSL: # Simple name mangling function_name = self.mangle_name(function_name, canonicalized_args, args_spec) + self.compile_options.apply_envar_settings(self.envar, function_name) + if not self.compile_options.generate_line_info: + decorator_frame = None # Generate MLIR Context and start generating IR log().debug(f"Generating MLIR for function '{function_name}'") @@ -1380,8 +1460,8 @@ class BaseDSL: pipeline, no_cache, compile_only, + frame=decorator_frame, ) - return result class _KernelGenHelper(ABC): @@ -1465,22 +1545,7 @@ class BaseDSL: This function builds IR and execute the module using cuda driver. It doesn't use mlir's cuda runtime """ - ret = None - - # Step 1. Build IR - with ir.Context(), ir.Location.unknown(): - loc = self.get_location() - module = ir.Module.create(loc=loc) - unit_attr = ir.UnitAttr.get() - module.operation.attributes["gpu.container_module"] = unit_attr - with ir.InsertionPoint(module.body): - self._build_gpu_module() - ret, kernel_name = kernel_generator() - log().debug( - f"Kernel generator returned: ret={ret}, kernel_name={kernel_name}" - ) - - module = self.build_module(module, kernel_name) + ret, kernel_name, module = self._generate_kernel_module(kernel_generator) # dryrun is used to only generate IR if self.envar.dryrun: @@ -1496,6 +1561,28 @@ class BaseDSL: return ret + def _generate_kernel_module(self, kernel_generator): + """ + Generates a module marked as GPU module which contains the kernel generated by :param kernel_generator:. + + :return: A named tuple containing the launch function and function return, the kernel name and the MLIR module. + """ + ret = None + + with ir.Context(), ir.Location.unknown(): + loc = self.get_location() + module = ir.Module.create(loc=loc) + unit_attr = ir.UnitAttr.get() + module.operation.attributes["gpu.container_module"] = unit_attr + with ir.InsertionPoint(module.body): + self._build_gpu_module({}) + ret, kernel_name = kernel_generator() + log().debug( + f"Kernel generator returned: ret={ret}, kernel_name={kernel_name}" + ) + + return ret, kernel_name, self.build_module(module, kernel_name) + def generate_kernel_operands_and_types( self, kernel_func, kernel_name, args_spec, args, kwargs ): @@ -1527,9 +1614,9 @@ class BaseDSL: log().debug("Final kernel_arg_types: %s", ", ".join(map(str, kernel_arg_types))) log().debug("Final kernel_arg_attrs: %s", ", ".join(map(str, kernel_arg_attrs))) - assert ( - len(kernel_operands) == len(kernel_arg_types) == len(kernel_arg_attrs) - ), "Size of kernel_operands, kernel_arg_types and kernel_arg_attrs must be equal" + assert len(kernel_operands) == len(kernel_arg_types) == len(kernel_arg_attrs), ( + "Size of kernel_operands, kernel_arg_types and kernel_arg_attrs must be equal" + ) return kernel_operands, kernel_arg_types, kernel_arg_attrs @@ -1600,9 +1687,9 @@ class BaseDSL: if optionalArgs else None ) - assert ( - kernelGenHelper is not None - ), "kernelGenHelper should be explicitly specified!" + assert kernelGenHelper is not None, ( + "kernelGenHelper should be explicitly specified!" + ) # check arguments sig = self._check_arg_count(*args, **kwargs) @@ -1622,6 +1709,7 @@ class BaseDSL: ) ) + loc = self.get_location() with self._enter_gpu_module(): log().debug("Generating device kernel") if self.device_compilation_only: @@ -1638,7 +1726,6 @@ class BaseDSL: ) helper = kernelGenHelper() - loc = self.get_location() fop = helper.generate_func_op( kernel_types, kernel_arg_attrs, kernel_name, loc ) @@ -1667,6 +1754,7 @@ class BaseDSL: kernelOperands=kernel_operands, requiredArgs=req_args, optionalArgs=opt_args, + loc=loc, ) KernelReturns = namedtuple( @@ -1684,3 +1772,21 @@ class BaseDSL: return decorator(dargs[0]) else: return decorator + + def get_arch_enum(self) -> Arch: + """ + Get the arch enum from the environment variable + """ + arch_option = self.compile_options.gpu_arch + return Arch.from_string(arch_option if arch_option else self.envar.arch) + + def check_arch(self, criterion: Callable[[Arch], bool]) -> None: + """ + Check the arch enum by criterion, raise DSLRuntimeError if the arch enum does not satisfy the criterion + """ + arch = self.get_arch_enum() + if not criterion(arch): + raise DSLRuntimeError( + f"invalid arch, expected one of {Arch.filter(criterion)}, but got {arch}.", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) diff --git a/python/CuTeDSL/cutlass/base_dsl/env_manager.py b/python/CuTeDSL/cutlass/base_dsl/env_manager.py index fa683477..e15cc757 100644 --- a/python/CuTeDSL/cutlass/base_dsl/env_manager.py +++ b/python/CuTeDSL/cutlass/base_dsl/env_manager.py @@ -29,6 +29,7 @@ from typing import Any from ..base_dsl.runtime.cuda import get_compute_capability_major_minor from .utils.logger import log +from .cache_helpers import get_default_file_dump_root IS_WINDOWS = sys.platform == "win32" CLIB_EXT = ".dll" if IS_WINDOWS else ".so" @@ -40,12 +41,21 @@ CLIB_EXT = ".dll" if IS_WINDOWS else ".so" @lru_cache(maxsize=None) def get_str_env_var(var_name, default_value=None): + """ + Get the string value of an environment variable. + Note that the value is cached after the first call. + """ value = os.getenv(var_name) return value if value is not None else default_value @lru_cache(maxsize=None) def get_bool_env_var(var_name, default_value=False): + """ + Get the value of a boolean environment variable. + If the value it not in False, 0, or empty string, it is considered True. + Note that the value is cached after the first call. + """ value = get_str_env_var(var_name) if value is None: return default_value @@ -54,12 +64,21 @@ def get_bool_env_var(var_name, default_value=False): @lru_cache(maxsize=None) def get_int_env_var(var_name, default_value=0): + """ + Get the value of an integer environment variable. + If the value is not a valid integer, the default value 0 is returned. + Note that the value is cached after the first call. + """ value = get_str_env_var(var_name) return int(value) if value and value.isdigit() else default_value @lru_cache(maxsize=None) def has_env_var(var_name): + """ + Check if an environment variable is set. + Note that the value is cached after the first call. + """ return os.getenv(var_name) is not None @@ -85,6 +104,8 @@ def detect_gpu_arch(prefix): suffix = "" if major >= 9: suffix = "a" + if major == 11 and minor == 0: + major, minor = 10, 1 return f"sm_{major}{minor}{suffix}" @@ -244,43 +265,14 @@ def get_prefix_dsl_libs(prefix: str): return None -class EnvironmentVarManager: - """Manages environment variables for configuration options. - - Printing options: - - [DSL_NAME]_LOG_TO_CONSOLE: Print logging to stderr (default: False) - - [DSL_NAME]_PRINT_AFTER_PREPROCESSOR: Print after preprocess (default: False) - - [DSL_NAME]_PRINT_IR: Print generated IR (default: False) - - [DSL_NAME]_FILTER_STACKTRACE: Filter internal stacktrace (default: True) - File options: - - [DSL_NAME]_KEEP_IR: Save generated IR in a file (default: False) - - [DSL_NAME]_LOG_TO_FILE: Store all logging into a file, excluding COMPILE_LOGS (default: False) - Other options: - - [DSL_NAME]_LOG_LEVEL: Logging level to set, for LOG_TO_CONSOLE or LOG_TO_FILE (default: 1). - - [DSL_NAME]_DRYRUN: Generates IR only (default: False) - - [DSL_NAME]_ARCH: GPU architecture (default: "sm_100") - - [DSL_NAME]_WARNINGS_AS_ERRORS: Enable warnings as error (default: False) - - [DSL_NAME]_WARNINGS_IGNORE: Ignore warnings (default: False) - - [DSL_NAME]_ENABLE_OPTIMIZATION_WARNINGS: Enable warnings of optimization warnings (default: False) - - [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False) - - [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False) - - [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000) - - [DSL_NAME]_LIBS: Path to dependent shared libraries (default: None) - - [DSL_NAME]_NO_SOURCE_LOCATION: Generate source location (default: False) - """ - +class LogEnvironmentManager: def __init__(self, prefix="DSL"): - self.prefix = prefix # change if needed + self.prefix = prefix - # Printing options - self.print_after_preprocessor = get_bool_env_var( - f"{prefix}_PRINT_AFTER_PREPROCESSOR", False - ) - self.printIR = get_bool_env_var(f"{prefix}_PRINT_IR", False) - self.filterStacktrace = get_bool_env_var(f"{prefix}_FILTER_STACKTRACE", True) - # File options - self.keepIR = get_bool_env_var(f"{prefix}_KEEP_IR", False) # Logging options + self.jit_time_profiling = get_bool_env_var( + f"{prefix}_JIT_TIME_PROFILING", False + ) self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False) self.log_to_file = get_bool_env_var(f"{prefix}_LOG_TO_FILE", False) if ( @@ -293,9 +285,58 @@ class EnvironmentVarManager: ) self.log_level = get_int_env_var(f"{prefix}_LOG_LEVEL", 1) + +class EnvironmentVarManager(LogEnvironmentManager): + """Manages environment variables for configuration options. + + Printing options: + - [DSL_NAME]_LOG_TO_CONSOLE: Print logging to stderr (default: False) + - [DSL_NAME]_PRINT_AFTER_PREPROCESSOR: Print after preprocess (default: False) + - [DSL_NAME]_PRINT_IR: Print generated IR (default: False) + - [DSL_NAME]_FILTER_STACKTRACE: Filter internal stacktrace (default: True) + File options: + - [DSL_NAME]_DUMP_DIR: Directory to dump the generated files (default: current working directory) + - [DSL_NAME]_KEEP_IR: Save generated IR in a file (default: False) + - [DSL_NAME]_KEEP_PTX: Save generated PTX in a file (default: False) + - [DSL_NAME]_KEEP_CUBIN: Save generated CUBIN in a file (default: False) + - [DSL_NAME]_LOG_TO_FILE: Store all logging into a file, excluding COMPILE_LOGS (default: False) + Other options: + - [DSL_NAME]_LINEINFO: Compile with `--lineinfo` enabling developer tools such as the profiler and debugger (default: False) + - [DSL_NAME]_LOG_LEVEL: Logging level to set, for LOG_TO_CONSOLE or LOG_TO_FILE (default: 1). + - [DSL_NAME]_DRYRUN: Generates IR only (default: False) + - [DSL_NAME]_ARCH: GPU architecture (default: "sm_100") + - [DSL_NAME]_WARNINGS_AS_ERRORS: Enable warnings as error (default: False) + - [DSL_NAME]_WARNINGS_IGNORE: Ignore warnings (default: False) + - [DSL_NAME]_ENABLE_OPTIMIZATION_WARNINGS: Enable warnings of optimization warnings (default: False) + - [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False) + - [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False) + - [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000) + - [DSL_NAME]_LIBS: Path to dependent shared libraries (default: None) + """ + + def __init__(self, prefix="DSL"): + super().__init__(prefix) + + # Printing options + self.print_after_preprocessor = get_bool_env_var( + f"{prefix}_PRINT_AFTER_PREPROCESSOR", False + ) + self.print_ir = get_bool_env_var(f"{prefix}_PRINT_IR", False) + self.filter_stacktrace = get_bool_env_var(f"{prefix}_FILTER_STACKTRACE", True) + self.lineinfo = get_bool_env_var(f"{prefix}_LINEINFO", False) + self.dump_dir = get_str_env_var( + f"{prefix}_DUMP_DIR", get_default_file_dump_root() + ) + self.keep_ptx = get_bool_env_var(f"{prefix}_KEEP_PTX", False) + self.keep_cubin = get_bool_env_var(f"{prefix}_KEEP_CUBIN", False) + + # File options + self.keep_ir = get_bool_env_var(f"{prefix}_KEEP_IR", False) # Other options self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False) self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix)) + if self.arch.startswith("sm_110"): + self.arch = self.arch.replace("sm_110", "sm_101") self.warnings_as_errors = get_bool_env_var( f"{prefix}_WARNINGS_AS_ERRORS", False ) @@ -303,18 +344,17 @@ class EnvironmentVarManager: self.enable_optimization_warnings = get_bool_env_var( f"{prefix}_ENABLE_OPTIMIZATION_WARNINGS", False ) - self.jitTimeProfiling = get_bool_env_var(f"{prefix}_JIT_TIME_PROFILING", False) self.disable_file_caching = get_bool_env_var( f"{prefix}_DISABLE_FILE_CACHING", False ) self.file_caching_capacity = get_int_env_var( f"{prefix}_FILE_CACHING_CAPACITY", 1000 ) - self.generate_source_location = not get_bool_env_var( - f"{prefix}_NO_SOURCE_LOCATION", False - ) # set cuda self.cuda_toolkit = get_cuda_toolkit_path() # set mlir shared libraries self.shared_libs = get_prefix_dsl_libs(prefix) + + # whether to enable assert in host and device code + self.enable_assertions = get_bool_env_var(f"{prefix}_ENABLE_ASSERTIONS", False) diff --git a/python/CuTeDSL/cutlass/base_dsl/jit_executor.py b/python/CuTeDSL/cutlass/base_dsl/jit_executor.py index 83268009..bd4bf34b 100644 --- a/python/CuTeDSL/cutlass/base_dsl/jit_executor.py +++ b/python/CuTeDSL/cutlass/base_dsl/jit_executor.py @@ -12,12 +12,16 @@ """ This module provides jit executor related classes """ + import ctypes import inspect import io -from typing import get_origin - -import numpy as np +from typing import Union, Optional +import weakref +import threading +import collections +import os +from dataclasses import dataclass # MLIR modules imports from .._mlir import ir @@ -32,44 +36,178 @@ from .utils.logger import log from .utils.timer import timer -class CudaSingleModule: - def __init__(self, cuda_module, kernel_ptr): +class CudaModuleAndKernel: + """A loaded CUDA kernel and its metadata.""" + + def __init__(self, sym, cuda_module, kernel, attrs): + self.sym = sym self.cuda_module = cuda_module - self.kernel_ptr = kernel_ptr + self.kernel = kernel + self.attrs = attrs -class CudaModules: - def __init__(self, modules, args): - # list of CudaSingleModule - self.modules = modules - # extra kernel ptr arguments for launch - self.args = args +def get_escaped_cubin_bytes(cubin_data): + """This function escapes cubin data from mlir raw bytecode to executable binary bytes""" + + def ishex(inp): + return (0x30 <= inp < 0x3A) or (0x41 <= inp < 0x47) or (0x61 <= inp < 0x67) + + converted = bytearray() + idx = 0 + while idx < len(cubin_data): + # escape the original bytes + if cubin_data[idx] == 0x5C: + # if data of idx is b'\\' + if ishex(cubin_data[idx + 1]) and ishex(cubin_data[idx + 2]): + converted += bytearray.fromhex(cubin_data[idx + 1 : idx + 3].decode()) + idx += 3 + elif cubin_data[idx + 1] == 0x5C: + converted.append(cubin_data[idx]) + idx += 2 + else: + # no escape, directly write + converted.append(cubin_data[idx]) + idx += 1 + return bytes(converted) -class JitExecutor: - def __init__( - self, - dsl, - engine, - capi_func, - ir_module, - args_spec, - function_name, - cuda_modules: CudaModules = None, - jit_time_profiling=False, - ): - self.dsl = dsl - self.engine = engine - self.capi_func = capi_func - self.ir_module = ir_module - self.args_spec = args_spec +def walk_module_and_get_cubin_data(module, sym, callback): + """This function is used to walk gpu binary op, extract the cubin inside, and process cubin data with callback.""" + + def walk_gpu_binary_op(op): + if op.name != "gpu.binary": + return ir.WalkResult.ADVANCE + s = io.BytesIO() + op.write_bytecode(s) + cubin_data = s.getvalue() + if sym.encode() not in cubin_data: + return ir.WalkResult.ADVANCE + + if "kernels" != op.opview.sym_name.value and sym != op.opview.sym_name.value: + return ir.WalkResult.ADVANCE + # function symbol of kernel(gpu.launch_func) is equal to sym name in mlir + func_sym = sym + if sym == op.opview.sym_name.value and not sym.endswith("_kernel"): + func_sym = sym.rsplit("_", 1)[0] + + cubin_data = cubin_data.split(b'bin = "')[1].split(b'">')[0] + cubin_data = get_escaped_cubin_bytes(cubin_data) + callback(sym, func_sym, cubin_data) + return ir.WalkResult.ADVANCE + + module.operation.walk(walk_gpu_binary_op) + + +def load_kernels_from_ir_module(module, kernel_info) -> list[CudaModuleAndKernel]: + """Loads all kernels from the IR module that match the given set of symbols.""" + if not kernel_info: + return [] # no modules + + # don't sort because the external kernel pointers are recorded in the order called in ir module. + kernel_symbols = tuple(kernel_info.keys()) + + # load cuda module/get function pointer from module and cache + kernel_modules = collections.OrderedDict() + for sym in kernel_symbols: + log().debug(f"Loading CUDA module for symbol: {sym}") + + def walk_callback(sym, func_sym, cubin_data): + if sym in kernel_modules: + log().debug(f"Skipping already loaded symbol: {sym}") + + cubin_module = cuda_helpers.load_library_data(cubin_data) + kernel = cuda_helpers.get_library_kernel(cubin_module, func_sym) + + # Setup attributes we want applied to the loaded kernel functions. + # A copy is made so we can update one of the attributes. + attrs = dict(kernel_info[sym]) + if cuda_helpers.get_driver_version() >= 11080: + attrs[ + cuda_helpers.cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED + ] = 1 + + kernel_modules[sym] = CudaModuleAndKernel(sym, cubin_module, kernel, attrs) + + walk_module_and_get_cubin_data(module, sym, walk_callback) + + return list(kernel_modules.values()) + + +class ExecutionArgs: + """Helper that wraps the function signature spec to filter exeuction and compile time arguments.""" + + def __init__(self, spec, function_name): self.function_name = function_name - if args_spec is not None: - self.original_args_spec = args_spec - self.args_spec = self.filter_runtime_arg_spec(args_spec) - # cuda kernels - self.cuda_modules = cuda_modules - self.jit_time_profiling = jit_time_profiling + self.args_spec = spec + if spec is not None: + self.args_spec = self.filter_runtime_arg_spec(spec) + self.original_args_spec = spec + + def generate_execution_args(self, args, kwargs): + """ + This function is the prune version of `generate_mlir_function_types` which only generates execution args + to get rid of mlir context. + """ + args_spec = self.args_spec + + # Process positional arguments with defaults + rectified_args = list(args) + if args_spec.defaults and len(args) < len(args_spec.args): + rectified_args.extend(args_spec.defaults[len(args) - len(args_spec.args) :]) + for k, v in kwargs.items(): + if k in args_spec.args: + idx = args_spec.args.index(k) + if idx < len(rectified_args): + rectified_args[idx] = v + else: + rectified_args.append(v) + + # Process keyword arguments + rectified_kwargs = {k: v for k, v in kwargs.items() if k not in args_spec.args} + if args_spec.kwonlydefaults and len(rectified_kwargs) < len( + args_spec.kwonlyargs + ): + rectified_kwargs.update(args_spec.kwonlydefaults) + + # args/kwargs must match arg_specs + if len(rectified_args) != len(args_spec.args) or len(rectified_kwargs) != len( + args_spec.kwonlyargs + ): + raise DSLRuntimeError( + "input args/kwargs length does not match runtime function signature!", + context={ + "input args length": len(rectified_args), + "input kwargs length": len(rectified_kwargs), + "function signature args length": len(args_spec.args), + "function signature kwonlyargs length": len(args_spec.kwonlyargs), + }, + ) + + exe_args = [] + adapted_args = [] + input_args = rectified_args + list(rectified_kwargs.values()) + input_arg_names = args_spec.args + args_spec.kwonlyargs + for arg, arg_name in zip(input_args, input_arg_names): + # short-cut for args already converted + if hasattr(arg, "__c_pointers__"): + exe_args.extend(arg.__c_pointers__()) + continue + + arg_type = args_spec.annotations.get(arg_name, None) + + # Implicit cast to NumericMeta + if isinstance(arg_type, t.NumericMeta): + arg = t.cast(arg, arg_type) + else: + # If not any known type, try registered adapter to do the conversion + adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) + if adapter: + arg = adapter(arg) + adapted_args.append(arg) + + exe_args.extend(get_c_pointers(arg)) + + return exe_args, adapted_args def filter_runtime_arg_spec(self, arg_spec: inspect.FullArgSpec): runtime_args = [] @@ -131,19 +269,13 @@ class JitExecutor: annotations=runtime_annotations, ) - def __del__(self): - if self.cuda_modules: - cuda_modules = [module.cuda_module for module in self.cuda_modules.modules] - for module in set(cuda_modules): - cuda_helpers.unload_cubin_module(module) - - def get_constexpr_args(self) -> list[dict[str, int | str]]: + def get_constexpr_args(self) -> list[dict[str, Union[int, str]]]: """ This function returns the constexpr args that have been pruned from the original function signature. The return type is a list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name). :return: list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name). - :rtype: list[dict[str, int | str]] + :rtype: list[dict[str, Union[int, str]]] """ if self.original_args_spec is None: return list() @@ -160,198 +292,261 @@ class JitExecutor: ) return constexpr_args - def generate_execution_args(self, args, kwargs, args_spec: inspect.FullArgSpec): - """ - This function is the prune version of `generate_mlir_function_types` which only generates execution args - to get rid of mlir context. - """ - # Process positional arguments with defaults - rectified_args = list(args) - if args_spec.defaults and len(args) < len(args_spec.args): - rectified_args.extend(args_spec.defaults[len(args) - len(args_spec.args) :]) - for k, v in kwargs.items(): - if k in args_spec.args: - idx = args_spec.args.index(k) - if idx < len(rectified_args): - rectified_args[idx] = v - else: - rectified_args.append(v) +class JitExecuteContext: + """Holds device specific context for execution.""" - # Process keyword arguments - rectified_kwargs = {k: v for k, v in kwargs.items() if k not in args_spec.args} - if args_spec.kwonlydefaults and len(rectified_kwargs) < len( - args_spec.kwonlyargs - ): - rectified_kwargs.update(args_spec.kwonlydefaults) + def __init__( + self, + module: "JitModule", + kernel_fns=[], + context: Optional[cuda_helpers.DevicePrimaryContext] = None, + ): + self.module = module + self.kernel_functions = kernel_fns + self.kernel_functions_ptrs = [ctypes.c_void_p(k.getPtr()) for k in kernel_fns] + self.context = context - # args/kwargs must match arg_specs - if len(rectified_args) != len(args_spec.args) or len(rectified_kwargs) != len( - args_spec.kwonlyargs - ): - raise DSLRuntimeError( - "input args/kwargs length does not match runtime function signature!", - context={ - "input args length": len(rectified_args), - "input kwargs length": len(rectified_kwargs), - "function signature args length": len(args_spec.args), - "function signature kwonlyargs length": len(args_spec.kwonlyargs), - }, - ) - exe_args = [] - adapted_args = [] - input_args = rectified_args + list(rectified_kwargs.values()) - input_arg_names = args_spec.args + args_spec.kwonlyargs - for arg, arg_name in zip(input_args, input_arg_names): - # short-cut for args already converted - if hasattr(arg, "__c_pointers__"): - exe_args.extend(arg.__c_pointers__()) - continue +class JitModule: + """Holds the execution engine and cuda modules.""" - arg_type = args_spec.annotations.get(arg_name, None) + def __init__( + self, + engine, + capi_func, + args_spec: ExecutionArgs, + modules: list[CudaModuleAndKernel], + ): + self.engine = engine + self.capi_func = capi_func + self.args_spec = args_spec + self.cuda_modules = modules + self._unloaded = False - # Implicit cast to NumericMeta - if isinstance(arg_type, t.NumericMeta): - arg = t.cast(arg, arg_type) - else: - # If not any known type, try registered adapter to do the conversion - adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) - if adapter: - arg = adapter(arg) - adapted_args.append(arg) + def get_device_execute_context(self, device=None) -> JitExecuteContext: + if self._unloaded: + raise RuntimeError(f"Can not get executor for unloaded module.") - exe_args.extend(get_c_pointers(arg)) + # Host only code no need to setup kernels + if not self.cuda_modules: + return JitExecuteContext(self) - return exe_args, adapted_args + # We need a device at this point so get one if not provided. + if device is None: + device = cuda_helpers.get_current_device() + elif isinstance(device, int): + device = cuda_helpers.get_device(device) - def __call__(self, *args, **kwargs): - exe_args, adapted_args = self.generate_execution_args( - args, kwargs, self.args_spec - ) + # Activate a primary context for the device: + context = cuda_helpers.DevicePrimaryContext(device) - self.run_compiled_program(exe_args) + # Get kernel functions from the kernels + kernel_fns = [] + for m in self.cuda_modules: + fn = cuda_helpers.get_function_from_kernel(m.kernel) + kernel_fns.append(fn) + + # Set attributes for the kernel function + for attr, val in m.attrs.items(): + cuda_helpers.set_kernel_attribute(fn, attr, val) + + # This instance will "own" a reference to the primary device context. + # It will release the the reference once its no longer alive or + # an explicit call to unload is made. + # + # The default module loading mode is CU_MODULE_LAZY_LOADING so + # the module will not be loaded to the device until the first call + # to execute it. # This can be modified using CUDA_MODULE_LOADING + # environment variable. + return JitExecuteContext(self, kernel_fns, context) + + def unload(self): + try: + for m in set([m.cuda_module for m in self.cuda_modules]): + cuda_helpers.unload_library(m) + self.cuda_modules.clear() + finally: + self._unloaded = True + + def __del__(self): + self.unload() + + +class JitExecutor: + """An executable function that can be called to launch a device kernel. + + JitExecutor is tired to a specific device context and should only be called + in a context on that device. + """ + + def __init__( + self, + jit_module: JitModule, + exec_context: JitExecuteContext, + jit_time_profiling: bool, + ): + # JitExecutor will keep JitCompiledFunction alive so that the underlying + # ExecutionEngine and module data is not discarded until runtime callables + # are garbage collected. + self.jit_module = jit_module + self.exec_context = exec_context + self.profiler = timer(enable=jit_time_profiling) # Assume each execution args has type `c_void_p` to reduce the overhead of `ctypes.cast`. - def get_invoke_packed_args(self, exe_args): - if self.cuda_modules: - exe_args += self.cuda_modules.args + def _get_invoke_packed_args(self, exe_args): + exe_args += self.exec_context.kernel_functions_ptrs packed_args = (ctypes.c_void_p * len(exe_args))() for argNum in range(len(exe_args)): packed_args[argNum] = exe_args[argNum] return packed_args + def generate_execution_args(self, *args, **kwargs): + return self.jit_module.args_spec.generate_execution_args(args, kwargs) + def run_compiled_program(self, exe_args): - if self.jit_time_profiling: - profiler = timer(enable=True) + try: + packed_args = self.profiler(self._get_invoke_packed_args)(exe_args) + self.profiler(self.jit_module.capi_func)(packed_args) + except Exception as e: + raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e) + + def __call__(self, *args, **kwargs): + exe_args, adapted_args = self.generate_execution_args(*args, **kwargs) + self.run_compiled_program(exe_args) + + +@dataclass +class JitFunctionArtifacts: + """Holds artifacts for a JIT-compiled function.""" + + PTX: str + CUBIN: str + MLIR: str + + def __post_init__(self): + if self.PTX is not None and os.path.exists(self.PTX): try: - packed_args = profiler(self.get_invoke_packed_args)(exe_args) - profiler(self.capi_func)(packed_args) - except Exception as e: - raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e) - else: + with open(self.PTX, "r") as f: + self.PTX = f.read() + except (IOError, OSError) as e: + raise DSLRuntimeError(f"Failed to read PTX file '{self.PTX}': {e}") + if self.CUBIN is not None and os.path.exists(self.CUBIN): try: - packed_args = self.get_invoke_packed_args(exe_args) - self.capi_func(packed_args) - except Exception as e: - raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e) + with open(self.CUBIN, "rb") as f: + self.CUBIN = f.read() + except (IOError, OSError) as e: + raise DSLRuntimeError(f"Failed to read CUBIN file '{self.CUBIN}': {e}") + if self.MLIR is not None and os.path.exists(self.MLIR): + try: + with open(self.MLIR, "r") as f: + self.MLIR = f.read() + except (IOError, OSError) as e: + raise DSLRuntimeError(f"Failed to read MLIR file '{self.MLIR}': {e}") - def update_jit_cuda_modules(self, kernel_symbols): - # preload cuda module from compiled cubin in ir and store to jit_executor.kernels. - if len(kernel_symbols) > 0: - extra_args = [] - module = self.ir_module - cuda_kernel_cache = dict() - cuda_driver_version = cuda_helpers.get_driver_version() - for sym in kernel_symbols: - if sym not in cuda_kernel_cache: - log().debug(f"Loading CUDA module for symbol: {sym}") - # load cuda module/get function pointer from module and cache - def walk_callback(sym, func_sym, cubin_data): - cubin_module = cuda_helpers.load_cubin_module_data(cubin_data) - kernel_ptr = cuda_helpers.get_kernel_function( - cubin_module, func_sym - ) - # Enable non-portable cluster size for CUDA version 11.8 or higher. - if cuda_driver_version >= 11080: - cuda_helpers.set_kernel_attribute( - kernel_ptr, - cuda_helpers.cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, - 1, - ) - cuda_kernel_cache[sym] = CudaSingleModule( - cubin_module, kernel_ptr - ) +class JitCompiledFunction: + """Holds a compiled function.""" - self.walk_module_and_get_cubin_data(module, sym, walk_callback) - else: - log().debug(f"Symbol {sym} already in cache") - # check if kernel is empty. - if sym in cuda_kernel_cache: - extra_args.append( - ctypes.c_void_p(cuda_kernel_cache[sym].kernel_ptr.getPtr()) - ) - # store to the jit result if jit result is cached. - self.cuda_modules = CudaModules(cuda_kernel_cache.values(), extra_args) + def __init__( + self, + ir_module, + engine, + capi_func, + args_spec, + function_name, + kernel_info, + jit_time_profiling, + jit_function_artifacts, + ): + self.ir_module = ir_module + self.engine = engine + self.capi_func = capi_func + self.function_name = function_name + self.kernel_info = kernel_info + if args_spec is not None: + self.args_spec = ExecutionArgs(args_spec, self.function_name) + self.jit_time_profiling = jit_time_profiling - return self + assert ( + isinstance(jit_function_artifacts, JitFunctionArtifacts) + or jit_function_artifacts is None + ) + self.artifacts = jit_function_artifacts - def _get_escaped_cubin_bytes(self, cubin_data): - """This function escapes cubin data from mlir raw bytecode to executable binary bytes""" + # This runtime state is stored here so that we can preserve the module + # in the compiler cache. Callers can extend the lifetime of the module + # by creating and retaining the executor. + self.jit_module = None + self._executor_lock = threading.RLock() + self._default_executor = None - def ishex(inp): - return ( - inp in range(0x30, 0x3A) - or inp in range(0x61, 0x67) - or inp in range(0x41, 0x47) - ) + @property + def __ptx__(self): + """Returns the PTX code of the JIT-compiled function.""" + return self.artifacts.PTX if self.artifacts is not None else None - converted = bytearray() - idx = 0 - while idx < len(cubin_data): - # escape the original bytes - if cubin_data[idx] == 0x5C: - # if data of idx is b'\\' - if ishex(cubin_data[idx + 1]) and ishex(cubin_data[idx + 2]): - converted += bytearray.fromhex( - cubin_data[idx + 1 : idx + 3].decode() - ) - idx += 3 - elif cubin_data[idx + 1] == 0x5C: - converted.append(cubin_data[idx]) - idx += 2 - else: - # no escape, directly write - converted.append(cubin_data[idx]) - idx += 1 - return bytes(converted) + @property + def __cubin__(self): + """Returns the CUBIN data of the JIT-compiled function.""" + return self.artifacts.CUBIN if self.artifacts is not None else None - def walk_module_and_get_cubin_data(self, module, sym, callback): - """This function is used to walk gpu binary op, extract the cubin inside, and process cubin data with callback.""" + @property + def __mlir__(self): + """Returns the MLIR code of the JIT-compiled function.""" + return self.artifacts.MLIR if self.artifacts is not None else None - def walk_gpu_binary_op(op): - if op.name != "gpu.binary": - return ir.WalkResult.ADVANCE - s = io.BytesIO() - op.write_bytecode(s) - cubin_data = s.getvalue() - if sym.encode() not in cubin_data: - return ir.WalkResult.ADVANCE + def to(self, device=None) -> JitExecutor: + """Returns an executable function bound to the given device. - if ( - "kernels" != op.opview.sym_name.value - and sym != op.opview.sym_name.value - ): - return ir.WalkResult.ADVANCE - # function symbol of kernel(gpu.launch_func) is equal to sym name in mlir - func_sym = sym - if sym == op.opview.sym_name.value and not sym.endswith("_kernel"): - func_sym = sym.rsplit("_", 1)[0] + For multi-device execution this method can be called for each device where + the kernel will run. - cubin_data = cubin_data.split(b'bin = "')[1].split(b'">')[0] - cubin_data = self._get_escaped_cubin_bytes(cubin_data) - callback(sym, func_sym, cubin_data) - return ir.WalkResult.ADVANCE + :param device: Specifies the device for the executor. If None the current device is used. + :type device: Optional[Union[int, CUdevice]] + :return: A callable executor function. + :rtype: JitExecutor + """ + with self._executor_lock: + # We need to ensure that the modules are loaded if not already + if self.jit_module is None: + cuda_modules = load_kernels_from_ir_module( + self.ir_module, self.kernel_info + ) + self.jit_module = JitModule( + self.engine, self.capi_func, self.args_spec, cuda_modules + ) - module.operation.walk(walk_gpu_binary_op) + # Create a new executor that will be tied to a device context + # n.b. host only moduels do not load device specific modules or context. + context = self.jit_module.get_device_execute_context(device) + return JitExecutor(self.jit_module, context, self.jit_time_profiling) + + def generate_execution_args(self, *args, **kwargs): + return self.args_spec.generate_execution_args(args, kwargs) + + def __call__(self, *args, **kwargs): + """Executes the jit-compiled function under the currently active CUDA context. + + Calling this method multiple devices is not allowed and will result in unexpected + CUDA errors. If you need to call the kernel on multiple devices use `to` + to return a per-device function. + """ + exe_args, adapted_args = self.generate_execution_args(*args, **kwargs) + return self.run_compiled_program(exe_args) + + def run_compiled_program(self, exe_args): + """Executes the jit-compiled function under the currently active CUDA context. + + Calling this method multiple devices is not allowed and will result in unexpected + CUDA errors. If you need to call the kernel on multiple devices use `to` + to return a per-device function. + """ + with self._executor_lock: + if self._default_executor is None: + log().debug("Creating default executor.") + # We use a weak reference here so that this instance does not keep this + # object alive as it hold a reference to self. + proxy_self = weakref.proxy(self) + self._default_executor = proxy_self.to(None) + self._default_executor.run_compiled_program(exe_args) diff --git a/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py b/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py index 97ae778c..364333dd 100644 --- a/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py +++ b/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py @@ -13,7 +13,6 @@ This module provides CUDA Python helper functions """ - from functools import lru_cache from dataclasses import dataclass from typing import List, Optional @@ -40,6 +39,14 @@ from .jit_arg_adapters import JitArgAdapterRegistry def _cudaGetErrorEnum(error): + """ + Get the error name of a CUDA error. + :param error: The CUDA error. + :type error: cuda.CUresult or nvrtc.nvrtcResult + :raise DSLRuntimeError: If the error type is unknown. + :return: The error name. + :rtype: str + """ if isinstance(error, cuda.CUresult): err, name = cuda.cuGetErrorName(error) return name if err == cuda.CUresult.CUDA_SUCCESS else "" @@ -50,7 +57,18 @@ def _cudaGetErrorEnum(error): def _get_gpu_arch_info(major, minor): - """Get GPU architecture information and compatibility details.""" + """ + Get GPU architecture information and compatibility details. + Return [Unknown, f"sm_{major}{minor}", [f"sm_{major}{minor}"]] if the major and minor version is not in the map. + :param major: The major version of the CUDA device. + usually obtained by calling cuda.cuDeviceGetAttribute(cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device) + :type major: int + :param minor: The minor version of the CUDA device. + usually obtained by calling cuda.cuDeviceGetAttribute(cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device) + :type minor: int + :return: The GPU architecture information. + :rtype: tuple(str, str, list[str]) + """ gpu_arch_map = { (7, 0): ("Volta", "sm_70", ["sm_70"]), # V100 (7, 5): ("Turing", "sm_75", ["sm_75"]), # RTX 20 Series, Quadro RTX @@ -68,9 +86,13 @@ def _get_gpu_arch_info(major, minor): def get_compute_capability_major_minor(device_id: int = 0): """ - Returns the compute capability of the CUDA device as a tuple of (major, minor). - For example: (8, 0) for Ampere, (9, 0) for Hopper, (10, 0) for Blackwell. - Returns None on failure. + Get the compute capability of the CUDA device. + :param device_id: The ID of the CUDA device. + :type device_id: int + :raise DSLRuntimeError: If the CUDA operation fails. + :return: The compute capability of the CUDA device as a tuple of (major, minor). + :rtype: tuple(int, int) + Example: (8, 0) for Ampere, (9, 0) for Hopper, (10, 0) for Blackwell. """ try: checkCudaErrors(cuda.cuInit(0)) @@ -95,7 +117,35 @@ def get_compute_capability_major_minor(device_id: int = 0): @dataclass class DeviceInfo: - """Data class to store CUDA device information.""" + """ + Data class to store CUDA device information. + + :param device_count: The number of CUDA devices. + :type device_count: int + :param current_device: The current CUDA device. + :type current_device: int + :param device_name: The name of the CUDA device. + :type device_name: str + :param major_version: The major version of the CUDA device. + :type major_version: int + :param minor_version: The minor version of the CUDA device. + :type minor_version: int + :param arch_name: The name of the CUDA architecture. + :type arch_name: str + :param sm_arch: The SM architecture of the CUDA device. + :type sm_arch: str + :param compatible_archs: The compatible SM architectures of the CUDA device. + :type compatible_archs: list[str] + :param memory_gb: The total memory of the CUDA device in GB. + :type memory_gb: float + :param target_arch: The target architecture of the CUDA device. + :type target_arch: str + :param error_message: The error message of the CUDA device. + :type error_message: str + :param initialization_failed: Whether the CUDA initialization failed. + :type initialization_failed: bool + + """ device_count: int = 0 current_device: int = 0 @@ -113,6 +163,18 @@ class DeviceInfo: def pretty_str(self) -> str: """ Convert DeviceInfo to a formatted string for display. + :return: The formatted string. + :rtype: str + Example: + On success: + CUDA devices available: (current: ) + - Architecture: () + - Compatible SM archs: + - Total Memory: GB + On failure: + 1. CUDA initialization failed + 2. Failed to get GPU info: + 3. No devices available """ info = "" @@ -144,7 +206,8 @@ class DeviceInfo: def get_device_info() -> DeviceInfo: """ Get detailed information about CUDA devices. - Returns a DeviceInfo dataclass with device information. + :return: A DeviceInfo dataclass with device information. + :rtype: DeviceInfo """ device_info = DeviceInfo() @@ -200,10 +263,7 @@ def get_device_info() -> DeviceInfo: # Get memory info try: - total_mem = cuda.cuDeviceGetAttribute( - cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_TOTAL_MEMORY, - device_info.current_device, - ) + total_mem = cuda.cuDeviceTotalMem(device_info.current_device) if total_mem[0].value == 0: device_info.memory_gb = total_mem[1] / ( 1024 * 1024 * 1024 @@ -221,7 +281,13 @@ def get_device_info() -> DeviceInfo: def checkCudaErrors(result): - """Check CUDA errors and provide detailed error messages.""" + """Check CUDA errors and provide detailed error messages. + :param result: The result of the CUDA operation. + :type result: tuple(CUresult, ...) + :raise DSLCudaRuntimeError: If the CUDA operation fails. + :return: The result of the CUDA operation, excluding the first element(CUresult) of the tuple + :rtype: tuple() + """ if result[0].value: error_code = result[0].value error_name = _cudaGetErrorEnum(result[0]) @@ -241,18 +307,51 @@ def checkCudaErrors(result): # ============================================================================= +def get_current_device(): + """ + Gets the current device on the active context. + :return: The current device. + :rtype: cuda.CUdevice + :raise DSLRuntimeError: If the CUDA operation fails. + """ + _log().info(f"cuCtxGetDevice") + dev = checkCudaErrors(cuda.cuCtxGetDevice()) + _log().info(f"{dev} <-- cuCtxGetDevice") + return dev + + +def get_device(device_id: int): + """ + Gets a device given its ordinal. + :param device_id: The ID of the device. + :type device_id: int + :return: The device. + :rtype: cuda.CUdevice + :raise DSLRuntimeError: If the CUDA operation fails. + """ + _log().info(f"cuDeviceGet {device_id}") + dev = checkCudaErrors(cuda.cuDeviceGet(device_id)) + _log().info(f"{dev} <-- cuDeviceGet") + return dev + + @lru_cache(maxsize=1) def initialize_cuda_context(device_id: int = 0, flags: int = 0): """ Initializes the CUDA context for a specified device. + :param device_id: The ID of the device. + :type device_id: int + :param flags: The flags for the CUDA context. + :type flags: int + :return: The context. + :rtype: cuda.CUcontext + :raise DSLRuntimeError: If the CUDA operation fails. """ # Initialize CUDA Driver API _log().info(f"cuInit {flags}") checkCudaErrors(cuda.cuInit(flags)) # Retrieve handle for device - _log().info(f"cuDeviceGet {device_id}") - cuDevice = checkCudaErrors(cuda.cuDeviceGet(device_id)) - _log().info(f"{cuDevice} <-- cuDeviceGet") + cuDevice = get_device(device_id) # Create context _log().info(f"cuCtxCreate {0} {cuDevice}") if cuda.CUDA_VERSION >= 13000: @@ -267,9 +366,52 @@ def initialize_cuda_context(device_id: int = 0, flags: int = 0): return context +def device_primary_context_retain(device): + """ + Retains the primary context on the device. + :param device: The device. + :type device: cuda.CUdevice + :return: The context. + :rtype: cuda.CUcontext + :raise DSLRuntimeError: If the CUDA operation fails. + """ + _log().info(f"cuDevicePrimaryCtxRetain {device}") + return checkCudaErrors(cuda.cuDevicePrimaryCtxRetain(device)) + + +def device_primary_context_release(device): + """ + Releases the primary context on the device. + :param device: The device. + :type device: cuda.CUdevice + :raise DSLRuntimeError: If the CUDA operation fails. + """ + _log().info(f"cuDevicePrimaryCtxRelease {device}") + checkCudaErrors(cuda.cuDevicePrimaryCtxRelease(device)) + + +class DevicePrimaryContext: + """ + Owns a reference to a device primary context and ensures it is released once + the object is no longer alive. + """ + + def __init__(self, device): + self.device = device + self.context = device_primary_context_retain(self.device) + + def __del__(self): + device_primary_context_release(self.device) + + def load_cubin_module(cubin_file): """ Loads a CUBIN file and returns the module. + :param cubin_file: The path to the CUBIN file. + :type cubin_file: str + :return: The module. + :rtype: cuda.CUmodule + :raise DSLRuntimeError: If the CUDA operation fails. """ # Load CUBIN file as binary data _log().info(f"read cubin {cubin_file}") @@ -286,6 +428,9 @@ def load_cubin_module(cubin_file): def unload_cubin_module(module): """ Unloads a CUBIN module. + :param module: The module. + :type module: cuda.CUmodule + :raise DSLRuntimeError: If the CUDA operation fails. """ _log().info(f"cuModuleUnload {module}") checkCudaErrors(cuda.cuModuleUnload(module)) @@ -294,6 +439,11 @@ def unload_cubin_module(module): def load_cubin_module_data(cubin_data): """ Loads a CUBIN from data and returns the module. + :param cubin_data: The binary data of the CUBIN. + :type cubin_data: bytes + :return: The module. + :rtype: cuda.CUmodule + :raise DSLRuntimeError: If the CUDA operation fails. """ # Load module data _log().info(f"cuModuleLoadData {np.char.array(cubin_data).ctypes.data}") @@ -306,6 +456,13 @@ def load_cubin_module_data(cubin_data): def get_kernel_function(module, kernel_name): """ Retrieves the kernel function from the module. + :param module: The module. + :type module: cuda.CUmodule + :param kernel_name: The name of the kernel. + :type kernel_name: str + :return: The kernel function. + :rtype: cuda.CUfunction + :raise DSLRuntimeError: If the CUDA operation fails. """ _log().info(f"cuModuleGetFunction {module} {kernel_name}") kernel = checkCudaErrors( @@ -315,9 +472,109 @@ def get_kernel_function(module, kernel_name): return kernel +def load_library(cubin_file): + """ + Loads a CUBIN file and returns the library. + :param cubin_file: The path to the CUBIN file. + :type cubin_file: str + :return: The library. + :rtype: cuda.CUlibrary + :raise DSLRuntimeError: If the CUDA operation fails. + """ + # Load CUBIN file as binary data + _log().info(f"read cubin {cubin_file}") + with open(cubin_file, "rb") as f: + cubin_data = f.read() + return load_library_data(cubin_data) + + +def unload_library(library): + """ + Unloads a CUBIN library. + :param library: The library. + :type library: cuda.CUlibrary + :raise DSLRuntimeError: If the CUDA operation fails. + """ + _log().info(f"cuLibraryUnload {library}") + checkCudaErrors(cuda.cuLibraryUnload(library)) + _log().info(f"cuLibraryUnload done {library}") + + +def load_library_data(cubin_data): + """ + Loads a CUBIN from data and returns the library. + :param cubin_data: The binary data of the CUBIN. + :type cubin_data: bytes + :return: The library. + :rtype: cuda.CUlibrary + :raise DSLRuntimeError: If the CUDA operation fails. + """ + # Load module data + _log().info(f"cuLibraryLoadData {np.char.array(cubin_data).ctypes.data}") + + library = checkCudaErrors( + cuda.cuLibraryLoadData( + np.char.array(cubin_data).ctypes.data, None, None, 0, None, None, 0 + ) + ) + return library + + +def get_library_kernel(library, kernel_name): + """ + Retrieves the kernel from the library. + :param library: The library. + :type library: cuda.CUlibrary + :param kernel_name: The name of the kernel. + :type kernel_name: str + :return: The kernel. + :rtype: cuda.CUfunction + :raise DSLRuntimeError: If the CUDA operation fails. + """ + _log().info(f"cuLibraryGetKernel {library} {kernel_name}") + kernel = checkCudaErrors( + cuda.cuLibraryGetKernel(library, bytes(kernel_name, "utf-8")) + ) + _log().info(f"{kernel} <-- cuLibraryGetKernel") + return kernel + + +def get_function_from_kernel(kernel): + """ + Retrieves the kernel function from the kernel. + :param kernel: The kernel. + :type kernel: cuda.CUfunction + :return: The kernel function. + :rtype: cuda.CUfunction + :raise DSLRuntimeError: If the CUDA operation fails. + """ + _log().info(f"cuKernelGetFunction {kernel}") + kernel_fn = checkCudaErrors(cuda.cuKernelGetFunction(kernel)) + _log().info(f"{kernel_fn} <-- cuKernelGetFunction") + + return kernel_fn + + def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size, kernel_args=None): """ Launches the CUDA kernel. + :param kernel: The kernel. + :type kernel: cuda.CUfunction + :param grid_dims: The grid dimensions. + :type grid_dims: tuple(int, int, int) + :param block_dims: The block dimensions. + :type block_dims: tuple(int, int, int) + :param stream: The stream. + :type stream: cuda.CUstream + :param smem_size: The shared memory size. + :type smem_size: int + :param kernel_args: The kernel arguments. + :type kernel_args: tuple + :raise DSLRuntimeError: If the CUDA operation fails. + Example: + ``` + launch_kernel(kernel, (1, 1, 1), (1, 1, 1), stream, 0, (1, 2, 3)) + ``` """ _log().info( f"cuLaunchKernel {kernel} grid={grid_dims} blocks={block_dims} smem_size={smem_size} stream={stream} {kernel_args}" @@ -342,6 +599,9 @@ def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size, kernel_args= def stream_sync(stream): """ Synchronizes the CUDA stream. + :param stream: The stream. + :type stream: cuda.CUstream + :raise DSLRuntimeError: If the CUDA operation fails. """ _log().info(f"cuStreamSynchronize {stream}") checkCudaErrors(cuda.cuStreamSynchronize(stream)) @@ -350,6 +610,11 @@ def stream_sync(stream): def stream_create(id=0): """ Creates the CUDA stream. + :param id: The ID of the stream. + :type id: int + :return: The stream. + :rtype: cuda.CUstream + :raise DSLRuntimeError: If the CUDA operation fails. """ _log().info(f"cuStreamCreate {id}") stream = checkCudaErrors(cuda.cuStreamCreate(id)) @@ -360,6 +625,9 @@ def stream_create(id=0): def stream_destroy(stream): """ Destroys the CUDA stream. + :param stream: The stream. + :type stream: cuda.CUstream + :raise DSLRuntimeError: If the CUDA operation fails. """ _log().info(f"cuStreamDestroy {stream}") checkCudaErrors(cuda.cuStreamDestroy(stream)) @@ -368,6 +636,9 @@ def stream_destroy(stream): def context_destroy(context): """ Destroys the CUDA context. + :param context: The context. + :type context: cuda.CUcontext + :raise DSLRuntimeError: If the CUDA operation fails. """ _log().info(f"cuCtxDestroy {context}") checkCudaErrors(cuda.cuCtxDestroy(context)) @@ -376,6 +647,13 @@ def context_destroy(context): def allocate(size_in_bytes: int, stream=None): """ Allocate device memory based on numpy host array size. + :param size_in_bytes: The size of the memory to allocate. + :type size_in_bytes: int + :param stream: The stream. + :type stream: cuda.CUstream + :return: The device memory. + :rtype: cuda.CUdeviceptr + :raise DSLRuntimeError: If the CUDA operation fails. """ _log().info("Allocate size_in_bytes=[%s] stream=[%s]", size_in_bytes, stream) if stream is None: @@ -389,6 +667,11 @@ def allocate(size_in_bytes: int, stream=None): def deallocate(device_pointer, stream=None): """ Deallocate the specified device memory pointer. + :param device_pointer: The device memory pointer. + :type device_pointer: cuda.CUdeviceptr + :param stream: The stream. + :type stream: cuda.CUstream + :raise DSLRuntimeError: If the CUDA operation fails. """ _log().info( "Deallocate device_pointer=[%s] stream=[%s]", hex(int(device_pointer)), stream @@ -401,7 +684,17 @@ def deallocate(device_pointer, stream=None): def memcpy_h2d(host_pointer, device_pointer, size_in_bytes, stream=None): """ - Copy data from host to device memory. + Copy data from host to device memory + if stream is None, the copy is synchronous otherwise it is asynchronous. + :param host_pointer: The host contiguous memory pointer. + :type host_pointer: cuda.CUdeviceptr + :param device_pointer: The device memory pointer. + :type device_pointer: cuda.CUdeviceptr + :param size_in_bytes: The size of the memory to copy. + :type size_in_bytes: int + :param stream: The stream. default to None. + :type stream: cuda.CUstream + :raise DSLRuntimeError: If the CUDA operation fails. """ _log().info( "Copy host-to-device host_pointer[%s] device_ptr=[%s] size_in_bytes=[%s] stream=[%s]", @@ -420,7 +713,17 @@ def memcpy_h2d(host_pointer, device_pointer, size_in_bytes, stream=None): def memcpy_d2h(host_pointer, device_pointer, size_in_bytes, stream=None): """ - Copy data from device to host memory. + Copy data from device to host memory + if stream is None, the copy is synchronous otherwise it is asynchronous. + :param host_pointer: The host contiguous memory pointer. + :type host_pointer: cuda.CUdeviceptr + :param device_pointer: The device memory pointer. + :type device_pointer: cuda.CUdeviceptr + :param size_in_bytes: The size of the memory to copy. + :type size_in_bytes: int + :param stream: The stream. default to None. + :type stream: cuda.CUstream + :raise DSLRuntimeError: If the CUDA operation fails. """ _log().info( "Copy device-host-to device_pointer=[%s] host_pointer[%s] size_in_bytes=[%s] stream=[%s]", @@ -438,21 +741,68 @@ def memcpy_d2h(host_pointer, device_pointer, size_in_bytes, stream=None): def default_stream(): + """ + Returns the default stream. + :return: The default stream. + :rtype: cuda.CUstream + """ return cuda.CUstream(0) +@lru_cache(maxsize=1) def get_driver_version(): """ Returns the CUDA driver version. + Note: the value is cached after the first call. + :return: The CUDA driver version. + :rtype: int + Example: + version = get_driver_version() + print(f"CUDA driver version: {version}") + >>> 12050 """ return checkCudaErrors(cuda.cuDriverGetVersion()) -def set_kernel_attribute(kernel, attribute, value): +def set_kernel_attribute(kernel, attribute, value, device=None): """ Sets a CUDA kernel attribute. + If the device is not provided, the attribute is set for the current device. + and cuda.cuFuncSetAttribute is called. + Otherwise, cuda.cuKernelSetAttribute is called. + :param kernel: The kernel. + :type kernel: cuda.CUfunction + :param attribute: The attribute. + :type attribute: cuda.CUfunction_attribute + :param value: The value. + :type value: int + :param device: The device. + :type device: cuda.CUdevice + :raise DSLRuntimeError: If the CUDA operation fails. """ - return checkCudaErrors(cuda.cuFuncSetAttribute(kernel, attribute, value)) + if device is None: + _log().info(f"cuFuncSetAttribute {kernel} {attribute} {value}") + return checkCudaErrors(cuda.cuFuncSetAttribute(kernel, attribute, value)) + else: + _log().info(f"cuKernelSetAttribute {attribute} {value} {kernel} {device}") + return checkCudaErrors( + cuda.cuKernelSetAttribute(attribute, value, kernel, device) + ) + + +def get_device_attribute(attribute, device_id: int = 0): + """ + Gets a CUDA device attribute. + :param attribute: The attribute. + :type attribute: cuda.CUdevice_attribute + :param device_id: The ID of the device. + :type device_id: int + :return: The attribute value. + :rtype: int + :raise DSLRuntimeError: If the CUDA operation fails. + """ + device = checkCudaErrors(cuda.cuDeviceGet(device_id)) + return checkCudaErrors(cuda.cuDeviceGetAttribute(attribute, device)) @JitArgAdapterRegistry.register_jit_arg_adapter(cuda.CUstream) diff --git a/python/CuTeDSL/cutlass/base_dsl/typing.py b/python/CuTeDSL/cutlass/base_dsl/typing.py index b46cff6d..db641811 100644 --- a/python/CuTeDSL/cutlass/base_dsl/typing.py +++ b/python/CuTeDSL/cutlass/base_dsl/typing.py @@ -473,6 +473,8 @@ class IntegerMeta(NumericMeta): np_dtype = np.bool_ elif width == 128: np_dtype = None + elif width == 4: + np_dtype = None elif signed: np_dtype = getattr(np, f"int{width}") else: @@ -1127,7 +1129,10 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): if isinstance(value, bool): res_type = Boolean elif isinstance(value, int): - res_type = Int32 + # Choose Int32 if it can represent the value, Int64 otherwise + res_type = ( + Int32 if (value <= 2147483647) and (value >= -2147483648) else Int64 + ) elif isinstance(value, float): res_type = Float32 elif isinstance(value, ArithValue): @@ -1229,11 +1234,13 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): T.i32(): Int32, T.i16(): Int16, T.i8(): Int8, + T.IntegerType.get_signless(4): Int4, T.si(128): Int128, T.si64(): Int64, T.si32(): Int32, T.si16(): Int16, T.si8(): Int8, + T.IntegerType.get_signed(4): Int4, T.ui(128): Uint128, T.ui64(): Uint64, T.ui32(): Uint32, @@ -1570,6 +1577,15 @@ class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T. raise TypeError("Negation, the operator `-` is not supported for boolean type") +class Int4( + Integer, + metaclass=IntegerMeta, + width=4, + signed=True, + mlir_type=lambda: T.IntegerType.get_signless(4), +): ... + + class Int8(Integer, metaclass=IntegerMeta, width=8, signed=True, mlir_type=T.i8): ... @@ -1661,8 +1677,17 @@ class BFloat16(Float, metaclass=FloatMeta, width=16, mlir_type=T.bf16): def __c_pointers__(self): if not isinstance(self.value, float): raise ValueError("only float is supported") - - return Float.__c_pointers__(self) + # Convert float32 to bfloat16 representation + # First convert the value to float32 bit representation + f32_val = np.float32(self.value) + # Get the 32-bit integer representation + bits = f32_val.view(np.uint32) + # Truncate to 16 bits, keeping the high 16 bits + bf16_bits = np.uint16(bits >> 16) + # Create a short (16-bit int) with those bits + c_val = ctypes.c_short(bf16_bits) + c_pointer = ctypes.cast(ctypes.pointer(c_val), ctypes.c_void_p) + return [c_pointer] class Float8E5M2(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E5M2): ... @@ -1703,6 +1728,7 @@ _unsupported_dst_float_types = [ ALL_DTYPES = { + Int4, Int8, Int16, Int32, @@ -1732,7 +1758,7 @@ __STR_TO_DTYPE__ = {dt.__name__: dt for dt in ALL_DTYPES} def dtype(dtype_) -> Type[Numeric]: t = None - if const_expr(isinstance(dtype_, str) and dtype_ in __STR_TO_DTYPE__): + if isinstance(dtype_, str) and dtype_ in __STR_TO_DTYPE__: t = __STR_TO_DTYPE__[dtype_] else: raise TypeError(f"can't interpret {dtype_} as data type") @@ -1930,6 +1956,7 @@ __all__ = [ "Int64", "Int128", "Int8", + "Int4", "Uint8", "Uint16", "Uint32", diff --git a/python/CuTeDSL/cutlass/base_dsl/utils/logger.py b/python/CuTeDSL/cutlass/base_dsl/utils/logger.py index d4e4b4ed..418015a7 100644 --- a/python/CuTeDSL/cutlass/base_dsl/utils/logger.py +++ b/python/CuTeDSL/cutlass/base_dsl/utils/logger.py @@ -78,4 +78,22 @@ def setup_log( return logger +def _init_logger_with_client_name(prefix): + from ..env_manager import LogEnvironmentManager + + log_env = LogEnvironmentManager(prefix) + + if log_env.log_to_console == False and log_env.jit_time_profiling: + log_env.log_to_console = True + log_env.log_level = 20 # info level + + setup_log( + prefix, + log_env.log_to_console, + log_env.log_to_file, + f"{prefix}.log", + log_env.log_level, + ) + + logger = setup_log("generic") diff --git a/python/CuTeDSL/cutlass/base_dsl/utils/stacktrace.py b/python/CuTeDSL/cutlass/base_dsl/utils/stacktrace.py index d2091098..12afecae 100644 --- a/python/CuTeDSL/cutlass/base_dsl/utils/stacktrace.py +++ b/python/CuTeDSL/cutlass/base_dsl/utils/stacktrace.py @@ -10,7 +10,7 @@ # is strictly prohibited. """ - This module provides stacktrace helper functions +This module provides stacktrace helper functions """ import os diff --git a/python/CuTeDSL/cutlass/base_dsl/utils/timer.py b/python/CuTeDSL/cutlass/base_dsl/utils/timer.py index f41d3f74..d9898d1d 100644 --- a/python/CuTeDSL/cutlass/base_dsl/utils/timer.py +++ b/python/CuTeDSL/cutlass/base_dsl/utils/timer.py @@ -12,6 +12,7 @@ """ This module provides a timing helper functions """ + from functools import wraps from .logger import log diff --git a/python/CuTeDSL/cutlass/cute/__init__.py b/python/CuTeDSL/cutlass/cute/__init__.py index 8702ed91..ea216b99 100644 --- a/python/CuTeDSL/cutlass/cute/__init__.py +++ b/python/CuTeDSL/cutlass/cute/__init__.py @@ -26,6 +26,7 @@ from .typing import ( XTuple, Tiler, Layout, + ComposedLayout, Pointer, Tensor, ) @@ -35,47 +36,37 @@ from .typing import * from .core import ( assume, - is_integer, - is_int_tuple, is_static, size, + static, + get_leaves, has_underscore, slice_, make_ptr, make_layout, recast_layout, - make_fragment_like, depth, rank, - flatten_to_tuple, flatten, - unflatten, - product, - product_like, shape, size_in_bytes, make_identity_layout, make_ordered_layout, + make_layout_like, make_composed_layout, make_layout_tv, make_swizzle, + make_sparse_elem, recast_ptr, - make_tensor, - make_identity_tensor, - make_fragment, - recast_tensor, get, select, front, is_major, leading_dim, - find, - find_if, coalesce, group_modes, cosize, dice, - product_each, prepend, append, prepend_ones, @@ -83,9 +74,7 @@ from .core import ( ceil_div, slice_and_offset, crd2idx, - domain_offset, - elem_less, - transform_leaf, + idx2crd, filter_zeros, filter, tile_to_shape, @@ -109,7 +98,65 @@ from .core import ( local_partition, local_tile, printf, + # Wrapper classes + Swizzle, + E, + # User defined struct + struct, + pretty_str, + make_layout_image_mask, + repeat, + repeat_as_tuple, + repeat_like, + round_up, + is_congruent, + is_weakly_congruent, + ScaledBasis, + get_divisibility, + Ratio, +) + +from .tuple import ( + transform_leaf, + find_if, + find, + flatten_to_tuple, + unflatten, + product, + product_like, + product_each, + elem_less, +) +from .tensor import ( + TensorSSA, + ReductionOp, + make_tensor, + make_identity_tensor, + make_fragment, + make_fragment_like, + make_rmem_tensor_like, + make_rmem_tensor, + recast_tensor, + domain_offset, print_tensor, + full, + full_like, + empty_like, + ones_like, + zeros_like, + where, + any_, + all_, +) +from .atom import ( + Atom, + MmaAtom, + CopyAtom, + TiledCopy, + TiledMma, + ThrMma, + ThrCopy, + make_atom, # tiled mma/tiled copy make_mma_atom, make_tiled_mma, @@ -122,48 +169,16 @@ from .core import ( make_tiled_copy_B, make_tiled_copy_C, make_tiled_copy_C_atom, - basic_copy, - basic_copy_if, - autovec_copy, - copy, + make_cotiled_copy, copy_atom_call, - gemm, - # Wrapper classes - ComposedLayout, - Swizzle, - E, - Atom, - MmaAtom, - CopyAtom, - TiledCopy, - TiledMma, - TensorSSA, - ReductionOp, - full, - full_like, - empty_like, - ones_like, - zeros_like, - where, - any_, - all_, - # User defined struct - struct, - pretty_str, - make_layout_image_mask, - repeat_like, - round_up, - is_congruent, - is_weakly_congruent, - ScaledBasis, - get_divisibility, - Ratio, ) +from .algorithm import gemm, copy, basic_copy, basic_copy_if, autovec_copy, prefetch from . import arch from . import nvgpu from . import testing from . import runtime +from . import math # Export all math ops without "math." from .math import * @@ -175,7 +190,15 @@ from .. import cutlass_dsl as _dsl jit = _dsl.CuTeDSL.jit kernel = _dsl.CuTeDSL.kernel register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter -compile = _dsl.compile +compile = _dsl.CompileCallable() +OptLevel = _dsl.OptLevel +PtxasOptions = _dsl.PtxasOptions +EnableAssertions = _dsl.EnableAssertions +GenerateLineInfo = _dsl.GenerateLineInfo +KeepCUBIN = _dsl.KeepCUBIN +KeepPTX = _dsl.KeepPTX +GPUArch = _dsl.GPUArch +LinkLibraries = _dsl.LinkLibraries # Explicitly export all symbols for documentation generation __all__ = [ @@ -186,22 +209,22 @@ __all__ = [ "ComposedLayout", "Swizzle", "E", + "ScaledBasis", "Atom", "MmaAtom", "CopyAtom", "TiledCopy", "TiledMma", + "ThrMma", + "ThrCopy", "TensorSSA", + "ReductionOp", # Basic utility functions "assume", "is_integer", "is_int_tuple", "is_static", - "size", "has_underscore", - "slice_", - "depth", - "rank", "shape", "printf", "print_tensor", @@ -211,6 +234,7 @@ __all__ = [ "recast_layout", "make_identity_layout", "make_ordered_layout", + "make_layout_like", "make_composed_layout", "make_layout_tv", "make_layout_image_mask", @@ -220,6 +244,8 @@ __all__ = [ "make_identity_tensor", "make_fragment", "make_fragment_like", + "make_rmem_tensor", + "make_rmem_tensor_like", "recast_ptr", "recast_tensor", # Tensor manipulation @@ -230,6 +256,7 @@ __all__ = [ "leading_dim", "find", "find_if", + "transform_leaf", "coalesce", "group_modes", "cosize", @@ -237,6 +264,7 @@ __all__ = [ # Tuple operations "flatten_to_tuple", "flatten", + "unflatten", "product", "product_like", "product_each", @@ -244,6 +272,7 @@ __all__ = [ "append", "prepend_ones", "append_ones", + "elem_less", # Math operations "ceil_div", "round_up", @@ -251,7 +280,6 @@ __all__ = [ "slice_and_offset", "crd2idx", "domain_offset", - "elem_less", "filter_zeros", "filter", "tile_to_shape", @@ -280,18 +308,27 @@ __all__ = [ "tiled_divide", "local_partition", "local_tile", - # MMA and Copy operations + # MMA and Copy atom operations + "make_atom", "make_mma_atom", "make_tiled_mma", "make_copy_atom", "make_tiled_copy_tv", "make_tiled_copy", + "make_tiled_copy_S", + "make_tiled_copy_D", + "make_tiled_copy_A", + "make_tiled_copy_B", + "make_tiled_copy_C", "make_tiled_copy_C_atom", + "make_cotiled_copy", + "copy_atom_call", + # Algorithm operations "basic_copy", "basic_copy_if", "autovec_copy", "copy", - "copy_atom_call", + "prefetch", "gemm", # Tensor creation "full", @@ -302,8 +339,9 @@ __all__ = [ "where", "any_", "all_", + "repeat_as_tuple", + "repeat", "repeat_like", - "ScaledBasis", # User defined struct "struct", # Modules @@ -311,6 +349,8 @@ __all__ = [ "nvgpu", "testing", "runtime", + # Math utils + *math.__all__, # Decorators and code generation "jit", "kernel", diff --git a/python/CuTeDSL/cutlass/cute/algorithm.py b/python/CuTeDSL/cutlass/cute/algorithm.py new file mode 100644 index 00000000..e93e2f74 --- /dev/null +++ b/python/CuTeDSL/cutlass/cute/algorithm.py @@ -0,0 +1,438 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import math +from typing import Optional, Dict, Any, List, Tuple + +from cutlass._mlir import ir +from cutlass.cutlass_dsl import for_generate, yield_out, if_generate, dsl_user_op +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir + +from .typing import Tensor, Int64, Int16, AddressSpace +from .core import ( + rank, + is_static, + size, + make_layout, + make_ptr, + max_common_layout, + logical_divide, + append_ones, + group_modes, +) +from .atom import MmaAtom, CopyAtom, make_atom + + +@dsl_user_op +def gemm( + atom: MmaAtom, + d: Tensor, + a: Tensor, + b: Tensor, + c: Tensor, + *, + loc=None, + ip=None, + **kwargs, +) -> None: + """The GEMM algorithm. + + Computes ``D <- A * B + C`` where ``C`` and ``D`` can alias. Note that some MMA Atoms (e.g. + warpgroup-wide or tcgen05 MMAs) require manually setting an "accumulate" boolean field. + + All tensors must be partitioned according to the provided MMA Atom. + + For MMA Atoms that require single-threaded execution, the gemm op automatically handles thread + election internally. Manual thread selection is not required in such cases. + + Following dispatch rules are supported: + + - Dispatch [1]: (V) x (V) => (V) => (V,1,1) x (V,1,1) => (V,1,1) + - Dispatch [2]: (M) x (N) => (M,N) => (1,M,1) x (1,N,1) => (1,M,N) + - Dispatch [3]: (M,K) x (N,K) => (M,N) => (1,M,K) x (1,N,K) => (1,M,N) + - Dispatch [4]: (V,M) x (V,N) => (V,M,N) => (V,M,1) x (V,N,1) => (V,M,N) + - Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) + + :param atom: MMA atom + :type atom: MmaAtom + :param d: Destination tensor + :type d: Tensor + :param a: First source tensor + :type a: Tensor + :param b: Second source tensor + :type b: Tensor + :param c: Third source tensor + :type c: Tensor + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point for MLIR, defaults to None + :type ip: Optional[InsertionPoint], optional + :param kwargs: Additional keyword arguments + :type kwargs: dict + :return: None + :rtype: None + """ + + a_rank = rank(a.shape) + b_rank = rank(b.shape) + c_rank = rank(c.shape) + d_rank = rank(d.shape) + + if a_rank != b_rank: + raise ValueError("`a` and `b` must have the same rank") + + if c_rank != d_rank: + raise ValueError("`c` and `d` must have the same rank") + + if a_rank == 1: + if c_rank > 2: + raise ValueError("`c` must have rank <= 2 when `a` has rank 1") + elif a_rank == 2: + if c_rank not in (2, 3): + raise ValueError("`c` must have rank 2 or 3 when `a` has rank 2") + elif a_rank == 3: + if c_rank != 3: + raise ValueError("`c` must have rank 3 when `a` has rank 3") + + value = atom._unpack(loc=loc, ip=ip, **kwargs) + return _cute_ir.gemm(value, d.value, a.value, b.value, c.value, loc=loc, ip=ip) + + +@dsl_user_op +def basic_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: + """Performs a basic element-wise copy. + + This functions **assumes** the following pre-conditions: + 1. `size(src) == size(dst)` + + When the `src` and `dst` shapes are static, the pre-conditions are actually verified and the + element-wise loop is fully unrolled. + + :param src: Source tensor + :type src: Tensor + :param dst: Destination tensor + :type dst: Tensor + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + """ + + if is_static(src.shape) and is_static(dst.shape): + simt_copy_ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( + src.element_type.mlir_type, src.element_type.width + ) + simt_copy = make_atom(simt_copy_ty, loc=loc, ip=ip) + return _cute_ir.copy(simt_copy, src.value, dst.value, loc=loc, ip=ip) + + s = size(dst, loc=loc, ip=ip) + # Always generate an scf.for Op when one of the tensors is dynamic + for i in for_generate(0, s, loc=loc, ip=ip): + dst[i] = src[i] + yield_out() + + +@dsl_user_op +def basic_copy_if(pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: + """Performs a basic predicated element-wise copy. + + This functions **assumes** the following pre-conditions: + 1. `size(src) == size(dst)` + 2. `size(src) == size(pred)` + + When all shapes are static, the pre-conditions are actually verified and the element-wise loop + is fully unrolled. + + """ + if src.element_type.width != dst.element_type.width: + raise NotImplementedError( + "basic_copy_if currently only supports equal source and destination " + "element type bit width" + ) + + if is_static(src.shape) and is_static(dst.shape) and is_static(pred.shape): + return _basic_copy_if_static(pred, src, dst, loc=loc, ip=ip) + + s = size(dst, loc=loc, ip=ip) + # Always generate an scf.for Op when one of the tensors is dynamic + for i in for_generate(0, s, loc=loc, ip=ip): + if_generate(pred[i], lambda: dst.__setitem__(i, src[i]), loc=loc, ip=ip) # type: ignore + yield_out() + + +# Version of basic_copy_if when src and dst have static shapes +# - verify size(src) == size(dst) == size(prd) +# - fully unroll the loop for now +def _basic_copy_if_static( + pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None +) -> None: + assert is_static(src.shape) and is_static(dst.shape) and is_static(pred.shape) + if size(src, loc=loc, ip=ip) != size(dst, loc=loc, ip=ip): + raise ValueError( + "basic_copy expects the size of source, destination, and predicate tensors to match" + ) + # Fully unrolled loop in the static case for now + for i in range(size(dst, loc=loc, ip=ip)): + if_generate(pred[i], lambda: dst.__setitem__(i, src[i]), loc=loc, ip=ip) # type: ignore + + +@dsl_user_op +def autovec_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: + """ + Auto-vectorization SIMT copy policy. + + Given a source and destination tensors that are statically shaped, this policy figures out the + largest safe vector width that the copy instruction can take and performs the copy. + """ + if src.element_type.width != dst.element_type.width: + raise NotImplementedError( + "autovec_copy currently only supports equal source and destination " + "element type bit width" + ) + + # We are going to dispatch to copy-with-atom which requires shapes to be static + if not is_static(src.shape) or not is_static(dst.shape): + raise ValueError( + "autovec_copy expects source and destination tensors to be statically shaped" + ) + + vec_layout = max_common_layout(src, dst, loc=loc, ip=ip) + num_common_elements = size(vec_layout, loc=loc, ip=ip) + + # Next we construct an upper-bound on the number bits that can be vectorized by considering + # - the maximum alignment of the layouts + # - the maximum alignment of the pointers + + upper_bound = math.gcd(src.layout.max_alignment, dst.layout.max_alignment) + upper_bound = math.gcd(upper_bound, num_common_elements) + upper_bound *= src.element_type.width + + # For our instructions, the alignment of the pointer is an upper bound to the vector width + # max_alignment, as opposed to alignment, takes into account possible address swizzling + upper_bound = math.gcd(upper_bound, src.iterator.max_alignment * 8) + upper_bound = math.gcd(upper_bound, dst.iterator.max_alignment * 8) + + # Finally, we put a cap at 128b + num_bits_per_copy = math.gcd(upper_bound, 128) + + if (num_common_elements > 1) and (num_bits_per_copy % 8 == 0): + num_common_elements = num_bits_per_copy // src.element_type.width + + # 2 step logical divides ensuring that the divides are valid at every step + vec_src = logical_divide(src, vec_layout, loc=loc, ip=ip) + vec_dst = logical_divide(dst, vec_layout, loc=loc, ip=ip) + tiled_src = logical_divide( + vec_src, make_layout(num_common_elements, loc=loc, ip=ip), loc=loc, ip=ip + ) + tiled_dst = logical_divide( + vec_dst, make_layout(num_common_elements, loc=loc, ip=ip), loc=loc, ip=ip + ) + + # Dispatch to copy with atom + simt_type = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( + src.element_type.mlir_type, num_bits_per_copy + ) + simt_copy = make_atom(simt_type, loc=loc, ip=ip) + return _cute_ir.copy( + simt_copy, tiled_src.value, tiled_dst.value, loc=loc, ip=ip + ) + + # Failed to vectorize, use a basic copy + basic_copy(src, dst, loc=loc, ip=ip) + + +def _parse_auto_multicast_args( + kwargs: Dict[str, Any], +) -> List[Tuple[str, ir.Attribute]]: + """ + Parse multicast-related kwargs and return a list of (attr_name, attr) pairs. + + This function consumes the following key from kwargs if present: + - 'auto_multicast': dict + dict: { 'multicast_layout': str, 'use_2cta': bool } + + Returns: + List of (attr_name, ir.Attribute) pairs to be attached to the op. + Recognized attributes: + - ('multicast_layout', #cute.layout<...>) when a layout string is provided + - ('use_2cta', unit) when use_2cta is True + """ + attr_pairs: List[Tuple[str, ir.Attribute]] = [] + + # Pop known keys to avoid leaking to trait unpack + auto_multicast = kwargs.pop("auto_multicast", None) + + use_2cta: bool = False + layout_str: Optional[str] = None + + if auto_multicast is not None: + if not isinstance(auto_multicast, dict): + raise TypeError( + "auto_multicast must be a dict with keys 'multicast_layout' and optional 'use_2cta'" + ) + layout_str = auto_multicast.get("multicast_layout", None) + use_2cta = bool(auto_multicast.get("use_2cta", False)) + + if layout_str is not None: + if not isinstance(layout_str, str): + raise TypeError( + "multicast_layout must be a string representing a CuTe layout, e.g. '(4,2):(1,0)'" + ) + attr_pairs.append( + ( + "multicast_layout", + ir.Attribute.parse(f'#cute.layout<"{layout_str}">'), + ) + ) + + if use_2cta: + attr_pairs.append(("use_2cta", ir.UnitAttr.get())) + + return attr_pairs + + +@dsl_user_op +def copy( + atom: CopyAtom, + src: Tensor, + dst: Tensor, + *, + pred: Optional[Tensor] = None, + loc=None, + ip=None, + **kwargs, +) -> None: + """Facilitates data transfer between two tensors conforming to layout profile ``(V, Rest...)``. + + :param atom: Copy atom specifying the transfer operation + :type atom: CopyAtom + :param src: Source tensor with layout profile ``(V, Rest...)`` + :type src: Tensor + :param dst: Destination tensor with layout profile ``(V, Rest...)`` + :type dst: Tensor + :param pred: Optional predication tensor for conditional transfers, defaults to None + :type pred: Optional[Tensor], optional + :param loc: Source location information, defaults to None + :type loc: Any, optional + :param ip: Insertion point, defaults to None + :type ip: Any, optional + :param kwargs: Additional copy atom specific arguments + :type kwargs: Dict[str, Any] + :raises TypeError: If source and destination element type bit widths differ + :raises ValueError: If source and destination ranks differ + :raises ValueError: If source and destination mode-1 sizes differ + :raises NotImplementedError: If ``V-mode`` rank exceeds 2 + :return: None + :rtype: None + + The ``V-mode`` represents either: + + - A singular mode directly consumable by the provided Copy Atom + - A composite mode requiring recursive decomposition, structured as ``(V, Rest...)``, + and src/dst layout like ``((V, Rest...), Rest...)`` + + The algorithm recursively processes the ``V-mode``, decomposing it until reaching the minimum granularity + compatible with the provided Copy Atom's requirements. + + Source and destination tensors must be partitioned in accordance with the Copy Atom specifications. + Post-partitioning, both tensors will exhibit a ``(V, Rest...)`` layout profile. + + **Precondition:** The size of mode 1 must be equal for both source and destination tensors: + ``size(src, mode=[1]) == size(dst, mode=[1])`` + + **Examples**: + + TMA copy operation with multicast functionality: + + .. code-block:: python + + cute.copy(tma_atom, src, dst, tma_bar_ptr=mbar_ptr, mcast_mask=mask) + + Optional predication is supported through an additional tensor parameter. For partitioned tensors with + logical profile ``((ATOM_V,ATOM_REST),REST,...)``, the predication tensor must maintain profile + compatibility with ``(ATOM_REST,REST,...)``. + + For Copy Atoms requiring single-threaded execution, thread election is managed automatically by the + copy operation. External thread selection mechanisms are not necessary. + + .. note:: + + - Certain Atoms may require additional operation-specific keyword arguments. + - Current implementation limits ``V-mode`` rank to 2 or less. Support for higher ranks is planned + for future releases. + + """ + if isinstance(src.type, _cute_ir.MemRefType) and isinstance( + dst.type, _cute_ir.MemRefType + ): + if src.element_type.width != dst.element_type.width: + raise TypeError( + "`copy` currently only supports equal source and destination " + "element type bit width" + ) + + if rank(src) != rank(dst): + raise ValueError( + "Expected source and destination tensors to have the same rank, " + f"but got {rank(src)} and {rank(dst)}" + ) + + # Canonicalize to at least rank-2 tensors + src = group_modes(append_ones(src, up_to_rank=2), 1) + dst = group_modes(append_ones(dst, up_to_rank=2), 1) + if pred is not None: + pred = group_modes(append_ones(pred, up_to_rank=2), 1) + + if is_static(src.shape[1]) and is_static(dst.shape[1]): + if size(src, mode=[1]) != size(dst, mode=[1]): + raise ValueError( + "Expected source and destination tensors to have the same size in mode-1, " + f"but got {size(src, mode=[1])} and {size(dst, mode=[1])}" + ) + + multicast_attr_pairs = _parse_auto_multicast_args(kwargs) + + value = atom._unpack(loc=loc, ip=ip, **kwargs) + if isinstance(pred, Tensor): + pred = pred.value + + op = _cute_ir.copy(value, src.value, dst.value, pred=pred, loc=loc, ip=ip) + + for name, attr in multicast_attr_pairs: + op.attributes[name] = attr + + return op + + +@dsl_user_op +def prefetch(atom: CopyAtom, src: Tensor, *, loc=None, ip=None) -> None: + """ + The Prefetch algorithm. + + The "prefetch" expects source tensors to be partitioned according to the provided Copy Atom. + Prefetch is used for loading tensors from global memory to L2. + + Prefetch accepts Copy Atom but not all are allowed. Currently, only supports TMA prefetch. + + .. code-block:: python + + cute.prefetch(tma_prefetch, src) + + For Copy Atoms that require single-threaded execution, the copy op automatically handles thread + election internally. Manual thread selection is not required in such cases. + """ + dummy_tma_bar_ptr = make_ptr(Int64, 0, AddressSpace.smem, loc=loc, ip=ip) + dummy_mcast_mask = Int16(0) + value = atom._unpack( + loc=loc, ip=ip, tma_bar_ptr=dummy_tma_bar_ptr, mcast_mask=dummy_mcast_mask + ) + return _cute_ir.prefetch(value, src.value, loc=loc, ip=ip) diff --git a/python/CuTeDSL/cutlass/cute/arch/__init__.py b/python/CuTeDSL/cutlass/cute/arch/__init__.py index 01198215..7f9682e5 100644 --- a/python/CuTeDSL/cutlass/cute/arch/__init__.py +++ b/python/CuTeDSL/cutlass/cute/arch/__init__.py @@ -11,9 +11,11 @@ from .elect import * from .mbar import * +from .numeric_conversion import * from .nvvm_wrappers import * from .smem import * from .tmem import * +from .numeric_conversion import * # __all__ is required here for documentation generation __all__ = [ @@ -44,6 +46,7 @@ __all__ = [ "grid_dim", "cluster_idx", "cluster_dim", + "cluster_size", "block_in_cluster_idx", "block_in_cluster_dim", "block_idx_in_cluster", @@ -66,9 +69,12 @@ __all__ = [ "cluster_wait", "cluster_arrive", "cluster_arrive_relaxed", - "fence_proxy", "vote_ballot_sync", + "vote_any_sync", + "vote_all_sync", + "vote_uni_sync", "popc", + "fence_proxy", "fence_view_async_tmem_load", "fence_view_async_tmem_store", "warpgroup_reg_alloc", @@ -98,4 +104,15 @@ __all__ = [ "alloc_tmem", "relinquish_tmem_alloc_permit", "dealloc_tmem", + # + # numeric_conversion.py + # + "prmt", + "cvt_i8_bf16_intrinsic", + "cvt_i4_bf16_intrinsic", + "cvt_f4e2m1_f16_intrinsic", + "cvt_i8x4_to_f32x4", + "cvt_i8x2_to_f32x2", + "cvt_i8_bf16", + "cvt_f32x2_bf16x2", ] diff --git a/python/CuTeDSL/cutlass/cute/arch/elect.py b/python/CuTeDSL/cutlass/cute/arch/elect.py index ead552af..d754ceb6 100644 --- a/python/CuTeDSL/cutlass/cute/arch/elect.py +++ b/python/CuTeDSL/cutlass/cute/arch/elect.py @@ -9,6 +9,7 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. +from cutlass.base_dsl.arch import Arch from cutlass.cutlass_dsl import CuTeDSL, T, dsl_user_op import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir @@ -16,17 +17,17 @@ from cutlass._mlir.dialects import nvvm, scf from cutlass._mlir import ir from ..typing import Int, Int32 -from ...impl_utils import check_value_in @dsl_user_op def make_warp_uniform(value: Int, *, loc=None, ip=None) -> Int32: """ - Creates a warp-uniform value from the given integer input. + Provides a compiler hint indicating that the specified value is invariant across all threads in the warp, + which may enable performance optimizations. - :param value: The integer to make warp uniform. + :param value: The integer value to be marked as warp-uniform. :type value: Int - :return: The warp-uniform value equal to the input. + :return: The input value, marked as warp-uniform. :rtype: Int32 """ return Int32( @@ -68,17 +69,7 @@ def elect_one(*, loc=None, ip=None) -> IfOpRegion: # Only one thread in the warp executes the code in this context pass """ - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) + CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) is_thread_leader = nvvm.elect_sync(T.bool()) if_op = scf.IfOp(is_thread_leader, loc=loc, ip=ip) return IfOpRegion(if_op.then_block, loc=loc, ip=ip) diff --git a/python/CuTeDSL/cutlass/cute/arch/mbar.py b/python/CuTeDSL/cutlass/cute/arch/mbar.py index 80cb7b0b..dad7e380 100644 --- a/python/CuTeDSL/cutlass/cute/arch/mbar.py +++ b/python/CuTeDSL/cutlass/cute/arch/mbar.py @@ -10,14 +10,12 @@ # is strictly prohibited. from typing import Optional +from cutlass.base_dsl.arch import Arch from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op from cutlass._mlir.dialects import nvvm -from cutlass._mlir import ir - -from ..typing import Pointer, Int, Boolean, Int32 -from ...impl_utils import check_value_in +from ..typing import Pointer, Int, Boolean, Int32, AddressSpace #################################################################################################### # @@ -46,17 +44,7 @@ def mbarrier_init_fence(*, loc=None, ip=None) -> None: """ A fence operation that applies to the mbarrier initializations. """ - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) + CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) nvvm.fence_mbarrier_init(loc=loc, ip=ip) @@ -75,17 +63,7 @@ def mbarrier_arrive_and_expect_tx( the mbarrier is converted to a remote address in the peer CTA's SMEM. """ - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) + CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) mbar_llvm_ptr = mbar_ptr.llvm_ptr if peer_cta_rank_in_cluster is not None: @@ -125,17 +103,7 @@ def mbarrier_expect_tx( the mbarrier is converted to a remote address in the peer CTA's SMEM. """ - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) + CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) mbar_llvm_ptr = mbar_ptr.llvm_ptr if peer_cta_rank_in_cluster is not None: @@ -170,17 +138,7 @@ def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None: :param phase: The phase to wait for (either 0 or 1) :type phase: Int """ - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) + CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) timeout_ns = 10000000 # This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX @@ -206,17 +164,7 @@ def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Bo :return: A boolean value indicating whether the wait operation was successful :rtype: Boolean """ - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) + CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) return Boolean( nvvm.mbarrier_wait_parity( @@ -245,23 +193,15 @@ def mbarrier_conditional_try_wait( :return: A boolean value indicating whether the wait operation was successful :rtype: Boolean """ - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) + CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) return if_generate( cond, lambda: mbarrier_try_wait(mbar_ptr, phase, loc=loc, ip=ip), lambda: Boolean(True).ir_value(loc=loc, ip=ip), None, [Boolean], + loc=loc, + ip=ip, ) @@ -284,17 +224,7 @@ def mbarrier_arrive( """ mbar_llvm_ptr = mbar_ptr.llvm_ptr if peer_cta_rank_in_cluster is not None: - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) + CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) mbar_llvm_ptr = nvvm.mapa_shared_cluster( mbar_llvm_ptr.type, @@ -328,17 +258,7 @@ def cp_async_mbarrier_arrive_noinc(mbar_ptr: Pointer, *, loc=None, ip=None) -> N :param mbar_ptr: A pointer to the mbarrier in SMEM :type mbar_ptr: Pointer """ - arch = CuTeDSL._get_dsl().envar.arch - check_value_in( - arch, - [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ], - "arch", - ) + CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) mbar_llvm_ptr = mbar_ptr.llvm_ptr nvvm.cp_async_mbarrier_arrive_shared( diff --git a/python/CuTeDSL/cutlass/cute/arch/numeric_conversion.py b/python/CuTeDSL/cutlass/cute/arch/numeric_conversion.py new file mode 100644 index 00000000..869f2c11 --- /dev/null +++ b/python/CuTeDSL/cutlass/cute/arch/numeric_conversion.py @@ -0,0 +1,263 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + + +from cutlass.cutlass_dsl import dsl_user_op +from cutlass._mlir import ir +from cutlass._mlir.dialects import builtin, arith, llvm, vector + +from .nvvm_wrappers import ( + cvt_i8_bf16, + cvt_f32x2_bf16x2, + cvt_i8x4_to_f32x4, + cvt_i8x2_to_f32x2, + cvt_i4x8_to_bf16x8, + cvt_i4x4_to_bf16x4, + cvt_i4x2_to_bf16x2, + cvt_i4_bf16, + cvt_f4e2m1x8_to_f16x8, + cvt_f4e2m1x4_to_f16x4, + cvt_f4e2m1x2_to_f16x2, + cvt_f4e2m1_f16, +) + +from ..typing import ( + Int4, + Int8, + Int32, + Float16, + BFloat16, + Float32, +) + + +@dsl_user_op +def cvt_i8_bf16_intrinsic(vec_i8, length, *, loc=None, ip=None): + """ + Convert a vector of int8 to a vector of bfloat16. + + :param vec_i8: The input vector of int8. + :type vec_i8: 1D vector of int8 + :param length: The length of the input vector. + :type length: int + :return: The output 1D vector of bfloat16 with the same length as the input vector. + :rtype: 1D vector of bfloat16 + """ + src_pos = 0 + vec_i8x4_type = ir.VectorType.get([4], Int8.mlir_type, loc=loc) + vec_i8x2_type = ir.VectorType.get([2], Int8.mlir_type, loc=loc) + vec_f32x2_type = ir.VectorType.get([2], Float32.mlir_type, loc=loc) + vec_dst_type = ir.VectorType.get([length], BFloat16.mlir_type, loc=loc) + vec_dst = llvm.mlir_zero(vec_dst_type, loc=loc, ip=ip) + # try to use vectorized version + if length >= 4: + num_vec4 = length // 4 + for _ in range(num_vec4): + vec_i8x4 = vector.extract_strided_slice( + vec_i8x4_type, vec_i8, [src_pos], [4], [1], loc=loc, ip=ip + ) + vec_f32x4 = cvt_i8x4_to_f32x4(vec_i8x4, loc=loc, ip=ip) + vec_f32x2_lo = vector.extract_strided_slice( + vec_f32x2_type, vec_f32x4, [0], [2], [1], loc=loc, ip=ip + ) + vec_f32x2_hi = vector.extract_strided_slice( + vec_f32x2_type, vec_f32x4, [2], [2], [1], loc=loc, ip=ip + ) + vec_bf16x2_lo = cvt_f32x2_bf16x2(vec_f32x2_lo, loc=loc, ip=ip) + vec_bf16x2_hi = cvt_f32x2_bf16x2(vec_f32x2_hi, loc=loc, ip=ip) + vec_dst = vector.insert_strided_slice( + vec_bf16x2_lo, vec_dst, [src_pos], [1], loc=loc, ip=ip + ) + vec_dst = vector.insert_strided_slice( + vec_bf16x2_hi, vec_dst, [src_pos + 2], [1], loc=loc, ip=ip + ) + src_pos += 4 + length -= 4 + if length >= 2: + vec_i8x2 = vector.extract_strided_slice( + vec_i8x2_type, vec_i8, [src_pos], [2], [1], loc=loc, ip=ip + ) + vec_f32x2 = cvt_i8x2_to_f32x2(vec_i8x2, loc=loc, ip=ip) + vec_bf16x2 = cvt_f32x2_bf16x2(vec_f32x2, loc=loc, ip=ip) + vec_dst = vector.insert_strided_slice( + vec_bf16x2, vec_dst, [src_pos], [1], loc=loc, ip=ip + ) + src_pos += 2 + length -= 2 + if length >= 1: + val_bf16 = cvt_i8_bf16( + vector.extractelement( + vec_i8, + position=arith.constant(Int32.mlir_type, src_pos), + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + vec_dst = vector.insertelement( + val_bf16, + vec_dst, + position=arith.constant(Int32.mlir_type, src_pos), + loc=loc, + ip=ip, + ) + return vec_dst + + +@dsl_user_op +def cvt_i4_bf16_intrinsic(vec_i4, length, *, loc=None, ip=None): + """ + Convert a vector of int4 to a vector of bfloat16. + + :param vec_i4: The input vector of int4. + :type vec_i4: 1D vector of int4 + :param length: The length of the input vector. + :type length: int + :return: The output 1D vector of bfloat16 with the same length as the input vector. + :rtype: 1D vector of bfloat16 + """ + src_pos = 0 + vec_i4x8_type = ir.VectorType.get([8], Int4.mlir_type, loc=loc) + vec_i4x4_type = ir.VectorType.get([4], Int4.mlir_type, loc=loc) + vec_i4x2_type = ir.VectorType.get([2], Int4.mlir_type, loc=loc) + vec_dst_type = ir.VectorType.get([length], BFloat16.mlir_type, loc=loc) + vec_dst = llvm.mlir_zero(vec_dst_type, loc=loc, ip=ip) + # try to use vectorized version + if length >= 8: + num_vec8 = length // 8 + for _ in range(num_vec8): + vec_i4x8 = vector.extract_strided_slice( + vec_i4x8_type, vec_i4, [src_pos], [8], [1], loc=loc, ip=ip + ) + vec_bf16x8 = cvt_i4x8_to_bf16x8(vec_i4x8, loc=loc, ip=ip) + vec_dst = vector.insert_strided_slice( + vec_bf16x8, vec_dst, [src_pos], [1], loc=loc, ip=ip + ) + src_pos += 8 + length -= 8 + if length >= 4: + vec_i4x4 = vector.extract_strided_slice( + vec_i4x4_type, vec_i4, [src_pos], [4], [1], loc=loc, ip=ip + ) + vec_bf16x4 = cvt_i4x4_to_bf16x4(vec_i4x4, loc=loc, ip=ip) + vec_dst = vector.insert_strided_slice( + vec_bf16x4, vec_dst, [src_pos], [1], loc=loc, ip=ip + ) + src_pos += 4 + length -= 4 + if length >= 2: + vec_i4x2 = vector.extract_strided_slice( + vec_i4x2_type, vec_i4, [src_pos], [2], [1], loc=loc, ip=ip + ) + vec_bf16x2 = cvt_i4x2_to_bf16x2(vec_i4x2, loc=loc, ip=ip) + vec_dst = vector.insert_strided_slice( + vec_bf16x2, vec_dst, [src_pos], [1], loc=loc, ip=ip + ) + src_pos += 2 + length -= 2 + if length >= 1: + val_bf16 = cvt_i4_bf16( + vector.extractelement( + vec_i4, + position=arith.constant(Int32.mlir_type, src_pos), + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + vec_dst = vector.insertelement( + val_bf16, + vec_dst, + position=arith.constant(Int32.mlir_type, src_pos), + loc=loc, + ip=ip, + ) + return vec_dst + + +@dsl_user_op +def cvt_f4e2m1_f16_intrinsic(vec_f4e2m1, length, *, loc=None, ip=None): + """ + Convert a vector of float4e2m1 to a vector of float16. + + :param vec_f4e2m1: The input vector of float4e2m1. + :type vec_f4e2m1: 1D vector of float4e2m1 + :param length: The length of the input vector. + :type length: int + :return: The output 1D vector of float16 with the same length as the input vector. + :rtype: 1D vector of float16 + """ + src_pos = 0 + vec_src_i4 = builtin.unrealized_conversion_cast( + [ir.VectorType.get([length], Int4.mlir_type, loc=loc)], + [vec_f4e2m1], + loc=loc, + ip=ip, + ) + vec_i4x8_type = ir.VectorType.get([8], Int4.mlir_type, loc=loc) + vec_i4x4_type = ir.VectorType.get([4], Int4.mlir_type, loc=loc) + vec_i4x2_type = ir.VectorType.get([2], Int4.mlir_type, loc=loc) + vec_dst_type = ir.VectorType.get([length], Float16.mlir_type, loc=loc) + vec_dst = llvm.mlir_zero(vec_dst_type, loc=loc, ip=ip) + # try to use vectorized version + if length >= 8: + num_vec8 = length // 8 + for _ in range(num_vec8): + vec_f4e2m1x8 = vector.extract_strided_slice( + vec_i4x8_type, vec_src_i4, [src_pos], [8], [1], loc=loc, ip=ip + ) + vec_f16x8 = cvt_f4e2m1x8_to_f16x8(vec_f4e2m1x8, loc=loc, ip=ip) + vec_dst = vector.insert_strided_slice( + vec_f16x8, vec_dst, [src_pos], [1], loc=loc, ip=ip + ) + src_pos += 8 + length -= 8 + if length >= 4: + vec_f4e2m1x4 = vector.extract_strided_slice( + vec_i4x4_type, vec_src_i4, [src_pos], [4], [1], loc=loc, ip=ip + ) + vec_f16x4 = cvt_f4e2m1x4_to_f16x4(vec_f4e2m1x4, loc=loc, ip=ip) + vec_dst = vector.insert_strided_slice( + vec_f16x4, vec_dst, [src_pos], [1], loc=loc, ip=ip + ) + src_pos += 4 + length -= 4 + if length >= 2: + vec_f4e2m1x2 = vector.extract_strided_slice( + vec_i4x2_type, vec_src_i4, [src_pos], [2], [1], loc=loc, ip=ip + ) + vec_f16x2 = cvt_f4e2m1x2_to_f16x2(vec_f4e2m1x2, loc=loc, ip=ip) + vec_dst = vector.insert_strided_slice( + vec_f16x2, vec_dst, [src_pos], [1], loc=loc, ip=ip + ) + src_pos += 2 + length -= 2 + if length >= 1: + val_f16 = cvt_f4e2m1_f16( + vector.extractelement( + vec_src_i4, + position=arith.constant(Int32.mlir_type, src_pos), + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + vec_dst = vector.insertelement( + val_f16, + vec_dst, + position=arith.constant(Int32.mlir_type, src_pos), + loc=loc, + ip=ip, + ) + return vec_dst diff --git a/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py b/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py index 69e3b8ac..3b61288e 100644 --- a/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py +++ b/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py @@ -16,7 +16,7 @@ from typing_extensions import deprecated from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir import ir -from cutlass._mlir.dialects import llvm, nvvm, vector +from cutlass._mlir.dialects import arith, llvm, nvvm, vector # Forward nvvm enums from cutlass._mlir.dialects.nvvm import ( @@ -30,11 +30,13 @@ from cutlass._mlir.dialects.nvvm import ( from ..typing import ( Int, Boolean, + Int8, Int16, Uint16, Int32, Uint32, Int64, + Float16, Float32, BFloat16, Numeric, @@ -164,6 +166,14 @@ def block_in_cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: ) +@dsl_user_op +def cluster_size(*, loc=None, ip=None) -> Int32: + """ + Returns the number of CTA within the cluster. + """ + return Int32(nvvm.read_ptx_sreg_cluster_nctarank(T.i32(), loc=loc, ip=ip)) + + @dsl_user_op def block_idx_in_cluster(*, loc=None, ip=None) -> Int32: """ @@ -295,12 +305,49 @@ def shuffle_sync_op( shlf_res = llvm.bitcast(orig_type.mlir_type, shlf_res, loc=loc, ip=ip) return orig_type(shlf_res) + shuffle_sync = partial(shuffle_sync_op, kind=nvvm.ShflKind.idx) shuffle_sync_up = partial(shuffle_sync_op, kind=nvvm.ShflKind.up) shuffle_sync_down = partial(shuffle_sync_op, kind=nvvm.ShflKind.down) shuffle_sync_bfly = partial(shuffle_sync_op, kind=nvvm.ShflKind.bfly) +@dsl_user_op +def warp_reduction( + val: Numeric, op: Callable, *, threads_in_group: int = 32, loc=None, ip=None +) -> Numeric: + """warp reduction of a Numeric value(e.g.Float32) by shuffle_sync_bfly, accepts custom binary operator. + The threads_in_group is the number of threads reduction group in a warp. + E.g. 32 means the whole warp reduced in one group. 8 means the warp is divided into 4 thread groups, each group has 8 threads in reduction. + + + :param val: register value + :type val: cutlass.Numeric + :param op: binary operator + :type op: Callable + :param threads_in_group: the number of threads reduction group in a warp + :type threads_in_group: int + :return: reduced value + :rtype: cutlass.Numeric + """ + offset = threads_in_group // 2 + while offset > 0: + val = op( + val, + shuffle_sync_bfly( + val, offset=offset, mask=-1, mask_and_clamp=31, loc=loc, ip=ip + ), + ) + offset = offset // 2 + return val + + +warp_reduction_max = partial( + warp_reduction, op=lambda x, y: fmax(x, y) if isinstance(x, Float32) else max(x, y) +) +warp_reduction_sum = partial(warp_reduction, op=lambda x, y: x + y) + + @dsl_user_op def barrier(*, barrier_id=None, number_of_threads=None, loc=None, ip=None) -> None: """ @@ -473,8 +520,19 @@ def fence_proxy( def vote_ballot_sync( pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None ) -> Int32: - """ - Performs a ballot operation across the warp. + """Performs a ballot operation across the warp. + + It copies the predicate from each thread in mask into the corresponding bit position of + destination register d, where the bit position corresponds to the thread's lane id. + + :param pred: The predicate value for the current thread + :type pred: Boolean + :param mask: A 32-bit integer mask specifying which threads participate, defaults to all threads (0xFFFFFFFF) + :type mask: Int, optional + :return: A 32-bit integer where each bit represents a thread's predicate value + :rtype: Int32 + + See the `PTX documentation `__. """ return Int32( nvvm.vote_ballot_sync( @@ -487,6 +545,97 @@ def vote_ballot_sync( ) +@dsl_user_op +def vote_sync_op( + pred: Boolean, kind: str, mask: Int = FULL_MASK, *, loc=None, ip=None +) -> Union[Int32, Boolean]: + return_type = Boolean + return_type_str = "pred" + return return_type( + llvm.inline_asm( + T.bool(), + [ + Boolean(pred).ir_value(loc=loc, ip=ip), + Int32(mask).ir_value(loc=loc, ip=ip), + ], + f"""{{\n\t + .reg .pred ps;\n\t + .reg .pred pd;\n\t + setp.ne.b32 ps, $1, 0;\n\t + vote.sync.{kind}.{return_type_str} pd, ps, $2;\n\t + selp.b32 $0, 1, 0, pd;\n\t + }}""", + "=r,r,i", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def vote_any_sync( + pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None +) -> Boolean: + """True if source predicate is True for any non-exited threads in mask. Negate the source + predicate to compute .not_all. + + :param pred: The predicate value for the current thread + :type pred: Boolean + :param mask: A 32-bit integer mask specifying which threads participate, defaults to all + threads (0xFFFFFFFF) + :type mask: Int, optional + :return: A boolean value indicating if the source predicate is True for all non-exited + threads in mask + :rtype: Boolean + + See the `PTX documentation `__. + """ + return vote_sync_op(pred, "any", mask, loc=loc, ip=ip) + + +@dsl_user_op +def vote_all_sync( + pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None +) -> Boolean: + """True if source predicate is True for all non-exited threads in mask. Negate the source + predicate to compute .none. + + :param pred: The predicate value for the current thread + :type pred: Boolean + :param mask: A 32-bit integer mask specifying which threads participate, defaults to all + threads (0xFFFFFFFF) + :type mask: Int, optional + :return: A boolean value indicating if the source predicate is True for all non-exited + threads in mask + :rtype: Boolean + + See the `PTX documentation `__. + """ + return vote_sync_op(pred, "all", mask, loc=loc, ip=ip) + + +@dsl_user_op +def vote_uni_sync( + pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None +) -> Boolean: + """True f source predicate has the same value in all non-exited threads in mask. Negating + the source predicate also computes .uni + + :param pred: The predicate value for the current thread + :type pred: Boolean + :param mask: A 32-bit integer mask specifying which threads participate, defaults to all + threads (0xFFFFFFFF) + :type mask: Int, optional + :return: A boolean value indicating if the source predicate is True for all non-exited + threads in mask + :rtype: Boolean + """ + return vote_sync_op(pred, "uni", mask, loc=loc, ip=ip) + + @dsl_user_op def popc(value: Numeric, *, loc=None, ip=None) -> Numeric: """ @@ -494,7 +643,7 @@ def popc(value: Numeric, *, loc=None, ip=None) -> Numeric: """ if not isinstance(value, Numeric): value = as_numeric(value) - return type(value)(llvm.intr_ctpop(value.ir_value(), loc=loc, ip=ip)) + return type(value)(llvm.intr_ctpop(value.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) @dsl_user_op @@ -546,6 +695,27 @@ fence_view_async_tmem_store = partial( ) +@dsl_user_op +def fence_view_async_shared( + *, + loc=None, + ip=None, +) -> None: + """ + Perform a fence operation on the async shared memory load or store. + + .. note:: + This function is only available on sm_90 or higher. + The fence is required to synchronize the shared memory load/store + and let the pipeline release or commit the buffer. + + This function is usually used for async execution unit (like TMA, UMMA) after the load/store operations. + """ + nvvm.fence_proxy( + nvvm.ProxyKind.async_shared, space=nvvm.SharedSpace.shared_cta, loc=loc, ip=ip + ) + + @dsl_user_op def warpgroup_reg_realloc_op( reg_count: int, @@ -569,7 +739,7 @@ warpgroup_reg_dealloc = partial( def calc_packed_f32x2_op( src_a: Tuple[Float32, Float32], src_b: Tuple[Float32, Float32], - src_c: Tuple[Float32, Float32] | None, + src_c: Optional[Tuple[Float32, Float32]], calc_func: Callable, *, rnd=RoundingModeKind.RZ, @@ -579,14 +749,23 @@ def calc_packed_f32x2_op( ) -> Tuple[Float32, Float32]: vec_type = ir.VectorType.get([2], Float32.mlir_type, loc=loc) vec_src_a = vector.from_elements( - vec_type, tuple(as_numeric(a).ir_value() for a in src_a), loc=loc, ip=ip + vec_type, + tuple(as_numeric(a).ir_value(loc=loc, ip=ip) for a in src_a), + loc=loc, + ip=ip, ) vec_src_b = vector.from_elements( - vec_type, tuple(as_numeric(b).ir_value() for b in src_b), loc=loc, ip=ip + vec_type, + tuple(as_numeric(b).ir_value(loc=loc, ip=ip) for b in src_b), + loc=loc, + ip=ip, ) if src_c is not None: vec_src_c = vector.from_elements( - vec_type, tuple(as_numeric(c).ir_value() for c in src_c), loc=loc, ip=ip + vec_type, + tuple(as_numeric(c).ir_value(loc=loc, ip=ip) for c in src_c), + loc=loc, + ip=ip, ) vec_res = calc_func( vec_type, vec_src_a, vec_src_b, vec_src_c, rnd=rnd, ftz=ftz, loc=loc, ip=ip @@ -660,6 +839,418 @@ def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32: ) +# Convert 1 int8 value to 1 bfloat16 value +@dsl_user_op +def cvt_i8_bf16(src_i8, *, loc=None, ip=None): + src_i16 = llvm.zext(Int16.mlir_type, src_i8, loc=loc, ip=ip) + val_i16 = llvm.inline_asm( + Uint16.mlir_type, + [ + src_i16, + ], + """{\n\t + .reg .b16 r;\n\t + .reg .b8 s;\n\t + mov.b16 {s,_}, $1;\n\t + cvt.rn.bf16.s8 r, s;\n\t + mov.b16 $0, r;\n\t + }""", + "=h,h", + ) + val_bf16 = llvm.bitcast(BFloat16.mlir_type, val_i16, loc=loc, ip=ip) + return val_bf16 + + +# Convert vector of 2 float values to vector of 2 bfloat16 values with satfinite rounding +@dsl_user_op +def cvt_f32x2_bf16x2(src_vec2, *, loc=None, ip=None): + src0 = vector.extractelement( + src_vec2, position=arith.constant(Int32.mlir_type, 0, loc=loc, ip=ip) + ) + src1 = vector.extractelement( + src_vec2, position=arith.constant(Int32.mlir_type, 1, loc=loc, ip=ip) + ) + rst = llvm.inline_asm( + T.i32(), + [ + Float32(src1).ir_value(loc=loc, ip=ip), + Float32(src0).ir_value(loc=loc, ip=ip), + ], + "cvt.rn.satfinite.bf16x2.f32 $0, $1, $2;", + "=r,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + vec_type = ir.VectorType.get([2], BFloat16.mlir_type, loc=loc) + vec_bf16x2 = llvm.bitcast(vec_type, rst, loc=loc, ip=ip) + return vec_bf16x2 + + +# Convert 1 float32 value to 1 bfloat16 value +@dsl_user_op +def cvt_f32_bf16(src_f32, *, loc=None, ip=None): + bf16_val = llvm.inline_asm( + BFloat16.mlir_type, + [ + src_f32, + ], + "cvt.rn.bf16.f32 $0, $1;", + "=h,f", + ) + return bf16_val + + +# Convert vector of 4 int8 values to vector of 4 float32 values +@dsl_user_op +def cvt_i8x4_to_f32x4(src_vec4, *, loc=None, ip=None): + zero = arith.constant(Int32.mlir_type, 0, loc=loc, ip=ip) + mask4 = ( + arith.constant(Int32.mlir_type, 0x00000001, loc=loc, ip=ip), + arith.constant(Int32.mlir_type, 0x00000100, loc=loc, ip=ip), + arith.constant(Int32.mlir_type, 0x00010000, loc=loc, ip=ip), + arith.constant(Int32.mlir_type, 0x01000000, loc=loc, ip=ip), + ) + src_i32 = llvm.bitcast(Int32.mlir_type, src_vec4, loc=loc, ip=ip) + rst0 = llvm.inline_asm( + Int32.mlir_type, + [ + src_i32, + mask4[0], + zero, + ], + "dp4a.s32.s32 $0, $1, $2, $3;", + "=r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + rst1 = llvm.inline_asm( + Int32.mlir_type, + [ + src_i32, + mask4[1], + zero, + ], + "dp4a.s32.s32 $0, $1, $2, $3;", + "=r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + rst2 = llvm.inline_asm( + Int32.mlir_type, + [ + src_i32, + mask4[2], + zero, + ], + "dp4a.s32.s32 $0, $1, $2, $3;", + "=r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + rst3 = llvm.inline_asm( + Int32.mlir_type, + [ + src_i32, + mask4[3], + zero, + ], + "dp4a.s32.s32 $0, $1, $2, $3;", + "=r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + res0 = llvm.inline_asm( + Float32.mlir_type, + [ + rst0, + ], + "cvt.rn.f32.s32 $0, $1;", + "=f,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + res1 = llvm.inline_asm( + Float32.mlir_type, + [ + rst1, + ], + "cvt.rn.f32.s32 $0, $1;", + "=f,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + res2 = llvm.inline_asm( + Float32.mlir_type, + [ + rst2, + ], + "cvt.rn.f32.s32 $0, $1;", + "=f,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + res3 = llvm.inline_asm( + Float32.mlir_type, + [ + rst3, + ], + "cvt.rn.f32.s32 $0, $1;", + "=f,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + vec_f32x4_type = ir.VectorType.get([4], Float32.mlir_type, loc=loc) + vec_f32x4 = vector.from_elements( + vec_f32x4_type, [res0, res1, res2, res3], loc=loc, ip=ip + ) + return vec_f32x4 + + +# Convert vector of 2 int8 values to vector of 2 float32 values +@dsl_user_op +def cvt_i8x2_to_f32x2(src_vec2, *, loc=None, ip=None): + zero = arith.constant(Int32.mlir_type, 0, loc=loc, ip=ip) + mask2 = ( + arith.constant(Int32.mlir_type, 0x00000001, loc=loc, ip=ip), + arith.constant(Int32.mlir_type, 0x00000100, loc=loc, ip=ip), + ) + src_i16 = llvm.bitcast(Int16.mlir_type, src_vec2, loc=loc, ip=ip) + src_i32_pad16b = llvm.zext(Int32.mlir_type, src_i16, loc=loc, ip=ip) + rst0 = llvm.inline_asm( + Int32.mlir_type, + [ + src_i32_pad16b, + mask2[0], + zero, + ], + "dp4a.s32.s32 $0, $1, $2, $3;", + "=r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + rst1 = llvm.inline_asm( + Int32.mlir_type, + [ + src_i32_pad16b, + mask2[1], + zero, + ], + "dp4a.s32.s32 $0, $1, $2, $3;", + "=r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + res0 = llvm.inline_asm( + Float32.mlir_type, + [ + rst0, + ], + "cvt.rn.f32.s32 $0, $1;", + "=f,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + res1 = llvm.inline_asm( + Float32.mlir_type, + [ + rst1, + ], + "cvt.rn.f32.s32 $0, $1;", + "=f,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + vec_f32x2_type = ir.VectorType.get([2], Float32.mlir_type, loc=loc) + vec_f32x2 = vector.from_elements(vec_f32x2_type, [res0, res1], loc=loc, ip=ip) + return vec_f32x2 + + +# Permute bytes from register pair. +@dsl_user_op +def prmt(src, src_reg_shifted, prmt_indices, *, loc=None, ip=None): + return llvm.inline_asm( + T.i32(), + [ + Int32(src).ir_value(loc=loc, ip=ip), + Int32(src_reg_shifted).ir_value(loc=loc, ip=ip), + Int32(prmt_indices).ir_value(loc=loc, ip=ip), + ], + "prmt.b32 $0, $1, $2, $3;", + "=r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +# Convert 1 int4 value to 1 bfloat16 value +@dsl_user_op +def cvt_i4_bf16(src_i4, *, loc=None, ip=None): + # i4 -> i32 -> f32 -> bf + src_i32 = llvm.zext(Int32.mlir_type, src_i4, loc=loc, ip=ip) + src_f32 = llvm.sitofp(Float32.mlir_type, src_i32, loc=loc, ip=ip) + bf16_val = cvt_f32_bf16(src_f32, loc=loc, ip=ip) + return bf16_val + + +# Convert multiple int4 values to bfloat16 values. +# The number of elements to be converted must be be even as specified by num_elts. +# Int4 values are packed into int32 values with upper bits filled with 0 if there are less than 4 int4 values. +# Results bfloat16 values are also packed into int32 values. +@dsl_user_op +def cvt_i4_to_bf16_impl(src_i32, num_elts, *, loc=None, ip=None): + c4 = arith.constant(Int32.mlir_type, 4, loc=loc, ip=ip) + src_shr4 = llvm.lshr(src_i32, c4, loc=loc, ip=ip) + xor_mask0 = arith.constant(Int32.mlir_type, 0x08080808, loc=loc, ip=ip) + and_mask = arith.constant(Int32.mlir_type, 0x0F0F0F0F, loc=loc, ip=ip) + imm_lut = arith.constant(Int32.mlir_type, 0x0000006A, loc=loc, ip=ip) + src_i32 = llvm.inline_asm( + Int32.mlir_type, + [ + src_i32, + and_mask, + xor_mask0, + imm_lut, + ], + "lop3.b32 $0, $1, $2, $3, $4;", + "=r,r,n,n,n", + ) + xor_mask1 = arith.constant(Int32.mlir_type, 0x88080808, loc=loc, ip=ip) + src_shr4 = llvm.inline_asm( + Int32.mlir_type, + [ + src_shr4, + and_mask, + xor_mask1, + imm_lut, + ], + "lop3.b32 $0, $1, $2, $3, $4;", + "=r,r,n,n,n", + ) + prmt_indices = [ + arith.constant(Int32.mlir_type, imme, loc=loc, ip=ip) + for imme in [ + 0x0000F4F0, + 0x0000F5F1, + 0x0000F6F2, + 0x0000F7F3, + ] + ] + num_i32_elts = num_elts // 2 + rsts = [] + for i in range(num_i32_elts): + rst = llvm.inline_asm( + Int32.mlir_type, + [ + src_i32, + src_shr4, + prmt_indices[i], + ], + "prmt.b32 $0, $1, $2, $3;", + "=r,r,r,r", + ) + rsts.append(rst) + mask_clear_top_bit = arith.constant(Int32.mlir_type, 0xFF7FFFFF, loc=loc, ip=ip) + rsts[-1] = llvm.inline_asm( + Int32.mlir_type, + [ + rsts[-1], + mask_clear_top_bit, + ], + "and.b32 $0, $1, $2;", + "=r,r,r", + ) + mul = arith.constant(Int32.mlir_type, 0x83808380, loc=loc, ip=ip) + bias = arith.constant(Int32.mlir_type, 0xC308C308, loc=loc, ip=ip) + for i in range(num_i32_elts): + rsts[i] = llvm.inline_asm( + Int32.mlir_type, + [ + rsts[i], + mul, + bias, + ], + "fma.rn.bf16x2 $0, $1, $2, $3;", + "=r,r,r,r", + ) + # pack rsts into a vector + vec_type = ir.VectorType.get([num_i32_elts], Int32.mlir_type, loc=loc) + vec_rsts = vector.from_elements(vec_type, rsts, loc=loc, ip=ip) + return vec_rsts + + +# Convert 2 int4 values to 2 bfloat16 values +@dsl_user_op +def cvt_i4x2_to_bf16x2(src_vec2, *, loc=None, ip=None): + # pack 2 int4 into 1 int32 value and fill upper bits with 0 + src_i8 = llvm.bitcast(Int8.mlir_type, src_vec2, loc=loc, ip=ip) + src_i32 = llvm.zext(Int32.mlir_type, src_i8, loc=loc, ip=ip) + rst_i32 = cvt_i4_to_bf16_impl(src_i32, 2, loc=loc, ip=ip) + vec_bf16x2_type = ir.VectorType.get([2], BFloat16.mlir_type, loc=loc) + vec_bf16x2 = llvm.bitcast(vec_bf16x2_type, rst_i32, loc=loc, ip=ip) + return vec_bf16x2 + + +# Convert 4 int4 values to 4 bfloat16 values +@dsl_user_op +def cvt_i4x4_to_bf16x4(src_vec4, *, loc=None, ip=None): + # pack 4 int4 into 1 int32 value and fill upper bits with 0 + src_i16 = llvm.bitcast(Int16.mlir_type, src_vec4, loc=loc, ip=ip) + src_i32 = llvm.zext(Int32.mlir_type, src_i16, loc=loc, ip=ip) + rst_i32 = cvt_i4_to_bf16_impl(src_i32, 4, loc=loc, ip=ip) + vec_bf16x4_type = ir.VectorType.get([4], BFloat16.mlir_type, loc=loc) + vec_bf16x4 = llvm.bitcast(vec_bf16x4_type, rst_i32, loc=loc, ip=ip) + return vec_bf16x4 + + +# Convert 8 int4 values to 8 bfloat16 values +@dsl_user_op +def cvt_i4x8_to_bf16x8(src_vec8, *, loc=None, ip=None): + # pack 8 int4 into 1 int32 value and fill upper bits with 0 + src_i32 = llvm.bitcast(Int32.mlir_type, src_vec8, loc=loc, ip=ip) + rst_i32 = cvt_i4_to_bf16_impl(src_i32, 8, loc=loc, ip=ip) + vec_bf16x8_type = ir.VectorType.get([8], BFloat16.mlir_type, loc=loc) + vec_bf16x8 = llvm.bitcast(vec_bf16x8_type, rst_i32, loc=loc, ip=ip) + return vec_bf16x8 + + +@dsl_user_op +def log2_of_pow2_int(a: Int32, *, loc=None, ip=None) -> Int32: + tmp = llvm.inline_asm( + Int32.mlir_type, + [a.ir_value(loc=loc, ip=ip)], + "brev.b32 $0, $1;", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + return Int32( + llvm.inline_asm( + Int32.mlir_type, + [tmp], + "bfind.shiftamt.u32 $0, $1;", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op @deprecated( "cute.arch.exp is deprecated, use cute.math.exp with `fastmath=True` instead" @@ -679,3 +1270,96 @@ def exp_packed_f32x2( LOG2_E = Float32(1.4426950408889634) b = mul_packed_f32x2(a, (LOG2_E, LOG2_E), loc=loc, ip=ip) return exp2(b[0], loc=loc, ip=ip), exp2(b[1], loc=loc, ip=ip) + + + +@dsl_user_op +def cvt_f4e2m1_f16(src, *, loc=None, ip=None): + # 0 padding for upper 4 bits + zero = arith.constant(src.type, 0, loc=loc, ip=ip) + vec2 = vector.from_elements( + ir.VectorType.get([2], src.type, loc=loc), [src, zero], loc=loc, ip=ip + ) + rst_vec2 = cvt_f4e2m1x2_to_f16x2(vec2, loc=loc, ip=ip) + # only the 1st element is valid + rst = vector.extract( + rst_vec2, dynamic_position=[], static_position=[0], loc=loc, ip=ip + ) + return rst + + +# Convert 2 float4e2m1 values to 2 float16 values +@dsl_user_op +def cvt_f4e2m1x2_to_f16x2(src_vec2, *, loc=None, ip=None): + # pack 2 float4e2m1 into 1 int8 value and fill upper bits with 0 + src_i8 = llvm.bitcast(Int8.mlir_type, src_vec2, loc=loc, ip=ip) + src_i16 = llvm.zext(Int16.mlir_type, src_i8, loc=loc, ip=ip) + rst_i32 = llvm.inline_asm( + Int32.mlir_type, + [src_i16], + """{\n\t + .reg .b8 b;\n\t + mov.b16 {b,_}, $1;\n\t + cvt.rn.f16x2.e2m1x2 $0, b;\n\t + }""", + "=r,h", + ) + vec_f16x2_type = ir.VectorType.get([2], Float16.mlir_type, loc=loc) + vec_f16x2 = llvm.bitcast(vec_f16x2_type, rst_i32, loc=loc, ip=ip) + return vec_f16x2 + + +# Convert 4 float4e2m1 values to 4 float16 values +@dsl_user_op +def cvt_f4e2m1x4_to_f16x4(src_vec4, *, loc=None, ip=None): + # pack 4 float4e2m1 into 1 int16 value + src_i16 = llvm.bitcast(Int16.mlir_type, src_vec4, loc=loc, ip=ip) + rst_i32x2 = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32()]), + [src_i16], + """{\n\t + .reg .b8 b0, b1;\n\t + mov.b16 {b0, b1}, $2;\n\t + cvt.rn.f16x2.e2m1x2 $0, b0;\n\t + cvt.rn.f16x2.e2m1x2 $1, b1;\n\t + }""", + "=r,=r,h", + ) + res0 = llvm.extractvalue(T.i32(), rst_i32x2, [0]) + res1 = llvm.extractvalue(T.i32(), rst_i32x2, [1]) + vec_f32x2_type = ir.VectorType.get([2], Int32.mlir_type, loc=loc) + vec_f32x2 = vector.from_elements(vec_f32x2_type, [res0, res1], loc=loc, ip=ip) + vec_f16x4_type = ir.VectorType.get([4], Float16.mlir_type, loc=loc) + vec_f16x4 = llvm.bitcast(vec_f16x4_type, vec_f32x2, loc=loc, ip=ip) + return vec_f16x4 + + +# Convert 8 float4e2m1 values to 8 float16 values +@dsl_user_op +def cvt_f4e2m1x8_to_f16x8(src_vec8, *, loc=None, ip=None): + # pack 8 float4e2m1 into 1 int32 value and fill upper bits with 0 + src_i32 = llvm.bitcast(Int32.mlir_type, src_vec8, loc=loc, ip=ip) + rst_i32x4 = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]), + [src_i32], + """{\n\t + .reg .b8 b0, b1, b2, b3;\n\t + mov.b32 {b0, b1, b2, b3}, $4;\n\t + cvt.rn.f16x2.e2m1x2 $0, b0;\n\t + cvt.rn.f16x2.e2m1x2 $1, b1;\n\t + cvt.rn.f16x2.e2m1x2 $2, b2;\n\t + cvt.rn.f16x2.e2m1x2 $3, b3;\n\t + }""", + "=r,=r,=r,=r,r", + ) + res0 = llvm.extractvalue(T.i32(), rst_i32x4, [0]) + res1 = llvm.extractvalue(T.i32(), rst_i32x4, [1]) + res2 = llvm.extractvalue(T.i32(), rst_i32x4, [2]) + res3 = llvm.extractvalue(T.i32(), rst_i32x4, [3]) + vec_f32x4_type = ir.VectorType.get([4], Int32.mlir_type, loc=loc) + vec_f32x4 = vector.from_elements( + vec_f32x4_type, [res0, res1, res2, res3], loc=loc, ip=ip + ) + vec_f16x8_type = ir.VectorType.get([8], Float16.mlir_type, loc=loc) + vec_f16x8 = llvm.bitcast(vec_f16x8_type, vec_f32x4, loc=loc, ip=ip) + return vec_f16x8 diff --git a/python/CuTeDSL/cutlass/cute/atom.py b/python/CuTeDSL/cutlass/cute/atom.py new file mode 100644 index 00000000..a8d05a94 --- /dev/null +++ b/python/CuTeDSL/cutlass/cute/atom.py @@ -0,0 +1,1132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from abc import ABC, ABCMeta, abstractmethod +from typing import Type, Union, Optional, Any, overload + +from .typing import Shape, Layout, Tile, Tensor, Numeric, Int32 +from .core import ( + composition, + coalesce, + left_inverse, + filter, + pretty_str, + is_static, + make_layout, + make_layout_tv, + rank, + size, + static, +) +from .tuple import product_each + +# Internal utils +from .core import _unpack_x_tuple, _pack_shape, _pack_coord, _pack_tile +from .tensor import _Tensor, make_tensor + +from cutlass.cutlass_dsl import extract_mlir_values, new_from_mlir_values, dsl_user_op + +from cutlass._mlir import ir +from cutlass._mlir.dialects import cute as _cute_ir +from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir + + +class Op(ABC): + """ + Operation abstract base class. + """ + + pass + + +class MmaOp(Op, metaclass=ABCMeta): + """ + MMA Operation abstract base class. + """ + + @abstractmethod + def _make_trait(self, *, loc=None, ip=None, **kwargs): + pass + + +class CopyOp(Op, metaclass=ABCMeta): + """ + Copy Operation abstract base class. + """ + + @abstractmethod + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ): + pass + + +class Trait(ABC): + """ + Trait abstract base class. + + Traits are internal-only classes used by Atoms that wrap the underlying IR Value. The Python + user should only interact with Ops and Atoms. + """ + + def __init__(self, value: ir.Value) -> None: + self.value = value + + def __extract_mlir_values__(self): + return [self.value] + + def __new_from_mlir_values__(self, values): + return self.__class__(values[0]) + + def set(self, field, value, *, loc=None, ip=None) -> None: + raise NotImplementedError( + "set not implemented, the requesting Atom has likely no runtime state" + ) + + def get(self, field, *, loc=None, ip=None) -> Any: + raise NotImplementedError( + "get not implemented, the requesting Atom has likely no runtime state" + ) + + def unpack(self, *, loc=None, ip=None, **kwargs) -> ir.Value: + return self.value + + +def make_atom(ty, values=None, *, loc=None, ip=None): + """ + This is a wrapper around the _cute_ir.make_atom operation, providing default value for the values argument. + """ + if values is None: + values = [] + return _cute_ir.make_atom(ty, values, loc=loc, ip=ip) + + +class Atom(ABC): + """ + Atom base class. + + An Atom is the composition of + + - a MMA or Copy Operation; + - an internal MMA or Copy Trait. + + An Operation is a pure Python class that is used to model a specific MMA or Copy instruction. + The Trait wraps the underlying IR Value and provides access to the metadata of the instruction + encoded using CuTe Layouts. When the Trait can be constructed straighforwardly from an + Operation, the ``make_mma_atom`` or ``make_copy_atom`` API should be used. There are cases where + constructing the metadata is not trivial and requires more information, for example to determine + the number of bytes copied per TMA instruction ("the TMA vector length"). In such cases, + dedicated helper functions are provided with an appropriate API such that the Atom is + constructed internally in an optimal fashion for the user. + """ + + def __init__(self, op: Op, trait: Trait) -> None: + self._op = op + self._trait = trait + + def __extract_mlir_values__(self): + return extract_mlir_values(self._trait) + + def __new_from_mlir_values__(self, values): + return self.__class__(self.op, new_from_mlir_values(self._trait, values)) + + @property + def op(self) -> Op: + return self._op + + @property + def type(self): + return self._trait.value.type + + @dsl_user_op + def set(self, modifier, value, *, loc=None, ip=None) -> None: + """ + Sets runtime fields of the Atom. + + Some Atoms have runtime state, for example a tcgen05 MMA Atom + + + .. code-block:: python + + tiled_mma = cute.make_tiled_mma(some_tcgen05_mma_op) + tiled_mma.set(cute.nvgpu.tcgen05.Field.ACCUMULATE, True) + + The ``set`` method provides a way to the user to modify such runtime state. Modifiable + fields are provided by arch-specific enumerations, for example ``tcgen05.Field``. The Atom + instance internally validates the field as well as the value provided by the user to set + the field to. + """ + self._trait.set(modifier, value, loc=loc, ip=ip) + + @dsl_user_op + def get(self, field, *, loc=None, ip=None) -> Any: + """ + Gets runtime fields of the Atom. + + Some Atoms have runtime state, for example a tcgen05 MMA Atom + + .. code-block:: python + + tiled_mma = cute.make_tiled_mma(some_tcgen05_mma_op) + accum = tiled_mma.get(cute.nvgpu.tcgen05.Field.ACCUMULATE) + + The ``get`` method provides a way to the user to access such runtime state. Modifiable + fields are provided by arch-specific enumerations, for example ``tcgen05.Field``. The Atom + instance internally validates the field as well as the value provided by the user to set + the field to. + """ + return self._trait.get(field, loc=loc, ip=ip) + + def _unpack(self, *, loc=None, ip=None, **kwargs) -> ir.Value: + return self._trait.unpack(loc=loc, ip=ip, **kwargs) + + +#################################################################################################### +# +# MMA Atoms, TiledMma, and ThrMma +# +#################################################################################################### + + +class MmaAtom(Atom): + """ + The MMA Atom class. + """ + + def __str__(self) -> str: + res = "MMA Atom\n" + res += " ThrID: " + pretty_str(self.thr_id) + "\n" + res += " Shape MNK: " + pretty_str(self.shape_mnk) + "\n" + res += " TV Layout A: " + pretty_str(self.tv_layout_A) + "\n" + res += " TV Layout B: " + pretty_str(self.tv_layout_B) + "\n" + res += " TV Layout C: " + pretty_str(self.tv_layout_C) + return res + + # + # Properties + # + + @property + def thr_id(self) -> Layout: + return static(self._trait.value.type.thr_id) + + @property + def shape_mnk(self) -> Shape: + return _unpack_x_tuple(self._trait.value.type.shape_mnk) + + @property + def tv_layout_A(self) -> Layout: + return static(self._trait.value.type.layout_a_tv) + + @property + def tv_layout_B(self) -> Layout: + return static(self._trait.value.type.layout_b_tv) + + @property + def tv_layout_C(self) -> Layout: + return static(self._trait.value.type.layout_c_tv) + + # + # make_fragment + # + + @dsl_user_op + def make_fragment_A(self, input, *, loc=None, ip=None): + # input could be memref/shape/layout for tmem based fragment + if isinstance(input, _Tensor): + if self.op is not None: + self.op._verify_fragment_A(input, loc=loc, ip=ip) + input = input.value + if isinstance(input, tuple): + input = _pack_shape(input, loc=loc, ip=ip) + return _cute_ir.mma_make_fragment( + _cute_ir.MmaOperand.A, self._trait.value, input, loc=loc, ip=ip + ) + + @dsl_user_op + def make_fragment_B(self, input, *, loc=None, ip=None): + if isinstance(input, _Tensor): + if self.op is not None: + self.op._verify_fragment_B(input, loc=loc, ip=ip) + input = input.value + return _cute_ir.mma_make_fragment( + _cute_ir.MmaOperand.B, self._trait.value, input, loc=loc, ip=ip + ) + + @dsl_user_op + def make_fragment_C(self, input, *, loc=None, ip=None): + # input could be memref/shape/layout for tmem based fragment + if isinstance(input, _Tensor): + input = input.value + if isinstance(input, tuple): + input = _pack_shape(input, loc=loc, ip=ip) + return _cute_ir.mma_make_fragment( + _cute_ir.MmaOperand.C, self._trait.value, input, loc=loc, ip=ip + ) + + +class TiledMma(MmaAtom): + """ + The tiled MMA class. + """ + + def __str__(self) -> str: + res = "Tiled MMA\n" + res += " Thr Layout VMNK: " + pretty_str(self.thr_layout_vmnk) + "\n" + res += " Permutation MNK: " + pretty_str(self.permutation_mnk) + "\n" + res += "MMA Atom\n" + res += " ThrID: " + pretty_str(self.thr_id) + "\n" + res += " Shape MNK: " + pretty_str(self.shape_mnk) + "\n" + res += " TV Layout A: " + pretty_str(self.tv_layout_A) + "\n" + res += " TV Layout B: " + pretty_str(self.tv_layout_B) + "\n" + res += " TV Layout C: " + pretty_str(self.tv_layout_C) + return res + + # + # Properties + # + + @property + def tv_layout_A_tiled(self) -> Layout: + return static(self._trait.value.type.layout_a_tv_tiled) + + @property + def tv_layout_B_tiled(self) -> Layout: + return static(self._trait.value.type.layout_b_tv_tiled) + + @property + def tv_layout_C_tiled(self) -> Layout: + return static(self._trait.value.type.layout_c_tv_tiled) + + @property + def permutation_mnk(self) -> Tile: + return _unpack_x_tuple(self._trait.value.type.permutation_mnk) + + @property + def thr_layout_vmnk(self) -> Layout: + return static(self._trait.value.type.thr_layout_vmnk) + + @property + def size(self) -> int: + return self._trait.value.type.size + + # + # Tiler + # + + def get_tile_size(self, mode_idx: int) -> Shape: + assert (mode_idx >= 0) and (mode_idx < 3) + perm_tile = self.permutation_mnk[mode_idx] + if perm_tile is None: + thr_layout_vmnk = self.thr_layout_vmnk + atom_shape_mnk = self.shape_mnk + return size(atom_shape_mnk, mode=[mode_idx]) * size( + thr_layout_vmnk, mode=[mode_idx + 1] + ) + else: + return size(perm_tile) + + # + # get_slice + # + + def get_slice(self, thr_idx: Union[int, Int32]) -> "ThrMma": + return ThrMma(self.op, self._trait, thr_idx) + + # + # partition_shape + # + + def _partition_shape(self, operand_id, shape, *, loc=None, ip=None): + shape = _pack_shape(shape, loc=loc, ip=ip) + return _unpack_x_tuple( + _cute_ir.tiled_mma_partition_shape( + operand_id, self._trait.value, shape, loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def partition_shape_A(self, shape_mk, *, loc=None, ip=None): + return self._partition_shape(_cute_ir.MmaOperand.A, shape_mk, loc=loc, ip=ip) + + @dsl_user_op + def partition_shape_B(self, shape_nk, *, loc=None, ip=None): + return self._partition_shape(_cute_ir.MmaOperand.B, shape_nk, loc=loc, ip=ip) + + @dsl_user_op + def partition_shape_C(self, shape_mn, *, loc=None, ip=None): + return self._partition_shape(_cute_ir.MmaOperand.C, shape_mn, loc=loc, ip=ip) + + # + # _thrfrg + # + + @overload + def _thrfrg(self, operand_id, input: Layout, *, loc=None, ip=None) -> Layout: ... + + @overload + def _thrfrg(self, operand_id, input: Tensor, *, loc=None, ip=None) -> Tensor: ... + + def _thrfrg(self, operand_id, input, *, loc=None, ip=None) -> Union[Tensor, Layout]: + if isinstance(input, Tensor): + return make_tensor( + input.iterator, + self._thrfrg(operand_id, input.layout, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + elif isinstance(input, Layout): + if not is_static(input.type): + raise ValueError(f"Expects a static layout but got {input.type}") + return static( + self._trait.value.type.thrfrg(operand_id, input), loc=loc, ip=ip + ) + + raise ValueError( + f"Expects a layout or a tensor as input but got {type(input)=}" + ) + + def _thrfrg_A( + self, input: Union[Layout, Tensor], *, loc=None, ip=None + ) -> Union[Layout, Tensor]: + return self._thrfrg(_cute_ir.MmaOperand.A, input, loc=loc, ip=ip) + + def _thrfrg_B( + self, input: Union[Layout, Tensor], *, loc=None, ip=None + ) -> Union[Layout, Tensor]: + return self._thrfrg(_cute_ir.MmaOperand.B, input, loc=loc, ip=ip) + + def _thrfrg_C( + self, input: Union[Layout, Tensor], *, loc=None, ip=None + ) -> Union[Layout, Tensor]: + return self._thrfrg(_cute_ir.MmaOperand.C, input, loc=loc, ip=ip) + + +class ThrMma(TiledMma): + """ + The thread MMA class for modeling a thread-slice of a tiled MMA. + """ + + def __init__(self, op: Op, trait: Trait, thr_idx: Union[int, Int32]) -> None: + super().__init__(op, trait) + self._thr_idx = thr_idx + + def __new_from_mlir_values__(self, values): + return self.__class__( + self.op, new_from_mlir_values(self._trait, values), self.thr_idx + ) + + @property + def thr_idx(self): + return self._thr_idx + + @dsl_user_op + def partition_A(self, input_mk: Tensor, *, loc=None, ip=None) -> Tensor: + thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) + return _cute_ir.tiled_mma_partition( + _cute_ir.MmaOperand.A, + self._trait.value, + input_mk.value, + thr_idx, + loc=loc, + ip=ip, + ) + + @dsl_user_op + def partition_B(self, input_nk: Tensor, *, loc=None, ip=None) -> Tensor: + thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) + return _cute_ir.tiled_mma_partition( + _cute_ir.MmaOperand.B, + self._trait.value, + input_nk.value, + thr_idx, + loc=loc, + ip=ip, + ) + + @dsl_user_op + def partition_C(self, input_mn: Tensor, *, loc=None, ip=None) -> Tensor: + thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) + return _cute_ir.tiled_mma_partition( + _cute_ir.MmaOperand.C, + self._trait.value, + input_mn.value, + thr_idx, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mma_atom(op: MmaOp, *, loc=None, ip=None, **kwargs) -> MmaAtom: + """ + Makes an MMA Atom from an MMA Operation. + + This function creates an MMA Atom from a given MMA Operation. Arbitrary kw arguments can be + provided for Op-specific additional parameters. They are not used as of today. + + :param op: The MMA Operation to construct an Atom for + :type op: MmaOp + :return: The MMA Atom + :rtype: MmaAtom + """ + trait = op._make_trait(loc=loc, ip=ip, **kwargs) + return MmaAtom(op, trait) + + +@dsl_user_op +def make_tiled_mma( + op_or_atom: Union[Op, MmaAtom], + atom_layout_mnk=(1, 1, 1), + permutation_mnk=None, + *, + loc=None, + ip=None, + **kwargs, +) -> TiledMma: + """ + Makes a tiled MMA from an MMA Operation or an MMA Atom. + + :param op_or_atom: The MMA Operation or Atom + :type op_or_atom: Union[Op, MmaAtom] + :param atom_layout_mnk: A Layout describing the tiling of Atom across threads + :type atom_layout_mnk: Layout + :param permutation_mnk: A permutation Tiler describing the tiling of Atom across values including any permutation of such tiling + :type permutation_mnk: Tiler + :return: The resulting tiled MMA + :rtype: TiledMma + """ + if isinstance(op_or_atom, Op): + op = op_or_atom + atom = make_mma_atom(op_or_atom, loc=loc, ip=ip, **kwargs) + elif isinstance(op_or_atom, MmaAtom): + op = op_or_atom.op + atom = op_or_atom + else: + raise TypeError( + f"expected an MMA Op or Atom, but got an instance of {type(op_or_atom)}" + ) + if isinstance(atom_layout_mnk, tuple): + atom_layout_mnk = make_layout(atom_layout_mnk, loc=loc, ip=ip) + if rank(atom_layout_mnk) != 3: + raise ValueError(f"expects rank-3 MNK atom layout, but got {atom_layout_mnk}") + permutation_mnk_ty = None + if permutation_mnk is not None: + permutation_mnk_ty = _pack_tile(permutation_mnk, loc=loc, ip=ip).type + ty = _cute_nvgpu_ir.TiledMmaType.get( + atom._trait.value.type, + atom_layout_mnk.type, + permutation_mnk_ty, + ) + val = _cute_ir.make_tiled_mma(ty, atom._trait.value, loc=loc, ip=ip) + # Instead of modifying atom which might have been provided by the user, create a brand new + # trait instance and replace the Atom ir.Value with the tiled one + trait = new_from_mlir_values(atom._trait, [val]) + return TiledMma(op, trait) + + +#################################################################################################### +# +# Copy Atoms, TiledCopy, and ThrCopy +# +#################################################################################################### + + +class CopyAtom(Atom): + """ + The Copy Atom class. + """ + + def __str__(self) -> str: + res = "Copy Atom\n" + res += " ThrID: " + str(self.thr_id) + "\n" + res += " TV Layout Src: " + str(self.layout_src_tv) + "\n" + res += " TV Layout Dst: " + str(self.layout_dst_tv) + "\n" + res += " Value type: " + str(self._trait.value.type.value_type) + return res + + # + # Properties + # + + @property + def value_type(self) -> Type[Numeric]: + return Numeric.from_mlir_type(self._trait.value.type.value_type) + + @property + def thr_id(self) -> Layout: + return static(self._trait.value.type.thr_id) + + @property + def layout_src_tv(self) -> Layout: + return static(self._trait.value.type.layout_src_tv) + + @property + def layout_dst_tv(self) -> Layout: + return static(self._trait.value.type.layout_dst_tv) + + +class TiledCopy(CopyAtom): + """ + The tiled Copy class. + """ + + def __str__(self) -> str: + res = "Tiled Copy\n" + res += " Tiler MN: " + pretty_str(self.tiler_mn) + "\n" + res += " TV Layout tiled: " + str(self.layout_tv_tiled) + "\n" + res += "Copy Atom\n" + res += " ThrID: " + str(self.thr_id) + "\n" + res += " TV Layout Src: " + str(self.layout_src_tv) + "\n" + res += " TV Layout Dst: " + str(self.layout_dst_tv) + "\n" + res += " Value type: " + str(self._trait.value.type.value_type) + return res + + # + # Properties + # + + @property + def layout_tv_tiled(self) -> Layout: + return static(self._trait.value.type.layout_tv_tiled) + + @property + def tiler_mn(self) -> Tile: + return _unpack_x_tuple(self._trait.value.type.tiler_mn) + + @property + def layout_src_tv_tiled(self) -> Layout: + return static(self._trait.value.type.layout_src_tv_tiled) + + @property + def layout_dst_tv_tiled(self) -> Layout: + return static(self._trait.value.type.layout_dst_tv_tiled) + + @property + def size(self) -> int: + return self._trait.value.type.size + + # + # get_slice and retile + # + + def get_slice(self, thr_idx: Union[int, Int32]) -> "ThrCopy": + return ThrCopy(self.op, self._trait, thr_idx) + + @dsl_user_op + def retile(self, src, *, loc=None, ip=None): + return _cute_ir.tiled_copy_retile( + tiled_copy=self._trait.value, input=src.value, loc=loc, ip=ip + ) + + +class ThrCopy(TiledCopy): + """ + The thread Copy class for modeling a thread-slice of a tiled Copy. + """ + + def __init__(self, op: Op, trait: Trait, thr_idx: Union[int, Int32]) -> None: + super().__init__(op, trait) + self._thr_idx = thr_idx + + def __new_from_mlir_values__(self, values): + return self.__class__( + self.op, new_from_mlir_values(self._trait, values), self.thr_idx + ) + + @property + def thr_idx(self): + return self._thr_idx + + @dsl_user_op + def partition_S(self, src: Tensor, *, loc=None, ip=None) -> Tensor: + thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) + return _cute_ir.tiled_copy_partition_S( + self._trait.value, src.value, thr_idx, loc=loc, ip=ip + ) + + @dsl_user_op + def partition_D(self, dst: Tensor, *, loc=None, ip=None) -> Tensor: + thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) + return _cute_ir.tiled_copy_partition_D( + self._trait.value, dst.value, thr_idx, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_copy_atom( + op: CopyOp, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs +) -> CopyAtom: + """ + Makes a Copy Atom from a Copy Operation. + + This function creates a Copy Atom from a given Copy Operation. Arbitrary kw arguments can be + provided for Op-specific additional parameters. + + Example: + + .. code-block:: python + + op = cute.nvgpu.CopyUniversalOp() + atom = cute.make_copy_atom(op, tensor_dtype, num_bits_per_copy=64) + + :param op: The Copy Operation to construct an Atom for + :type op: CopyOp + :param copy_internal_type: An internal data type used to construct the source/destination layouts in unit of tensor elements + :type copy_internal_type: Type[Numeric] + :return: The Copy Atom + :rtype: CopyAtom + """ + trait = op._make_trait(copy_internal_type, loc=loc, ip=ip, **kwargs) + return CopyAtom(op, trait) + + +def _make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): + if type(tiler_mn) is tuple: + tiler_mn = _pack_tile(tiler_mn, loc=loc, ip=ip) + + assert isinstance(tiler_mn, ir.Value) and _cute_ir.TileType.isinstance( + tiler_mn.type + ), f"tiler_mn must be a Tile, but got {type(tiler_mn)}" + assert is_static(layout_tv.type) and is_static(tiler_mn.type), ( + "layout tv and tiler mn must be static" + ) + tiled_copy_ty = _cute_nvgpu_ir.TiledCopyType.get( + atom.type, layout_tv.type, tiler_mn.type + ) + + val = _cute_ir.make_tiled_copy(tiled_copy_ty, atom._trait.value, loc=loc, ip=ip) + # Instead of modifying atom which might have been provided by the user, create a brand new + # trait instance and replace the Atom ir.Value with the tiled one + trait = new_from_mlir_values(atom._trait, [val]) + return TiledCopy(atom.op, trait) + + +def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): + """Create a tiled type given a TV partitioner and tiler. + + :param atom: Copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc. + :type atom: CopyAtom + :param layout_tv: Thread-value layout + :type layout_tv: Layout + :param tiler_mn: Tile size + :type tiler_mn: Tiler + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) + + +@dsl_user_op +def make_tiled_copy_tv( + atom: CopyAtom, thr_layout: Layout, val_layout: Layout, *, loc=None, ip=None +) -> TiledCopy: + """Create a tiled copy given separate thread and value layouts. + + A TV partitioner is inferred based on the input layouts. The input thread layout + must be compact. + + :param atom: Copy atom + :type atom: CopyAtom + :param thr_layout: Layout mapping from ``(TileM,TileN)`` coordinates to thread IDs (must be compact) + :type thr_layout: Layout + :param val_layout: Layout mapping from ``(ValueM,ValueN)`` coordinates to value IDs + :type val_layout: Layout + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + tiler_mn, layout_tv = make_layout_tv(thr_layout, val_layout, loc=loc, ip=ip) + tiler_mn = _pack_tile(product_each(tiler_mn, loc=loc, ip=ip), loc=loc, ip=ip) + return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) + + +@dsl_user_op +def make_cotiled_copy( + atom: CopyAtom, atom_layout_tv: Layout, data_layout: Layout, *, loc=None, ip=None +) -> TiledCopy: + """ + Produce a TiledCopy from thread and value offset maps. + The TV Layout maps threads and values to the codomain of the data_layout. + It is verified that the intended codomain is valid within data_layout. + Useful when threads and values don't care about owning specific coordinates, but + care more about the vector-width and offsets between them. + + Parameters + ---------- + atom : copy atom, e.g. simt_copy and simt_async_copy, tgen05.st, etc. + atom_layout_tv : (tid, vid) -> data addr + data_layout : data coord -> data addr + loc : source location for mlir (optional) + ip : insertion point (optional) + + Returns + ------- + tiled_copy + A tuple of A tiled copy and atom + """ + assert is_static(atom_layout_tv.type) and is_static(data_layout.type), ( + "atom_layout_tv and data_layout must be static" + ) + # data addr -> data coord + inv_layout_ = left_inverse(data_layout, loc=loc, ip=ip) + inv_data_layout = make_layout( + (inv_layout_.shape, (1)), stride=(inv_layout_.stride, (0)), loc=loc, ip=ip + ) + # (tid,vid) -> data_coord + layout_tv_data = composition(inv_data_layout, atom_layout_tv, loc=loc, ip=ip) + + # check validity + atom_layout_v_to_check = coalesce( + make_layout( + atom_layout_tv.shape[1], stride=atom_layout_tv.stride[1], loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + data_layout_v_to_check = coalesce( + composition( + data_layout, + make_layout( + layout_tv_data.shape[1], stride=layout_tv_data.stride[1], loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + assert data_layout_v_to_check == atom_layout_v_to_check, ( + "the memory pointed to by atom_layout_tv does not exist in the data_layout." + ) + + flat_data_shape = product_each(data_layout.shape, loc=loc, ip=ip) + tiler = tuple( + filter( + composition( + make_layout( + flat_data_shape, + stride=tuple( + 0 if j != i else 1 for j in range(rank(flat_data_shape)) + ), + loc=loc, + ip=ip, + ), + layout_tv_data, + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + for i in range(rank(flat_data_shape)) + ) + # tile_coord -> data_coord + tile2data = composition( + make_layout(flat_data_shape, loc=loc, ip=ip), tiler, loc=loc, ip=ip + ) + # (tid,vid) -> tile_coord + layout_tv = composition( + left_inverse(tile2data, loc=loc, ip=ip), layout_tv_data, loc=loc, ip=ip + ) + return _make_tiled_copy(atom, layout_tv, tiler, loc=loc, ip=ip) + + +@dsl_user_op +def make_tiled_copy_A(atom, tiled_mma, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the A-Layout of tiled_mma. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_mma: Tiled MMA + :type tiled_mma: TiledMma + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, + tiled_mma.tv_layout_A_tiled, + (tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_tiled_copy_B(atom, tiled_mma, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the B-Layout of tiled_mma. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_mma: Tiled MMA + :type tiled_mma: TiledMma + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, + tiled_mma.tv_layout_B_tiled, + (tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_tiled_copy_C(atom, tiled_mma, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the C-Layout of tiled_mma. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_mma: Tiled MMA + :type tiled_mma: TiledMma + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, + tiled_mma.tv_layout_C_tiled, + (tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(1)), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_tiled_copy_S(atom, tiled_copy, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the Src-Layout of tiled_copy. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_copy: Tiled copy + :type tiled_copy: TiledCopy + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, tiled_copy.layout_src_tv_tiled, tiled_copy.tiler_mn, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_tiled_copy_D(atom, tiled_copy, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the Dst-Layout of tiled_copy. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_copy: Tiled copy + :type tiled_copy: TiledCopy + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, tiled_copy.layout_dst_tv_tiled, tiled_copy.tiler_mn, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_tiled_copy_C_atom(atom: CopyAtom, mma: TiledMma, *, loc=None, ip=None): + """Create the smallest tiled copy that can retile LayoutC_TV for use with pipelined epilogues with subtiled stores. + + :param atom: Copy atom + :type atom: CopyAtom + :param mma: Tiled MMA + :type mma: TiledMma + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for partitioner + :rtype: TiledCopy + + :raises ValueError: If the number value of CopyAtom's source layout is greater than the size of TiledMma's LayoutC_TV + """ + # Truncate the V-layout to just the Copy_Atom, keep the V-order + layoutC_tv = mma.tv_layout_C_tiled + val_layout_src = atom.layout_src_tv + num_val_src = size(val_layout_src, mode=[1], loc=loc, ip=ip) + num_val_layoutC_tv = size(layoutC_tv, mode=[1], loc=loc, ip=ip) + if num_val_src > num_val_layoutC_tv: + raise ValueError( + f"The number value of CopyAtom's source layout {num_val_src} " + f"is greater than the size of TiledMma's LayoutC_TV {num_val_layoutC_tv}" + ) + layout_TV = composition( + layoutC_tv, + make_layout( + (size(layoutC_tv, mode=[0], loc=loc, ip=ip), num_val_src), loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + + # Recompute tiler and restride the TV layout for the new tiler + + # Tiler -- Find the active elements in the MMA tensor and generate a tiler to extract them + # Convert to the awkward by-mode tiler to preserve the modes of the tiled MMA + mma_tiler = (mma.get_tile_size(0), mma.get_tile_size(1)) + + tiler_0 = filter( + composition( + make_layout(mma_tiler, stride=(1, 0), loc=loc, ip=ip), + layout_TV, + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + tiler_1 = filter( + composition( + make_layout(mma_tiler, stride=(0, 1), loc=loc, ip=ip), + layout_TV, + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + tiler = (tiler_0, tiler_1) + + tile2mma = composition( + make_layout(mma_tiler, loc=loc, ip=ip), tiler, loc=loc, ip=ip + ) + layout_tv = composition( + left_inverse(tile2mma, loc=loc, ip=ip), layout_TV, loc=loc, ip=ip + ) + + tiler_mn = _pack_tile(tiler, loc=loc, ip=ip) + + return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) + + +@dsl_user_op +def copy_atom_call( + atom: CopyAtom, + src: Tensor, + dst: Tensor, + *, + pred: Optional[Tensor] = None, + loc=None, + ip=None, + **kwargs, +) -> None: + """Executes a single copy atom operation between two tensors. + + :param atom: Copy atom specifying the transfer operation + :type atom: CopyAtom + :param src: Source tensor with layout profile ``(V)`` + :type src: Tensor + :param dst: Destination tensor with layout profile ``(V)`` + :type dst: Tensor + :param pred: Optional predication tensor for conditional transfers, defaults to None + :type pred: Optional[Tensor], optional + :param loc: Source location information, defaults to None + :type loc: Any, optional + :param ip: Insertion point, defaults to None + :type ip: Any, optional + :param kwargs: Additional copy atom specific arguments + :type kwargs: Dict[str, Any] + :raises TypeError: If source and destination element type bit widths differ + :return: None + :rtype: None + + The copy_atom_call operation executes a single copy atom with the given operands. + Source and destination tensors with layout profile like ``(V)``. + + The ``V-mode`` represents either: + + - A singular mode directly consumable by the provided Copy Atom + - A composite mode requiring recursive decomposition, structured as ``(V, Rest...)``, + + For src/dst layout like ``(V, Rest...)``, the layout profile of ``pred`` must match ``(Rest...)``. + + **Examples**: + + .. code-block:: python + + # Basic copy atom operation + cute.copy_atom_call(copy_atom, src, dst) + + # Predicated copy atom operation + cute.copy_atom_call(copy_atom, src, dst, pred=pred) + + .. note:: + + - Certain Atoms may require additional operation-specific keyword arguments. + - Current implementation limits ``V-mode`` rank to 2 or less. Support for higher ranks is planned + for future releases. + + """ + if isinstance(src.type, _cute_ir.MemRefType) and isinstance( + dst.type, _cute_ir.MemRefType + ): + if src.element_type.width != dst.element_type.width: + raise TypeError( + "`copy_atom_call` currently only supports equal source and destination " + "element type bit width" + ) + + if rank(src, mode=[0]) > 2 or rank(dst, mode=[0]) > 2: + raise NotImplementedError( + "V-mode (mode-0) with rank > 2 is not supported yet, " + f"but got rank(src, mode=[0]) = {rank(src, mode=[0])} and rank(dst, mode=[0]) = {rank(dst, mode=[0])}" + ) + + value = atom._unpack(loc=loc, ip=ip, **kwargs) + if isinstance(pred, Tensor): + pred = pred.value + return _cute_ir.copy_atom_call( + value, src.value, dst.value, pred=pred, loc=loc, ip=ip + ) diff --git a/python/CuTeDSL/cutlass/cute/core.py b/python/CuTeDSL/cutlass/cute/core.py index 12d5e422..204fda58 100644 --- a/python/CuTeDSL/cutlass/cute/core.py +++ b/python/CuTeDSL/cutlass/cute/core.py @@ -9,84 +9,62 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -import copy as py_copy -from dataclasses import dataclass -import inspect -import math -import operator -from abc import ABC, abstractmethod -from functools import lru_cache, partial, reduce +from functools import partial, reduce from inspect import isclass -from itertools import chain -from typing import ( - Callable, - Iterable, - overload, - List, - Tuple, - Union, - Type, - Any, - Dict, - Optional, -) -from enum import Enum, auto +from typing import Any, Dict, List, Optional, Tuple, Type, Union, overload -from cutlass.cutlass_dsl import ( - const, - T, - lru_cache_ir, - is_dynamic_expression, - for_generate, - yield_out, - if_generate, - extract_mlir_values, - new_from_mlir_values, - _binary_op_type_promote, - not_, - cutlass_arith, - dsl_user_op, -) +from typing_extensions import deprecated from cutlass._mlir import ir -from cutlass._mlir.dialects._ods_common import get_op_result_or_op_results +from cutlass._mlir.dialects import builtin, llvm from cutlass._mlir.dialects import cute as _cute_ir from cutlass._mlir.dialects.cute import ( - ScaledBasis as _ScaledBasis, Ratio as _Ratio, ) +from cutlass._mlir.dialects.cute import ( + ReductionOp as ReductionOp, +) +from cutlass._mlir.dialects.cute import ( + ScaledBasis as _ScaledBasis, +) +from cutlass.cutlass_dsl import ( + T, + const, + cutlass_arith, + dsl_user_op, + extract_mlir_values, + is_dynamic_expression, + lru_cache_ir, + not_, +) -from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir -from cutlass._mlir.dialects import llvm, builtin, vector, arith - +from .tuple import find_if, flatten_to_tuple, product_each, transform_leaf, wrap from .typing import ( - Numeric, - Integer, - NumericMeta, + AddressSpace, Boolean, - Int32, - Int8, + ComposedLayout, + Coord, + Float32, + Int, Int16, Int32, Int64, - Float32, - TFloat32, - Int, + Integer, IntTuple, + Layout, + Numeric, + NumericMeta, + Pointer, Shape, Stride, - Coord, - Layout, + Tensor, Tile, Tiler, XTuple, - Tensor, - Pointer, - AddressSpace, - as_numeric, + is_int_tuple, + is_integer, ) - #################################################################################################### # # Internal IntTuple helpers @@ -96,9 +74,10 @@ from .typing import ( def _get_typed_value(x): if isinstance(x, Integer): - return ( - x.value.get_typed_value() if isinstance(x.value, IntValue) else x.ir_value() - ) + x = x.ir_value() + + if isinstance(x, IntValue): + return x.get_typed_value() else: return x @@ -118,12 +97,18 @@ def _pack_shape(shape: Shape, *, loc=None, ip=None) -> ir.Value: def _pack_stride(stride: Stride, *, loc=None, ip=None) -> ir.Value: _check_stride(stride) + dyn_elems = map(_get_typed_value, extract_mlir_values(stride)) # Convert basis elements to the base class before _pack_x stride = transform_leaf( - lambda x: x.to(_cute_ir.ScaledBasis) if isinstance(x, ScaledBasis) else x, + lambda x: ( + x.to(_cute_ir.ScaledBasis) + if isinstance(x, ScaledBasis) + else _get_typed_value(x) + ), stride, ) - return _pack_x(stride, _cute_ir.pack_stride, _cute_ir.MakeStrideOp, loc=loc, ip=ip) + res_ty, _ = _cute_ir.pack_stride(stride) + return _cute_ir.MakeStrideOp(res_ty, dyn_elems, loc=loc, ip=ip).result def _pack_coord(coord: Coord, *, loc=None, ip=None) -> ir.Value: @@ -166,7 +151,7 @@ def _unpack_x_tuple(t: Union[ir.Type, ir.Value], *, loc=None, ip=None) -> XTuple if isinstance(t, ir.Type): if not _cute_ir.is_static(t): raise ValueError() - t = _cute_ir.static(t) + t = static(t, loc=loc, ip=ip) if isinstance(t, ir.Value): input_ty = t.type @@ -174,14 +159,14 @@ def _unpack_x_tuple(t: Union[ir.Type, ir.Value], *, loc=None, ip=None) -> XTuple # Handle this case separately, _cute_ir.get_leaves will return an Op in this case vals = [] else: - vals = _cute_ir.get_leaves(t, loc=loc, ip=ip) + vals = get_leaves(t, loc=loc, ip=ip) if not isinstance(vals, list): vals = [vals] else: raise TypeError(f"expects static type or value, but got {t}") # CuTe IR only supports Int32 for now. Need to support detection of other types - res = _cute_ir.unpack_x_tuple(input_ty, vals) + res = _cute_ir.unpack_x_tuple(input_ty, vals, loc=loc) def post_process(x): if isinstance(x, _cute_ir.ScaledBasis): @@ -206,7 +191,9 @@ def _check_shape(shape: Shape) -> None: raise ValueError( f"Expected size in shape to be strictly positive, but got {shape}" ) - elif isinstance(shape, Integer): + elif isinstance(shape, Integer) or ( + isinstance(shape, ir.Value) and isinstance(shape.type, ir.IntegerType) + ): pass else: raise TypeError(f"Expected size be int or Integer, but got {type(shape)}") @@ -281,18 +268,20 @@ class IntValue(cutlass_arith.ArithValue): * get_divisibility() - Returns the divisibility constraint of the value """ - def __init__(self, v, signed=True): + @dsl_user_op + def __init__(self, v, signed=True, *, loc=None, ip=None): # Cute Constrained Int Type is always signed if isinstance(v, int): - v = _pack_int_tuple(v) + v = _pack_int_tuple(v, loc=loc, ip=ip) if isinstance(v.type, _cute_ir.IntTupleType): - scalar_val = _cute_ir.get_scalars(v) - super().__init__(scalar_val, True) + scalar_val = _cute_ir.get_scalars(v, loc=loc, ip=ip) + super().__init__(scalar_val, True, loc=loc, ip=ip) else: - super().__init__(v, True) + super().__init__(v, True, loc=loc, ip=ip) - def get_typed_value(self): + @dsl_user_op + def get_typed_value(self, *, loc=None, ip=None): if isinstance(self.type, ir.IntegerType): def_op = self.owner.operation if def_op.name == "cute.get_scalars": @@ -300,18 +289,21 @@ class IntValue(cutlass_arith.ArithValue): assert not isinstance(self.type, _cute_ir.IntTupleType) - return _pack_int_tuple(self) + # get_typed_value is called by _pack_int_tuple, copy code to avoid + # recursive calls + res_ty, _ = _cute_ir.pack_int_tuple(self) + return _cute_ir.MakeIntTupleOp(res_ty, [self], loc=loc, ip=ip).result @property def divisibility(self): - if isinstance(self.get_typed_value().type, _cute_ir.IntTupleType): - return self.get_typed_value().type.get_divisibility([0]) - else: - return 1 + assert isinstance( + self.get_typed_value().type, _cute_ir.IntTupleType + ), f"expected self.get_typed_value() to be int_tuple type, but got {self.get_typed_value().type}" + return self.get_typed_value().type.get_divisibility([0]) def __str__(self): if self.divisibility == 1: - return f"?" + return "?" else: return f"?{{div={self.divisibility}}}" @@ -322,7 +314,6 @@ class IntValue(cutlass_arith.ArithValue): def pretty_str(self): return self.__str__() - @staticmethod def _binary_op(op): def wrapper(self, other, **kwargs): if isinstance(other, IntValue): @@ -332,7 +323,6 @@ class IntValue(cutlass_arith.ArithValue): ): other_val = other elif isinstance(other, ir.Value) and isinstance(other.type, ir.IntegerType): - other = cutlass_arith.int_to_int(other, Int32, **kwargs) other_val = _pack_int_tuple(other) elif isinstance(other, (int, bool)): other_val = _pack_int_tuple(int(other)) @@ -347,52 +337,72 @@ class IntValue(cutlass_arith.ArithValue): @dsl_user_op @_binary_op def __add__(self, other, *, loc=None, ip=None): - return _cute_ir.add_offset(self.get_typed_value(), other, loc=loc, ip=ip) + return _cute_ir.add_offset( + self.get_typed_value(loc=loc, ip=ip), other, loc=loc, ip=ip + ) @dsl_user_op @_binary_op def __sub__(self, other, *, loc=None, ip=None): - return _cute_ir.tuple_sub(self.get_typed_value(), other, loc=loc, ip=ip) + return _cute_ir.tuple_sub( + self.get_typed_value(loc=loc, ip=ip), other, loc=loc, ip=ip + ) @dsl_user_op @_binary_op def __mul__(self, other, *, loc=None, ip=None): - return _cute_ir.tuple_mul(self.get_typed_value(), other, loc=loc, ip=ip) + return _cute_ir.tuple_mul( + self.get_typed_value(loc=loc, ip=ip), other, loc=loc, ip=ip + ) @dsl_user_op @_binary_op def __floordiv__(self, other, *, loc=None, ip=None) -> "IntValue": - return _cute_ir.tuple_div(self.get_typed_value(), other, loc=loc, ip=ip) + return _cute_ir.tuple_div( + self.get_typed_value(loc=loc, ip=ip), other, loc=loc, ip=ip + ) @dsl_user_op @_binary_op def __mod__(self, other, *, loc=None, ip=None) -> cutlass_arith.ArithValue: - return _cute_ir.tuple_mod(self.get_typed_value(), other, loc=loc, ip=ip) + return _cute_ir.tuple_mod( + self.get_typed_value(loc=loc, ip=ip), other, loc=loc, ip=ip + ) @dsl_user_op @_binary_op def __radd__(self, other, *, loc=None, ip=None) -> "IntValue": - return _cute_ir.add_offset(other, self.get_typed_value(), loc=loc, ip=ip) + return _cute_ir.add_offset( + other, self.get_typed_value(loc=loc, ip=ip), loc=loc, ip=ip + ) @dsl_user_op @_binary_op def __rsub__(self, other, *, loc=None, ip=None) -> "IntValue": - return _cute_ir.tuple_sub(other, self.get_typed_value(), loc=loc, ip=ip) + return _cute_ir.tuple_sub( + other, self.get_typed_value(loc=loc, ip=ip), loc=loc, ip=ip + ) @dsl_user_op @_binary_op def __rmul__(self, other, *, loc=None, ip=None): - return _cute_ir.tuple_mul(other, self.get_typed_value(), loc=loc, ip=ip) + return _cute_ir.tuple_mul( + other, self.get_typed_value(loc=loc, ip=ip), loc=loc, ip=ip + ) @dsl_user_op @_binary_op def __rfloordiv__(self, other, *, loc=None, ip=None) -> "IntValue": - return _cute_ir.tuple_div(other, self.get_typed_value(), loc=loc, ip=ip) + return _cute_ir.tuple_div( + other, self.get_typed_value(loc=loc, ip=ip), loc=loc, ip=ip + ) @dsl_user_op @_binary_op def __rmod__(self, other, *, loc=None, ip=None) -> "IntValue": - return _cute_ir.tuple_mod(other, self.get_typed_value(), loc=loc, ip=ip) + return _cute_ir.tuple_mod( + other, self.get_typed_value(loc=loc, ip=ip), loc=loc, ip=ip + ) class Ratio(_Ratio): @@ -532,7 +542,7 @@ class ScaledBasis: self._mode = [mode] else: if any(not isinstance(x, int) for x in mode): - raise TypeError("Mode must be a list of integers") + raise TypeError(f"Mode must be a list of integers, but got {mode}") self._mode = mode self._value = value @@ -545,11 +555,16 @@ class ScaledBasis: """ return not is_dynamic_expression(self._value) - def to(self, dtype): + @dsl_user_op + def to(self, dtype, *, loc=None, ip=None): """Convert to another type. :param dtype: The target type for conversion :type dtype: type + :param loc: The source location for the operation, defaults to None + :type loc: Location, optional + :param ip: The insertion point for the operation, defaults to None + :type ip: InsertionPoint, optional :return: The ScaledBasis converted to the specified type :raises TypeError: If conversion to the specified type is not supported """ @@ -557,15 +572,13 @@ class ScaledBasis: return self elif dtype is _ScaledBasis: if isinstance(self._value, Ratio): - scale = self._value - elif isinstance(self._value, Integer): - scale = self._value.ir_value() - else: - scale = self._value + return _ScaledBasis(self._value, self._mode) - if isinstance(scale, IntValue): - return _ScaledBasis(scale.get_typed_value(), self._mode) + if isinstance(self._value, Integer): + scale = self._value.ir_value(loc=loc, ip=ip) + return _ScaledBasis(scale, self._mode, get_divisibility(scale)) else: + scale = self._value return _ScaledBasis(scale, self._mode) else: raise TypeError(f"Cannot convert ScaledBasis to {dtype}") @@ -574,10 +587,7 @@ class ScaledBasis: return f"{self.to(_ScaledBasis).__str__()}" def __hash__(self): - if isinstance(self.mode, list): - return hash((self.value, tuple(self.mode))) - else: - return hash((self.value, self.mode)) + return hash((self.value, tuple(self.mode))) @property def value(self): @@ -602,7 +612,9 @@ class ScaledBasis: else: return False - def __rmul__(self, scale: Union[Int, ir.Value, Ratio]) -> "ScaledBasis": + def __rmul__( + self, scale: Union[Int, ir.Value, Ratio], *, loc=None, ip=None + ) -> "ScaledBasis": """Right multiplication by a scale factor. This operation is used in layout algebra to scale basis elements, @@ -610,6 +622,10 @@ class ScaledBasis: :param scale: The scale factor :type scale: Union[Int, ir.Value, Ratio] + :param loc: The source location for the operation, defaults to None + :type loc: Location, optional + :param ip: The insertion point for the operation, defaults to None + :type ip: InsertionPoint, optional :return: A new scaled basis element :rtype: ScaledBasis :raises TypeError: If scale is not of a supported type @@ -631,15 +647,22 @@ class ScaledBasis: # Lift to IntValue type to preserve type info as much as possible if isinstance(scale, cutlass_arith.ArithValue): - scale = IntValue(_pack_int_tuple(cutlass_arith.int_to_int(scale, Int32))) + scale = IntValue(_pack_int_tuple(scale)) if isinstance(value, cutlass_arith.ArithValue): - value = IntValue(_pack_int_tuple(cutlass_arith.int_to_int(value, Int32))) + value = IntValue(_pack_int_tuple(value)) elif isinstance(value, Integer): - value = value.ir_value() + value = value.ir_value(loc=loc, ip=ip) return ScaledBasis(scale * value, self.mode) # type: ignore + def __extract_mlir_values__(self): + if isinstance(self.value, Ratio): + # Ratio is always static + return [] + else: + return extract_mlir_values(self.value) + def E(mode: Union[int, List[int]]) -> ScaledBasis: """Create a unit ScaledBasis element with the specified mode. @@ -669,8 +692,8 @@ def E(mode: Union[int, List[int]]) -> ScaledBasis: if isinstance(mode, int): mode = [mode] - if not isinstance(mode, list): - raise TypeError(f"expects a list, got {type(mode)}") + if any(not isinstance(x, int) for x in mode): + raise TypeError(f"mode must be a list of integers, but got {mode}") if not mode: return 1 @@ -777,26 +800,26 @@ class _Layout(Layout): return f"{pretty_str(self.shape)}:{pretty_str(self.stride)}" @property + @dsl_user_op + @lru_cache_ir() def shape(self, *, loc=None, ip=None) -> Shape: """Get the shape of the layout. The shape defines the dimensions and structure of the layout's coordinate space. - :param loc: Optional location information for debugging. - :param ip: Optional insertion point for IR generation. :return: The hierarchical shape of the layout. """ return _unpack_x_tuple(_cute_ir.get_shape(self, loc=loc, ip=ip), loc=loc, ip=ip) @property + @dsl_user_op + @lru_cache_ir() def stride(self, *, loc=None, ip=None) -> Stride: """Get the stride of the layout. The stride defines how coordinates map to linear indices in memory. - :param loc: Optional location information for debugging. - :param ip: Optional insertion point for IR generation. :return: The hierarchical stride of the layout. """ return _unpack_x_tuple( @@ -858,7 +881,7 @@ class _Layout(Layout): """ if isinstance(other, Layout): return other.__ne__(self) - return False + return True def __getitem__(self, idx: int) -> Layout: """ @@ -868,7 +891,11 @@ class _Layout(Layout): @dsl_user_op def __call__(self, coord: Coord, loc=None, ip=None) -> IntTuple: - return crd2idx(coord, self, loc=loc, ip=ip) + if has_underscore(coord): + crd_val = _pack_coord(coord, loc=loc, ip=ip) + return _cute_ir.slice(self, crd_val, loc=loc, ip=ip) + else: + return crd2idx(coord, self, loc=loc, ip=ip) @dsl_user_op def get_hier_coord(self, idx, *, loc=None, ip=None) -> Coord: @@ -889,72 +916,25 @@ class _Layout(Layout): # map linear index back to coordinate: 5 -> (1, 1) coord = get_hier_coord(5, layout) """ - idx_val = Int32(idx).ir_value() + idx_val = Int32(idx).ir_value(loc=loc, ip=ip) crd = _cute_ir.get_hier_coord(idx_val, self, loc=loc, ip=ip) - return _unpack_x_tuple(crd) + return _unpack_x_tuple(crd, loc=loc, ip=ip) @dsl_user_op def get_flat_coord(self, idx, *, loc=None, ip=None) -> Coord: - idx_val = Int32(idx).ir_value() + idx_val = Int32(idx).ir_value(loc=loc, ip=ip) res = _cute_ir.get_flat_coord(idx_val, self, loc=loc, ip=ip) return _unpack_x_tuple(res, loc=loc, ip=ip) @ir.register_value_caster(_cute_ir.ComposedLayoutType.get_static_typeid(), replace=True) -class ComposedLayout(ir.Value): - r"""ComposedLayout represents the functional composition of layouts in CuTe. +class _ComposedLayout(ComposedLayout): + """DSL wrapper of built-in ComposedLayout of CuTe IR where inner layout is one of following: + - Swizzle + - normal Layout - A ComposedLayout is formed by the composition of three components: - inner o offset o outer, where: - - - inner: The inner layout or swizzle that is applied last - - offset: An integer tuple representing a coordinate offset - - outer: The outer layout that is applied first - - ComposedLayout implements the functional composition operation where: - - .. math:: - - R(c) := (inner \\circ offset \\circ outer)(c) := inner(offset + outer(c)) - - This composition allows for complex transformations of coordinates and indices, - enabling operations like tiling, partitioning, and reshaping of data. - - :ivar inner: The inner layout or swizzle component - :ivar offset: The coordinate offset applied between inner and outer layouts - :ivar outer: The outer layout component - :ivar max_alignment: The maximum alignment of the composed layout - - **Examples:** - - .. code-block:: python - - # Create a composed layout with inner layout, offset, and outer layout - - # inner layout: (4, 8):(1, 4) - inner_layout = make_layout((4, 8)) - - offset = (0, 0) - - # outer layout: (2, 2):(1@0, 1@1) - outer_layout = make_layout((2, 2), stride=(1 * E(0), 1 * E(1))) - - # composed layout: (inner o offset o outer) - composed = make_composed_layout(inner_layout, offset, outer_layout) - - # Accessing components of the composed layout - inner = composed.inner - offset = composed.offset - outer = composed.outer - - # map coordinate (0, 1) to linear index - # - outer(0, 1) = (0, 1) - # - offset + outer(0, 1) = (0, 1) - # - inner(0, 1) = 0 * 1 + 1 * 4 = 4 - idx = crd2idx((0, 1), composed) - - # Composition is used in many tiling operations - # For example, in logical_product, raked_product, and blocked_product + The generalized composed layout can support arbitrary function mapping from coordinate + to coordinate as inner layout. """ def __init__(self, value) -> None: @@ -962,33 +942,49 @@ class ComposedLayout(ir.Value): :param value: The operation result value to wrap. """ - super().__init__(value) + self.value = value def __str__(self) -> str: return f"{pretty_str(self.inner)} o {pretty_str(self.offset)} o {pretty_str(self.outer)}" @property + def type(self) -> ir.Type: + return self.value.type + + @property + def is_normal(self) -> bool: + return self.type.is_normal_layout + + @property + @dsl_user_op def inner(self, *, loc=None, ip=None) -> Union[Swizzle, Layout]: - return _cute_ir.composed_get_inner(self, loc=loc, ip=ip) + return _cute_ir.composed_get_inner(self.value, loc=loc, ip=ip) @property + @dsl_user_op def offset(self, *, loc=None, ip=None) -> IntTuple: - return _unpack_x_tuple(_cute_ir.composed_get_offset(self, loc=loc, ip=ip)) + return _unpack_x_tuple( + _cute_ir.composed_get_offset(self.value, loc=loc, ip=ip), loc=loc, ip=ip + ) @property + @dsl_user_op def outer(self, *, loc=None, ip=None) -> Layout: - return _cute_ir.composed_get_outer(self, loc=loc, ip=ip) + return _cute_ir.composed_get_outer(self.value, loc=loc, ip=ip) @property + @dsl_user_op def shape(self, *, loc=None, ip=None) -> Shape: - return _unpack_x_tuple(_cute_ir.get_shape(self, loc=loc, ip=ip), loc=loc, ip=ip) + return _unpack_x_tuple( + _cute_ir.get_shape(self.value, loc=loc, ip=ip), loc=loc, ip=ip + ) @property def max_alignment(self) -> int: return self.type.max_alignment def __eq__(self, other) -> Union[bool, Boolean]: - if isinstance(other, ComposedLayout): + if isinstance(other, _ComposedLayout): if is_static(self.type) and is_static(other.type): return self.type == other.type else: @@ -999,7 +995,7 @@ class ComposedLayout(ir.Value): return False def __req__(self, other) -> Union[bool, Boolean]: - if isinstance(other, ComposedLayout): + if isinstance(other, _ComposedLayout): return Boolean(other.__eq__(self)) return False @@ -1007,20 +1003,36 @@ class ComposedLayout(ir.Value): return not self.__eq__(other) def __rne__(self, other) -> Union[bool, Boolean]: - if isinstance(other, ComposedLayout): + if isinstance(other, _ComposedLayout): return other.__ne__(self) - return False + return True - def __getitem__(self, idx: int) -> "ComposedLayout": + @dsl_user_op + def __getitem__(self, idx: int, *, loc=None, ip=None) -> "_ComposedLayout": """ Top-level `get` to provide a syntax similar to `tuple`. """ - return get(self, mode=[idx]) + return get(self, mode=[idx], loc=loc, ip=ip) @dsl_user_op def __call__(self, coord: Coord, loc=None, ip=None) -> IntTuple: return crd2idx(coord, self, loc=loc, ip=ip) + def __extract_mlir_values__(self): + return [self.value] + + def __new_from_mlir_values__(self, values): + # Only expecting single value of _ComposedLayout or ir.Value + # In this context, a _ComposedLayout instance is an encapsulated ir.Value which is automatically created + # by value caster for ComposedLayout typed values + assert len(values) == 1, f"Expected 1 value, but got {len(values)}" + assert isinstance( + values[0], (_ComposedLayout, ir.Value) + ), f"Expected _ComposedLayout or ir.Value, but got {type(values[0])}" + return _ComposedLayout( + values[0] if isinstance(values[0], ir.Value) else values[0].value, + ) + @ir.register_value_caster(_cute_ir.PtrType.get_static_typeid(), replace=True) class _Pointer(Pointer): @@ -1074,8 +1086,11 @@ class _Pointer(Pointer): @property @lru_cache_ir() - def dtype(self) -> Type[Numeric]: - return Numeric.from_mlir_type(self.value.type.value_type) + def dtype(self) -> Union[Type[Numeric], _cute_ir.SparseElemType]: + if isinstance(self.value.type.value_type, _cute_ir.SparseElemType): + return self.value.type.value_type + else: + return Numeric.from_mlir_type(self.value.type.value_type) @property def alignment(self) -> int: @@ -1098,15 +1113,12 @@ class _Pointer(Pointer): # Only use if you absolutely need to get the LLVM pointer Value @property + @dsl_user_op @lru_cache_ir() def llvm_ptr(self, *, loc=None, ip=None) -> ir.Value: """ Get the LLVM pointer representation of this pointer. - :param loc: The source location for the operation, defaults to None - :type loc: Location, optional - :param ip: The insertion point for the operation, defaults to None - :type ip: InsertionPoint, optional :return: The LLVM pointer representation :rtype: ir.Value """ @@ -1115,19 +1127,29 @@ class _Pointer(Pointer): [llvm_ptr_ty], [self.value], loc=loc, ip=ip ) - def __add__(self, offset: IntTuple) -> Pointer: + @dsl_user_op + def __add__(self, offset: Int, *, loc=None, ip=None) -> Pointer: """ Offset the pointer by elements of a layout's codomain. :param offset: The offset to add to the pointer - :type offset: IntTuple + :type offset: Int :return: A new pointer offset by the specified amount :rtype: ir.Value """ - offset = _pack_int_tuple(offset) - return _cute_ir.add_offset(self.value, offset=offset) + offset = _pack_int_tuple(offset, loc=loc, ip=ip) # type: ignore + return _cute_ir.add_offset(self.value, offset=offset, loc=loc, ip=ip) @dsl_user_op + def __radd__(self, offset: Int, *, loc=None, ip=None) -> Pointer: + return self.__add__(offset, loc=loc, ip=ip) + + @dsl_user_op + def __sub__(self, offset: Int, *, loc=None, ip=None) -> Pointer: + return self.__add__(-offset, loc=loc, ip=ip) # type: ignore + + @dsl_user_op + @lru_cache_ir() def toint(self, *, loc=None, ip=None): if self.memspace in (AddressSpace.gmem, AddressSpace.generic): res_type = Int64 @@ -1181,540 +1203,6 @@ class _Pointer(Pointer): ) -@ir.register_value_caster(_cute_ir.MemRefType.get_static_typeid(), replace=True) -@ir.register_value_caster(_cute_ir.CoordTensorType.get_static_typeid(), replace=True) -@ir.register_value_caster( - _cute_nvgpu_ir.SmemDescViewType.get_static_typeid(), replace=True -) -class _Tensor(Tensor): - """A tensor class representing the composition of an iterator (engine) with a layout. - - A tensor evaluates the layout by mapping a coordinate to the codomain, offsets the - iterator accordingly, and dereferences the result to obtain the tensor's value. - Formally: T(c) = (E ∘ L)(c) = *(E + L(c)), where E is the iterator/engine and L is the layout. - - :param value: The MLIR operation result value to initialize the tensor with - :type value: ir.Value - :param dtype: The user specified data type of the tensor elements. It could be \ - different from the underlying dtype in the iterator. The default is None. - :type dtype: Type[Numeric], optional - - Attributes: - iterator: The pointer or iterator (engine) component of the tensor - layout: The layout component defining the mapping from coordinates to offsets - shape: The shape of the tensor, inherited from the layout - stride: The stride of the tensor, inherited from the layout - element_type: The data type of the tensor elements - memspace: The memory space where the tensor data resides - - Notes: - - The tensor supports both direct element access via coordinates and slicing operations - - Load/store operations are only supported for specific memory spaces (rmem, smem, gmem, generic) - - For composed layouts, stride information is not directly accessible - - Dynamic layouts do not support vector load/store operations - - **Examples:** - - .. code-block:: python - - # Create a tensor with shape (4,8) in row-major layout - tensor = make_tensor(ptr, make_layout(shape=(4,8), stride=(8,1))) - - # Access individual element - val = tensor[0, 0] # or val = tensor[(0, 0)] - - # Slice operation - get first column - subtensor = tensor[None, 0] # or subtensor = tensor[(None, 0)] - """ - - def __init__(self, value, dtype: Optional[Type[Numeric]] = None): - self._dtype = dtype - if isinstance(value, ir.Value): - self.value = value - elif isinstance(value, _Tensor): - self.value = value.value - else: - raise TypeError(f"Expected ir.Value or core._Tensor, got {type(value)}") - - # Set iterator - iter_val = _cute_ir.get_iter(self.value) - if isinstance(iter_val, Pointer): - self._iterator = iter_val - elif isinstance(iter_val.type, _cute_ir.IntTupleType): - self._iterator = _unpack_x_tuple(iter_val) - elif isinstance(iter_val, ir.Value): - # Example: SMEM descriptor iterator, not well supported today - self._iterator = iter_val - else: - raise TypeError(f"unsupported iterator type, got {type(iter_val)}") - - # Set dtype - if self._dtype is None: - if is_int_tuple(self.iterator): - self._dtype = IntTuple - elif isinstance(self.iterator, Pointer): - self._dtype = self.iterator.value_type - elif isinstance(self.type, _cute_nvgpu_ir.SmemDescViewType): - # SmemDescViewType do not need dtype - self._dtype = None - else: - raise TypeError(f"unsupported iterator type, got {type(self.iterator)}") - - def __str__(self): - return f"tensor<{pretty_str(self.iterator)} o {pretty_str(self.layout)}>" - - def __extract_mlir_values__(self): - return [self.value] - - def __new_from_mlir_values__(self, values): - # Only expecting single value of _Tensor or ir.Value - # In this context, a _Tensor instance is an encapsulated ir.Value which is automatically created - # by value caster for MemRef/CoordTensor/SmemDescView typed values - assert len(values) == 1, f"Expected 1 value, but got {len(values)}" - assert isinstance( - values[0], (_Tensor, ir.Value) - ), f"Expected _Tensor or ir.Value, but got {type(values[0])}" - return _Tensor( - values[0] if isinstance(values[0], ir.Value) else values[0].value, - dtype=self.element_type, - ) - - # Cheat to let `Type(_Tensor())` to return cute.Tensor - @property - def __class__(self) -> Type[Tensor]: - return Tensor - - # Make it behave as if it inherited from ir.Value - @property - @lru_cache_ir() - def type(self) -> ir.Type: - return self.value.type - - @dsl_user_op - def __getitem__( - self, crd: Coord, *, loc=None, ip=None - ) -> Union[Tensor, Numeric, IntTuple]: - """Access or slice tensor elements using coordinates. - - This method implements - * tensor evaluation T(c) = *(E + L(c)) when `c` is a coordinate without slicing, or - * tensor slicing operations T(c) = make_tensor(E + L(c), slice(L, c)) - where E is the iterator/engine and L is the layout - - :param crd: Coordinate or slice specification for accessing tensor elements - :type crd: Coord - :param loc: Source location for MLIR operation tracking, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation, defaults to None - :type ip: Optional[InsertionPoint] - :return: Tensor element value or sliced subtensor - :rtype: Union[Tensor, ir.Value, IntTuple] - - :raises ValueError: If coordinate access is invalid for the tensor layout - - **Examples:** - - .. code-block:: python - - # Create a tensor with pointer iterator - ptr = make_ptr(cutlass.Float32, 0, cutlass.AddressSpace.gmem) - layout = make_layout((64, 128)) # leftmost mode is major - tensor = make_tensor(ptr, layout) # Tensor using pointer iterator - - # Direct element access loads from memory - val = tensor[0] # Loads element at offset 0 - val = tensor[1] # Loads element at offset 4 (4bytes per Float32) - val = tensor[(0, 1)] # Loads element at offset 64 - - # Create a coord tensor - layout = make_layout((64, 128), stride=(1 * E(0), 1 * E(1))) - tensor = make_tensor((128, 128), layout) - - # Direct element access - val = tensor[0] # Returns (128, 128) - val = tensor[(0, 1)] # Returns (128, 129) - - # Slice access - sliced = view[(3, None)] # Returns tensor slice - - .. note:: - Sub-byte types like Float4E2M1FN and Float6E3M2FN are not supported for scalar - dereference operations. Attempting to set individual elements of tensors with - these element types will result in errors. - - **Examples:** - - .. code-block:: python - - # Unsupported operations with sub-byte types: - ptr = make_ptr(cutlass.Float4E2M1FN, 0, cutlass.AddressSpace.gmem) - tensor = make_tensor(ptr, layout) - # The following will raise an error: - val = tensor[0] # Error: sub-byte scalar dereference not supported - - # Similarly for other sub-byte types: - ptr = make_ptr(cutlass.Float6E3M2FN, 0, cutlass.AddressSpace.gmem) - tensor = make_tensor(ptr, layout) - val = tensor[0] # Error: sub-byte scalar dereference not supported - """ - if has_underscore(crd): - return slice_(self.value, crd) - elif isinstance(self.type, _cute_ir.CoordTensorType): - res = _cute_ir.get_iter(slice_(self, crd).value, loc=loc, ip=ip) - return _unpack_x_tuple(res) - else: - self._check_can_load_store() - self._check_can_dereference() - - crd_val = _pack_coord(crd, loc=loc, ip=ip) - data_val = _cute_ir.memref_load(self.value, crd_val, loc=loc, ip=ip) - return self.element_type(data_val) - - def _cvt_to_dest(self, data: Union["TensorSSA", Numeric], *, loc=None, ip=None): - orig_dtype = data.dtype - # Implicit upcast to wider type - if ( - data.dtype.is_same_kind(self.element_type) - and self.element_type.width >= data.dtype.width - ): - data = data.to(self.element_type, loc=loc, ip=ip) # type: ignore - - if data.dtype.width != self.element_type.width: - raise ValueError( - f"Type mismatch, store {orig_dtype} (-> {data.dtype}) " - f"to Tensor with element type {self.element_type}" - ) - - if data.dtype is Boolean and self.element_type is Boolean: - # Boolean Numeric and Boolean TensorSSA both hold i1 value, but we need int8 value store to memory - val = data.ir_value_int8() - else: - val = data.ir_value() - return val - - @dsl_user_op - def __setitem__( - self, - crd: Coord, - data: Union[int, float, ir.Value, Numeric, "TensorSSA"], - *, - loc=None, - ip=None, - ) -> None: - """Set tensor elements at specified coordinates. - - Assigns values to tensor elements through direct coordinate access or slice assignment. - For slice assignment, the value must be a TensorSSA with matching shape. - - :param crd: Coordinate or slice specification for tensor element assignment - :type crd: Coord - :param data: Value to assign - can be scalar or TensorSSA for slice assignment - :type data: Union[int, float, ir.Value, Numeric, TensorSSA] - :param loc: Source location for MLIR operation tracking, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation, defaults to None - :type ip: Optional[InsertionPoint] - - :raises ValueError: If tensor type doesn't support load/store operations - :raises ValueError: If slice assignment value is not a TensorSSA - :raises ValueError: If value type doesn't match tensor element type - :raises NotImplementedError: If value type is not supported - - .. note:: - Sub-byte types like Float4E2M1FN and Float6E3M2FN are not supported for scalar - dereference operations. Attempting to set individual elements of tensors with - these element types will result in errors. - - **Examples:** - - .. code-block:: python - - # Unsupported operations with sub-byte types: - ptr = make_ptr(cutlass.Float4E2M1FN, 0, cutlass.AddressSpace.gmem) - tensor = make_tensor(ptr, layout) - # The following will raise an error: - tensor[0] = 1.0 # Error: sub-byte scalar dereference not supported - - # Similarly for other sub-byte types: - ptr = make_ptr(cutlass.Float6E3M2FN, 0, cutlass.AddressSpace.gmem) - tensor = make_tensor(ptr, layout) - tensor[0] = 0.5 # Error: sub-byte scalar dereference not supported - """ - self._check_can_load_store() - - # convert scalar type - if not has_underscore(crd): - self._check_can_dereference() - # First, convert ir.Value to Numeric - if isinstance(data, ir.Value): - data = as_numeric(data) - elif isinstance(data, (int, float, bool)): - data = as_numeric(data) - - if not isinstance(data, Numeric): - raise ValueError(f"unsupported data type: {type(data)}") - - # Implicit upcast to wider type - val = self._cvt_to_dest(data, loc=loc, ip=ip) - if val.type != self.type.value_type: - raise ValueError( - f"type mismatch, store {val.type} to {self.element_type}" - ) - - crd_val = _pack_coord(crd, loc=loc, ip=ip) - _cute_ir.memref_store(self.value, crd_val, val, loc=loc, ip=ip) - else: - if not isinstance(data, TensorSSA): - raise ValueError(f"expects TensorSSA, but got {data}") - - self.__getitem__(crd).store(data, loc=loc, ip=ip) # type: ignore - - @property - def __class__(self) -> Type[Tensor]: - return Tensor - - # Make it behave as if it inherited from ir.Value - @property - @lru_cache_ir() - def type(self) -> ir.Type: - return self.value.type - - @property - def iterator(self) -> Union[Pointer, IntTuple]: - return self._iterator - - @property - def layout(self) -> Layout: - return _cute_ir.get_layout(self.value) - - @property - def shape(self) -> Shape: - return self.layout.shape - - @property - def stride(self) -> Stride: - if isinstance(self.type, _cute_ir.ComposedLayoutType): - raise ValueError(f"can't get stride from composed layout") - return self.layout.stride - - @property - def leading_dim(self) -> Union[int, Tuple[int], None]: - """Get the leading dimension of this Tensor. - - :return: The index or indices of the first mode (from left to right) with stride 1 - :rtype: Union[int, Tuple[int], None] - :returns: - - int: Single leading dimension index if found - - Tuple[int]: Tuple of indices for nested leading dimensions - - None: If no leading dimension is found - - :postcondition: ``get(self.stride(), mode=self.leading_dim()) == 1 if self.leading_dim() != None else True`` - """ - return leading_dim(self.shape, self.stride) - - @property - @lru_cache_ir() - def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]: - return self._dtype - - @property - @lru_cache_ir() - def memspace(self) -> AddressSpace: - if isinstance(self.iterator, Pointer): - return self.iterator.memspace - - raise ValueError(f"{self} doesn't have memspace") - - @dsl_user_op - def load(self, *, loc=None, ip=None) -> "TensorSSA": - """Load tensor elements as a vector. - - Loads all elements of the tensor into a vector representation, assuming the tensor - has a static shape and is in a memory space that supports load operations. - - :param loc: Source location for MLIR operation tracking, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation, defaults to None - :type ip: Optional[InsertionPoint] - :return: Vector representation of tensor elements - :rtype: TensorSSA - - :raises ValueError: If tensor has dynamic layout - :raises ValueError: If tensor memory space doesn't support load operations - """ - if not is_static(self.shape): - raise ValueError("dynamic layout doesn't support load") - - self._check_can_load_store() - - res_vect = _cute_ir.memref_load_vec(self.value, row_major=True, loc=loc, ip=ip) - if self.element_type is Boolean: - assert ( - res_vect.type.element_type == T.i8() - ), f"Boolean tensor must be stored as i8 in memory, but got {res_vect.type.element_type}" - zeros = full_like(self, 0, Int8, loc=loc, ip=ip) - res_vect = arith.cmpi( - arith.CmpIPredicate.ne, res_vect, zeros, loc=loc, ip=ip - ) - return TensorSSA(res_vect, self.shape, self.element_type) - - @dsl_user_op - def store(self, data: "TensorSSA", *, loc=None, ip=None): - """Store vector data into tensor. - - Stores vector data into the tensor, assuming matching shapes and a memory space - that supports store operations. - - :param data: Vector data to store into tensor - :type data: TensorSSA - :param loc: Source location for MLIR operation tracking, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation, defaults to None - :type ip: Optional[InsertionPoint] - - :raises ValueError: If tensor has dynamic layout - :raises ValueError: If tensor memory space doesn't support store operations - :raises ValueError: If data shape doesn't match tensor shape - """ - if not isinstance(data, TensorSSA): - raise ValueError(f"Expects TensorSSA, but got {type(data)}") - - if not is_static(self.shape): - raise ValueError("Dynamic layout doesn't support vectorized store") - - self._check_can_load_store() - - n_elems = size(self.shape, loc=loc, ip=ip) - if n_elems != size(data.shape, loc=loc, ip=ip): - raise ValueError( - f"lhs and rhs must have the same shape, but got {self.shape} and {data.shape}" - ) - - elem_mlir_type = cutlass_arith.element_type(data.dtype.mlir_type) - if cutlass_arith.is_narrow_precision(elem_mlir_type): - if elem_mlir_type.width * n_elems % 32 != 0: - raise ValueError( - f"narrow precision type must be 32-bit aligned vector, but got {elem_mlir_type} with {n_elems} elements" - ) - - # Implicit upcast to wider type - new_data = self._cvt_to_dest(data, loc=loc, ip=ip) - - return _cute_ir.memref_store_vec( - new_data, self.value, row_major=True, loc=loc, ip=ip - ) - - @dsl_user_op - def fill(self, value: Numeric, *, loc=None, ip=None) -> None: - """Fill tensor with a constant value. - - Fills all elements of the tensor with the specified value, assuming static size - and supported memory space. - - :param value: Value to fill tensor with - :type value: Union[int, float] - :param loc: Source location for MLIR operation tracking, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation, defaults to None - :type ip: Optional[InsertionPoint] - - :raises NotImplementedError: If tensor has dynamic size - - **Examples:** - - .. code-block:: python - - # Create tensor from numpy array - b = np.random.randn(4, 8).astype(np.float32) - tensor = from_dlpack(b) - - # Fill tensor with constant value - tensor.fill(0.5) # All elements become 0.5 - """ - self._check_can_load_store() - - sz = size(self, loc=loc, ip=ip) - if type(sz) is not int: - raise NotImplementedError(f"dynamic size is not supported: {self.type}") - - # Should we cast to destination type even with narrow cast? - dst_type = self.element_type - value = dst_type(value) - - self[None] = full(self.shape, fill_value=value, dtype=dst_type, loc=loc, ip=ip) - - def _check_can_load_store(self): - if not isinstance(self.type, _cute_ir.MemRefType) or not self.memspace in ( - AddressSpace.rmem, - AddressSpace.smem, - AddressSpace.gmem, - AddressSpace.generic, - ): - raise ValueError(f"{self} doesn't support load and store") - - def _check_can_dereference(self): - # Check for sub-byte types and raise error if needed - if self.element_type.width % 8 != 0 and self.element_type is not Boolean: - raise ValueError( - f"Sub-byte scalar dereference not supported for type {self.element_type}" - ) - - -@dsl_user_op -def print_tensor( - tensor: Union[Tensor, "TensorSSA"], *, verbose: bool = False, loc=None, ip=None -): - """Print content of the tensor in human readable format. - - Outputs the tensor data in a structured format showing both metadata - and the actual data values. The output includes tensor type information, - layout details, and a formatted array representation of the values. - - :param tensor: The tensor to print - :type tensor: Tensor - :param verbose: If True, includes additional debug information in the output - :type verbose: bool - :param loc: Source location where it's called, defaults to None - :type loc: source location, optional - :param ip: Insertion pointer for IR generation, defaults to None - :type ip: insertion pointer, optional - :raises NotImplementedError: If the tensor type doesn't support trivial dereferencing - - **Example output:** - - .. code-block:: text - - tensor(raw_ptr<@..., Float32, generic, align(4)> o (8,5):(5,1), data= - [[-0.4326, -0.5434, 0.1238, 0.7132, 0.8042], - [-0.8462, 0.9871, 0.4389, 0.7298, 0.6948], - [ 0.3426, 0.5856, 0.1541, 0.2923, 0.6976], - [-0.1649, 0.8811, 0.1788, 0.1404, 0.2568], - [-0.2944, 0.8593, 0.4171, 0.8998, 0.1766], - [ 0.8814, 0.7919, 0.7390, 0.4566, 0.1576], - [ 0.9159, 0.7577, 0.6918, 0.0754, 0.0591], - [ 0.6551, 0.1626, 0.1189, 0.0292, 0.8655]]) - """ - if isinstance(tensor, TensorSSA): - tmp = make_fragment(tensor.shape, tensor.dtype) - tmp.store(tensor) - tensor = tmp - - if not isinstance(tensor.type, _cute_ir.MemRefType): - raise NotImplementedError( - f"printing {tensor} is not supported because it doesn't support trivial dereferencing. " - f"Coordinate Tensor will be supported in the future." - ) - - tensor._check_can_load_store() # type: ignore - - if tensor.element_type.is_integer: - signed = tensor.element_type.signed - else: - signed = False - - _cute_ir.print_view(tensor.value, verbose=verbose, is_signed=signed, loc=loc, ip=ip) - - #################################################################################################### # # Core API @@ -1722,18 +1210,21 @@ def print_tensor( #################################################################################################### -# -# Utilties -# +def _op_wrapper(op_fn, input): + from .tensor import _Tensor + + if isinstance(input, Tensor): + res = op_fn(input.value) + return _Tensor(res, dtype=input.element_type) + elif isinstance(input, _ComposedLayout): + return op_fn(input.value) + else: + return op_fn(input) -@lru_cache_ir() -def is_integer(a) -> bool: - """Check if an object is static integer or dynamic integer""" - return isinstance(a, (int, Integer)) or ( - isinstance(a, ir.Value) - and isinstance(a.type, (ir.IntegerType, _cute_ir.ConstrainedIntType)) - ) +# +# Utilities +# def is_valid_leaf(a) -> bool: @@ -1747,21 +1238,29 @@ def is_valid_leaf(a) -> bool: ) -def is_int_tuple(a) -> bool: - if isinstance(a, tuple): - return all([is_int_tuple(x) for x in a]) - else: - return is_integer(a) - - -def is_static(x: Union[ir.Type, ir.Value, XTuple]) -> bool: +def is_static(x: Any) -> bool: """Check if a value is statically known at compile time. In CuTe, static values are those whose values are known at compile time, as opposed to dynamic values which are only known at runtime. + This function checks if a value is static by recursively traversing its type hierarchy + and checking if all components are static. + + Static values include: + - Python literals (bool, int, float, None) + - Static ScaledBasis objects + - Static ComposedLayout objects + - Static IR types + - Tuples containing only static values + + Dynamic values include: + - Numeric objects (representing runtime values) + - Dynamic expressions + - Any tuple containing dynamic values + :param x: The value to check - :type x: Union[ir.Type, ir.Value, XTuple] + :type x: Any :return: True if the value is static, False otherwise :rtype: bool :raises TypeError: If an unsupported type is provided @@ -1773,12 +1272,14 @@ def is_static(x: Union[ir.Type, ir.Value, XTuple]) -> bool: # Can it be a static int? elif isinstance(x, Numeric): return False + elif isinstance(x, ScaledBasis): + return x.is_static() + elif isinstance(x, _ComposedLayout): + return _cute_ir.is_static(x.type) elif is_dynamic_expression(x): return _cute_ir.is_static(x.type) elif isinstance(x, (bool, int, float)) or x is None: return True - elif isinstance(x, ScaledBasis): - return x.is_static() else: raise TypeError(f"unsupported type {x}") @@ -1807,7 +1308,7 @@ def has_scaled_basis(a: XTuple) -> bool: return isinstance(a, ScaledBasis) -def _tuple_str(t: tuple) -> str: +def _tuple_str(t: Tuple[Any, ...]) -> str: """ Constructs a string representation of a python tuple without calling __repr__ on its elements. """ @@ -1845,26 +1346,48 @@ def pretty_str(arg) -> str: @dsl_user_op def printf(*args, loc=None, ip=None) -> None: """ - Print a value or a list of values. + Print one or more values with optional formatting. - It supports c-style printf format as well: + This function provides printf-style formatted printing capabilities. It can print values directly + or format them using C-style format strings. The function supports printing various types including + layouts, numeric values, tensors, and other CuTe objects. + + The function accepts either: + 1. A list of values to print directly + 2. A format string followed by values to format + + :param args: Variable length argument list containing either: + - One or more values to print directly + - A format string followed by values to format + :type args: Any + :param loc: Source location information for debugging, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for code generation, defaults to None + :type ip: Optional[InsertionPoint] + :raises ValueError: If no arguments are provided + :raises TypeError: If an unsupported argument type is passed + + **Examples:** + + Direct printing of values: .. code-block:: python a = cute.make_layout(shape=(10, 10), stride=(10, 1)) b = cutlass.Float32(1.234) - cute.printf(a, b) - cute.printf("a={}, b={}", a, b) - cute.printf("a={}, b=%.2f", a, b) + cute.printf(a, b) # Prints values directly - :param args: List of values to print - :type args: list - :param loc: Source location where it's called, defaults to None - :type loc: source location, optional - :param ip: Insertion pointer, defaults to None - :type ip: insertion pointer, optional - :raises ValueError: If no arguments are provided or if an unsupported argument type is passed + Formatted printing: + + .. code-block:: python + + # Using format string with generic format specifiers + cute.printf("a={}, b={}", a, b) + + # Using format string with C-style format specifiers + cute.printf("a={}, b=%.2f", a, b) """ + from .tensor import _Tensor if len(args) == 0: raise ValueError("expects at least one argument to print") @@ -1892,10 +1415,12 @@ def printf(*args, loc=None, ip=None) -> None: elif has_scaled_basis(arg0): # Assume it's a stride return _pack_stride(arg0) - elif isinstance(arg0, tuple): - # Assume it's an int_tuple + elif is_int_tuple(arg0): return _pack_int_tuple(arg0) - elif isinstance(arg0, (_Tensor, _Pointer)): + elif isinstance(arg0, tuple): + # Assume it's a tile + return _pack_tile(arg0) + elif isinstance(arg0, (_Tensor, _Pointer, _ComposedLayout)): return arg0.value else: raise TypeError(f"unsupported argument type in printf, got {type(arg)}") @@ -1939,175 +1464,6 @@ def is_major(mode, stride: Stride, *, loc=None, ip=None) -> bool: return True if first_stride == 1 else False -def leading_dim(shape: Shape, stride: Stride) -> Union[int, Tuple[int, ...], None]: - """ - Find the leading dimension of a shape and stride. - - :param shape: The shape of the tensor or layout - :type shape: Shape - :param stride: The stride of the tensor or layout - :type stride: Stride - :return: The leading dimension index or indices - :rtype: Union[int, Tuple[int, ...], None] - - The return value depends on the stride pattern: - - * If a single leading dimension is found, returns an integer index - * If nested leading dimensions are found, returns a tuple of indices - * If no leading dimension is found, returns None - """ - - def pred_fn(val, pos): - # skip dynamic values which can't be compared - # find the candidate target val, stride at this position is 1 - if (not is_dynamic_expression(val)) and (val == 1): - # extract the shape at this position - mode = [pos] if isinstance(pos, int) else list(pos) - s = get(shape, mode) - if is_dynamic_expression(s) or s != 1: - # shape at this position is dynamic value or not 1 - # we found the leading dimension - return True - return False - - return find_if(stride, pred_fn=pred_fn) - - -@dsl_user_op -def find_if( - t: Union[tuple, ir.Value, int], - pred_fn: Callable[[int, Tuple[int, ...]], bool], - *, - loc=None, - ip=None, -) -> Union[int, Tuple[int, ...], None]: - """Find the first position in t where pred_fn(val, pos) returns True. - - :param t: The search space - :type t: Union[tuple, ir.Value, int] - :param pred_fn: A callable object (lambda, function, etc.) that predicates the value and position in t. - It takes the current leaf value and position, returns True if the value or position is satisfied. - :type pred_fn: Callable[[int, Tuple[int, ...]], bool] - :return: Index if found at top level, tuple of indices showing nested position, or None if not found - :rtype: Union[int, Tuple[int, ...], None] - - **Examples:** - - .. code-block:: python - - # Find the first position of x in t - t = (3, 4) - find_if(t, pred_fn=lambda val, pos: val == x) - - .. code-block:: python - - # find the leading dimension - shape = (3, 4) - stride = (4, 1) - # Find value 1 in stride where the corresponding shape is not 1 - def pred_fn(val, pos): - mode = [pos] if isinstance(pos, int) else list(pos) - return val == 1 and get(shape, mode) != 1 - find_if(stride, pred_fn=pred_fn) - """ - - def _find_if_impl(curr, pos, *, loc=None, ip=None): - if isinstance(curr, tuple): - # Recursively search nested tuple - for i in range(rank(curr)): - sub_curr = get(curr, mode=[i], loc=loc, ip=ip) - sub_pos = (pos, i) if isinstance(pos, int) else pos + (i,) - res_pos = _find_if_impl(sub_curr, sub_pos, loc=loc, ip=ip) - if res_pos is not None: - return res_pos - else: - # For leaf values, check if it matches x - if pred_fn(curr, pos): - return pos - return None - - def _check_pred_fn(): - if not callable(pred_fn): - raise TypeError(f"pred_fn must be callable, but got {type(pred_fn)}") - signature = inspect.signature(pred_fn) - if len(signature.parameters) != 2: - raise ValueError( - f"pred_fn must have two parameters (value, pos), but got {len(signature.parameters)}" - ) - - _check_pred_fn() - - for i in range(rank(t)): - curr = get(t, mode=[i], loc=loc, ip=ip) - res_pos = _find_if_impl(curr, i, loc=loc, ip=ip) - if res_pos is not None: - return res_pos - return None - - -@dsl_user_op -def find( - t: Union[tuple, ir.Value, int], - x: int, - *, - loc=None, - ip=None, -) -> Union[int, Tuple[int, ...], None]: - """Find the first position of a value ``x`` in a hierarchical structure ``t``. - - Searches for the first occurrence of x in t, optionally excluding positions - where a comparison value matches. The search can traverse nested structures - and returns either a single index or a tuple of indices for nested positions. - - :param t: The search space - :type t: Union[tuple, ir.Value, int] - :param x: The static integer x to search for - :type x: int - :return: Index if found at top level, tuple of indices showing nested position, or None if not found - :rtype: Union[int, Tuple[int, ...], None] - """ - if not isinstance(x, int): - raise TypeError(f"find() requires a static x to search for, but got {x}") - - def pred_fn(val, pos): - # Skip dynamic values which can't be compared - return not is_dynamic_expression(val) and val == x - - return find_if(t, pred_fn=pred_fn, loc=loc, ip=ip) - - -def transform_leaf(f, *args): - """ - Apply a function to the leaf nodes of nested tuple structures. - - This function traverses nested tuple structures in parallel and applies the function f - to corresponding leaf nodes. All input tuples must have the same nested structure. - - :param f: Function to apply to leaf nodes - :type f: Callable - :param args: One or more nested tuple structures with matching profiles - :return: A new nested tuple with the same structure as the inputs, but with leaf values transformed by f - :raises TypeError: If the input tuples have different nested structures - - Example: - - .. code-block:: python - - >>> transform_leaf(lambda x: x + 1, (1, 2)) - (2, 3) - >>> transform_leaf(lambda x, y: x + y, (1, 2), (3, 4)) - (4, 6) - >>> transform_leaf(lambda x: x * 2, ((1, 2), (3, 4))) - ((2, 4), (6, 8)) - """ - if all(isinstance(t, tuple) for t in args): - return tuple(transform_leaf(f, *_args) for _args in zip(*args)) - elif all(not isinstance(t, tuple) for t in args): - return f(*args) - else: - raise TypeError(f"profile of input tuples doesn't match: {args}") - - @dsl_user_op def assume(src, divby=None, *, loc=None, ip=None): if divby is None: @@ -2115,7 +1471,7 @@ def assume(src, divby=None, *, loc=None, ip=None): if isinstance(src, Integer): width = type(src).width - src_val = src.ir_value() + src_val = src.ir_value(loc=loc, ip=ip) else: width = src.type.width src_val = src @@ -2131,7 +1487,22 @@ def make_swizzle(b, m, s, *, loc=None, ip=None): if b == 0: m, s = 4, 3 ty = ir.Type.parse(f'!cute.swizzle<"S<{b},{m},{s}>">') - return Swizzle(_cute_ir.static(ty, loc=loc, ip=ip)) + return Swizzle(static(ty, loc=loc, ip=ip)) + + +@dsl_user_op +def make_sparse_elem(num_logical, num_phys, elem_type, *, loc=None, ip=None): + return _cute_ir.SparseElemType.get(num_logical, num_phys, elem_type.mlir_type) + + +@dsl_user_op +def static(value, *, loc=None, ip=None): + return _cute_ir.static(value, loc=loc, ip=ip) + + +@dsl_user_op +def get_leaves(value, *, loc=None, ip=None): + return _cute_ir.get_leaves(value, loc=loc, ip=ip) # @@ -2152,7 +1523,7 @@ def depth(a: Union[XTuple, Layout, "ComposedLayout"]) -> int: :return: The depth of the input object :rtype: int - Example: + **Example:** .. code-block:: python @@ -2174,7 +1545,7 @@ def depth(a: Union[XTuple, Layout, "ComposedLayout"]) -> int: @lru_cache_ir() -def rank(a: Union[XTuple, Layout, "ComposedLayout"]) -> int: +def rank(a: Union[XTuple, Layout, "ComposedLayout"], mode: List[int] = []) -> int: # type: ignore """Returns the rank (dimensionality) of a tuple, layout, or tensor. The rank of a tuple is its length. For layouts and tensors, the rank is @@ -2189,16 +1560,23 @@ def rank(a: Union[XTuple, Layout, "ComposedLayout"]) -> int: This function is used in layout algebra to determine the dimensionality of tensors and layouts for operations like slicing and evaluation. """ + if isinstance(a, (Layout, ComposedLayout, Tensor)): + return rank(a.shape, mode) + + if (not isinstance(mode, list)) or any(not isinstance(m, int) for m in mode): + raise ValueError(f"Expected 'mode' to be a list of int, but got {mode}") + + if mode: + for x in mode: + a = a[x] + if isinstance(a, tuple): return len(a) - elif isinstance(a, (Layout, ComposedLayout, Tensor)): - return rank(a.shape) elif depth(a) == 0: return 1 else: raise TypeError(f"unsupported type in rank, got {type(a)}") - def is_congruent( a: Union[XTuple, Layout, ComposedLayout, Tensor], b: Union[XTuple, Layout, ComposedLayout, Tensor], @@ -2263,32 +1641,23 @@ def is_weakly_congruent( b = b.shape if not isinstance(a, tuple): return True - if isinstance(a, tuple) and isinstance(b, tuple): + # a and b are both tuple + if isinstance(b, tuple): return (len(a) == len(b)) and all( is_weakly_congruent(x, y) for x, y in zip(a, b) ) - if isinstance(a, tuple) or isinstance(b, tuple): - return False - return True + # a is a tuple, b is not a tuple + return False -@overload -def get(input: Shape, mode, *, loc=None, ip=None) -> Shape: ... -@overload -def get(input: Stride, mode, *, loc=None, ip=None) -> Stride: ... -@overload -def get(input: Coord, mode, *, loc=None, ip=None) -> Coord: ... -@overload -def get(input: IntTuple, mode, *, loc=None, ip=None) -> IntTuple: ... -@overload -def get(input: Tile, mode, *, loc=None, ip=None) -> Tile: ... @overload def get(input: Layout, mode, *, loc=None, ip=None) -> Layout: ... @overload def get(input: ComposedLayout, mode, *, loc=None, ip=None) -> ComposedLayout: ... +@overload +def get(input: XTuple, mode, *, loc=None, ip=None) -> XTuple: ... -@dsl_user_op def get(input, mode: List[int], *, loc=None, ip=None): """Extract a specific element or sub-layout from a layout or tuple. @@ -2339,25 +1708,19 @@ def get(input, mode: List[int], *, loc=None, ip=None): else: if not isinstance(input, (Layout, ComposedLayout)): raise TypeError(f"unsupported type of input, got {type(input)}") - return _cute_ir.get( - input.type.get_op_res_type(mode=mode), input, mode=mode, loc=loc, ip=ip - ) + + if isinstance(input, _ComposedLayout): + input = input.value + res_ty = input.type.get_op_res_type(mode=mode) # type: ignore + return _cute_ir.get(res_ty, input, mode=mode, loc=loc, ip=ip) -@overload -def select(input: Shape, mode, *, loc=None, ip=None) -> Shape: ... -@overload -def select(input: Stride, mode, *, loc=None, ip=None) -> Stride: ... -@overload -def select(input: Coord, mode, *, loc=None, ip=None) -> Coord: ... -@overload -def select(input: IntTuple, mode, *, loc=None, ip=None) -> IntTuple: ... -@overload -def select(input: Tile, mode, *, loc=None, ip=None) -> Tile: ... @overload def select(input: Layout, mode, *, loc=None, ip=None) -> Layout: ... @overload def select(input: ComposedLayout, mode, *, loc=None, ip=None) -> ComposedLayout: ... +@overload +def select(input: XTuple, mode, *, loc=None, ip=None) -> XTuple: ... @dsl_user_op @@ -2402,23 +1765,12 @@ def select(input, mode: List[int], *, loc=None, ip=None): if not isinstance(input, (Layout, ComposedLayout)): raise TypeError(f"unsupported type of input, got {type(input)}") + if isinstance(input, _ComposedLayout): + input = input.value + return _cute_ir.select(input, mode=mode, loc=loc, ip=ip) -@overload -def group_modes(input: Shape, begin: int, end: int, *, loc=None, ip=None) -> Shape: ... -@overload -def group_modes( - input: Stride, begin: int, end: int, *, loc=None, ip=None -) -> Stride: ... -@overload -def group_modes(input: Coord, begin: int, end: int, *, loc=None, ip=None) -> Coord: ... -@overload -def group_modes( - input: IntTuple, begin: int, end: int, *, loc=None, ip=None -) -> IntTuple: ... -@overload -def group_modes(input: Tile, begin: int, end: int, *, loc=None, ip=None) -> Tile: ... @overload def group_modes( input: Layout, begin: int, end: int, *, loc=None, ip=None @@ -2431,10 +1783,14 @@ def group_modes( def group_modes( input: Tensor, begin: int, end: int, *, loc=None, ip=None ) -> Tensor: ... +@overload +def group_modes( + input: XTuple, begin: int, end: int, *, loc=None, ip=None +) -> XTuple: ... @dsl_user_op -def group_modes(input, begin: int, end: int = -1, *, loc=None, ip=None): +def group_modes(input, begin: int, end: Optional[int] = None, *, loc=None, ip=None): """Group modes of a hierarchical tuple or layout into a single mode. This function groups a range of modes from the input object into a single mode, @@ -2471,33 +1827,37 @@ def group_modes(input, begin: int, end: int = -1, *, loc=None, ip=None): shape = make_shape(2, 3, 4, 5) grouped_shape = group_modes(shape, 0, 2) # Shape ((2, 3), 4, 5) """ + if end is None: + end = rank(input) + + r = rank(input) + begin = max(begin + r, 0) if begin < 0 else begin + end = end + r if end < 0 else end + + if begin >= end: + raise ValueError(f"Expected begin < end, but got {begin} >= {end}") + if depth(input) == 0 and is_integer(input): return (input,) + if isinstance(input, tuple): return (*input[:begin], (input[begin:end]), *input[end:]) - return _cute_ir.group_modes( - input.value if isinstance(input, Tensor) else input, begin, end, loc=loc, ip=ip + + return _op_wrapper( + partial(_cute_ir.group_modes, begin=begin, end=end, loc=loc, ip=ip), input ) -@overload -def slice_(src: Shape, coord: Coord, *, loc=None, ip=None) -> Shape: ... -@overload -def slice_(src: Stride, coord: Coord, *, loc=None, ip=None) -> Stride: ... -@overload -def slice_(src: Coord, coord: Coord, *, loc=None, ip=None) -> Coord: ... -@overload -def slice_(src: IntTuple, coord: Coord, *, loc=None, ip=None) -> IntTuple: ... -@overload -def slice_(src: Tile, coord: Coord, *, loc=None, ip=None) -> Tile: ... @overload def slice_(src: Layout, coord: Coord, *, loc=None, ip=None) -> Layout: ... @overload def slice_( - src: ComposedLayout, coord: Coord, *, loc=None, ip=None -) -> ComposedLayout: ... + src: _ComposedLayout, coord: Coord, *, loc=None, ip=None +) -> _ComposedLayout: ... @overload def slice_(src: Tensor, coord: Coord, *, loc=None, ip=None) -> Tensor: ... +@overload +def slice_(src: XTuple, coord: Coord, *, loc=None, ip=None) -> XTuple: ... @dsl_user_op @@ -2576,29 +1936,16 @@ def slice_(src, coord: Coord, *, loc=None, ip=None): else: return () - res_type = None - if isinstance(src, Tensor): - res_type = src.element_type - src = src.value coord_val = _pack_coord(coord, loc=loc, ip=ip) - res = _cute_ir.slice(input=src, coord=coord_val, loc=loc, ip=ip) - return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res + return _op_wrapper(partial(_cute_ir.slice, coord=coord_val, loc=loc, ip=ip), src) @overload -def dice(src: Shape, coord: Coord, *, loc=None, ip=None) -> Shape: ... +def dice(src: Layout, dicer: Coord, *, loc=None, ip=None) -> Layout: ... @overload -def dice(src: Stride, coord: Coord, *, loc=None, ip=None) -> Stride: ... +def dice(src: ComposedLayout, dicer: Coord, *, loc=None, ip=None) -> ComposedLayout: ... @overload -def dice(src: Coord, coord: Coord, *, loc=None, ip=None) -> Coord: ... -@overload -def dice(src: IntTuple, coord: Coord, *, loc=None, ip=None) -> IntTuple: ... -@overload -def dice(src: Tile, coord: Coord, *, loc=None, ip=None) -> Tile: ... -@overload -def dice(src: Layout, coord: Coord, *, loc=None, ip=None) -> Layout: ... -@overload -def dice(src: ComposedLayout, coord: Coord, *, loc=None, ip=None) -> ComposedLayout: ... +def dice(src: XTuple, dicer: Coord, *, loc=None, ip=None) -> XTuple: ... @dsl_user_op @@ -2665,29 +2012,23 @@ def dice(src, dicer, *, loc=None, ip=None): return src dicer_val = _pack_coord(dicer, loc=loc, ip=ip) - return _cute_ir.dice(src, dicer_val.type.attribute, loc=loc, ip=ip) - - -def wrap(x) -> tuple: - """ - Wraps the input into a tuple if not a tuple. - """ - if isinstance(x, tuple): - return x - return (x,) + return _op_wrapper( + partial(_cute_ir.dice, coord=dicer_val.type.attribute, loc=loc, ip=ip), src + ) def _extend(func, input, elem, up_to_rank, loc, ip): if input is None: - raise ValueError(f"No input provided for input") + raise ValueError("No input provided for input") - if isinstance(input, (Layout, ComposedLayout)): + if isinstance(input, (_Layout, _ComposedLayout)): if elem is None: elem = make_layout(1) elif not isinstance(elem, Layout): raise TypeError(f"Input type of elem ({type(elem)}) is not accepted!") N = rank(input) + 1 if up_to_rank is None else up_to_rank - return func(N, input, elem, loc=loc, ip=ip) + + return _op_wrapper(partial(func, N, element=elem, loc=loc, ip=ip), input) if is_valid_leaf(input) or isinstance(input, tuple): if elem is None: @@ -2700,7 +2041,7 @@ def _extend(func, input, elem, up_to_rank, loc, ip): if repeat_cnt == 0: return input elif repeat_cnt < 0: - raise ValueError(f"up_to_rank must be >= rank(input)") + raise ValueError("up_to_rank must be >= rank(input)") else: if func is _cute_ir.prepend_to_rank: return (elem,) * repeat_cnt + input @@ -2710,24 +2051,6 @@ def _extend(func, input, elem, up_to_rank, loc, ip): raise TypeError(f"invalid type for input, got {type(input)}") -@overload -def prepend( - input: Shape, elem: Shape, up_to_rank=None, *, loc=None, ip=None -) -> Shape: ... -@overload -def prepend( - input: Stride, elem: Stride, up_to_rank=None, *, loc=None, ip=None -) -> Stride: ... -@overload -def prepend( - input: Coord, elem: Coord, up_to_rank=None, *, loc=None, ip=None -) -> Coord: ... -@overload -def prepend( - input: IntTuple, elem: IntTuple, up_to_rank=None, *, loc=None, ip=None -) -> IntTuple: ... -@overload -def prepend(input: Tile, elem: Tile, up_to_rank=None, *, loc=None, ip=None) -> Tile: ... @overload def prepend( input: Layout, elem: Layout, up_to_rank=None, *, loc=None, ip=None @@ -2736,6 +2059,10 @@ def prepend( def prepend( input: ComposedLayout, elem: Layout, up_to_rank=None, *, loc=None, ip=None ) -> ComposedLayout: ... +@overload +def prepend( + input: XTuple, elem: XTuple, up_to_rank=None, *, loc=None, ip=None +) -> XTuple: ... @dsl_user_op @@ -2779,24 +2106,6 @@ def prepend(input, elem, up_to_rank: Union[None, int] = None, *, loc=None, ip=No return _extend(_cute_ir.prepend_to_rank, input, elem, up_to_rank, loc=loc, ip=ip) -@overload -def append( - input: Shape, elem: Shape, up_to_rank=None, *, loc=None, ip=None -) -> Shape: ... -@overload -def append( - input: Stride, elem: Stride, up_to_rank=None, *, loc=None, ip=None -) -> Stride: ... -@overload -def append( - input: Coord, elem: Coord, up_to_rank=None, *, loc=None, ip=None -) -> Coord: ... -@overload -def append( - input: IntTuple, elem: IntTuple, up_to_rank=None, *, loc=None, ip=None -) -> IntTuple: ... -@overload -def append(input: Tile, elem: Tile, up_to_rank=None, *, loc=None, ip=None) -> Tile: ... @overload def append( input: Layout, elem: Layout, up_to_rank=None, *, loc=None, ip=None @@ -2805,6 +2114,10 @@ def append( def append( input: ComposedLayout, elem: Layout, up_to_rank=None, *, loc=None, ip=None ) -> ComposedLayout: ... +@overload +def append( + input: XTuple, elem: XTuple, up_to_rank=None, *, loc=None, ip=None +) -> XTuple: ... @dsl_user_op @@ -2858,18 +2171,91 @@ def append(input, elem, up_to_rank: Union[None, int] = None, *, loc=None, ip=Non def prepend_ones( t: Tensor, up_to_rank: Union[None, int] = None, *, loc=None, ip=None ) -> Tensor: + from .tensor import make_tensor + return make_tensor( t.iterator, prepend(t.layout, make_layout(1), up_to_rank), loc=loc, ip=ip ) -@dsl_user_op +@overload +def append_ones( + t: Layout, up_to_rank: Union[None, int] = None, *, loc=None, ip=None +) -> Layout: ... + + +@overload def append_ones( t: Tensor, up_to_rank: Union[None, int] = None, *, loc=None, ip=None -) -> Tensor: - return make_tensor( - t.iterator, append(t.layout, make_layout(1), up_to_rank), loc=loc, ip=ip - ) +) -> Tensor: ... + + +@dsl_user_op +def append_ones(t, up_to_rank: Union[None, int] = None, *, loc=None, ip=None): + from .tensor import make_tensor + + if isinstance(t, Tensor): + return make_tensor( + t.iterator, append(t.layout, make_layout(1), up_to_rank), loc=loc, ip=ip + ) + elif isinstance(t, Layout): + return append(t, make_layout(1), up_to_rank) + else: + raise TypeError(f"expects Tensor or Layout, but got {type(t)}") + + +def repeat_as_tuple(x, n) -> tuple: + """Creates a tuple with x repeated n times. + + This function creates a tuple by repeating the input value x n times. + + :param x: The value to repeat + :type x: Any + :param n: Number of times to repeat x + :type n: int + :return: A tuple containing x repeated n times + :rtype: tuple + + **Examples:** + + .. code-block:: python + + repeat_as_tuple(1, 1) # Returns (1,) + repeat_as_tuple(1, 3) # Returns (1, 1, 1) + repeat_as_tuple(None, 4) # Returns (None, None, None, None) + """ + if n < 1: + raise ValueError("n must be >= 1") + + return (x,) * n + + +def repeat(x, n): + """Creates an object by repeating x n times. + + This function creates an object by repeating the input value x n times. + If n=1, returns x directly, otherwise returns a tuple of x repeated n times. + + :param x: The value to repeat + :type x: Any + :param n: Number of times to repeat x + :type n: int + :return: x if n=1, otherwise a tuple containing x repeated n times + :rtype: Union[Any, tuple] + :raises ValueError: If n is less than 1 + + **Examples:** + + .. code-block:: python + + repeat(1, 1) # Returns 1 + repeat(1, 3) # Returns (1, 1, 1) + repeat(None, 4) # Returns (None, None, None, None) + """ + if n < 1: + raise ValueError("n must be >= 1") + + return x if n == 1 else (x,) * n def repeat_like(x, target): @@ -2902,37 +2288,14 @@ def repeat_like(x, target): return tuple(repeat_like(x, t) for t in target) -def flatten_to_tuple(a: Union[IntTuple, Coord, Shape, Stride]) -> tuple: - """Flattens a potentially nested tuple structure into a flat tuple. - - This function recursively traverses the input structure and flattens it into - a single-level tuple, preserving the order of elements. - - :param a: The structure to flatten - :type a: Union[IntTuple, Coord, Shape, Stride] - :return: A flattened tuple containing all elements from the input - :rtype: tuple - - **Examples:** - - .. code-block:: python - - flatten_to_tuple((1, 2, 3)) # Returns (1, 2, 3) - flatten_to_tuple(((1, 2), 3)) # Returns (1, 2, 3) - flatten_to_tuple((1, (2, (3,)))) # Returns (1, 2, 3) - """ - if not isinstance(a, tuple): - return wrap(a) - else: - return tuple(chain.from_iterable(tuple(flatten_to_tuple(x) for x in a))) - - @overload -def flatten(a: Union[IntTuple, Coord, Shape, Stride]) -> IntTuple: ... +def flatten(a: Layout) -> Layout: ... + + @overload def flatten(a: Tensor) -> Tensor: ... @overload -def flatten(a: Layout) -> Layout: ... +def flatten(a: XTuple) -> XTuple: ... def flatten(a): @@ -2959,6 +2322,8 @@ def flatten(a): flatten(Tensor(layout)) # Returns Tensor(flatten(layout)) """ + from .tensor import make_tensor + if isinstance(a, Tensor): return make_tensor(a.iterator, flatten(a.layout)) elif isinstance(a, Layout): @@ -2969,50 +2334,6 @@ def flatten(a): return a -def unflatten( - sequence: Union[Tuple[Any, ...], List[Any], Iterable[Any]], profile: XTuple -) -> XTuple: - """Unflatten a flat tuple into a nested tuple structure according to a profile. - - This function transforms a flat sequence of elements into a nested tuple structure - that matches the structure defined by the profile parameter. It traverses the profile - structure and populates it with elements from the sequence. - - sequence must be long enough to fill the profile. Raises RuntimeError if it is not. - - :param sequence: A flat sequence of elements to be restructured - :type sequence: Union[Tuple[Any, ...], List[Any], Iterable[Any]] - :param profile: A nested tuple structure that defines the shape of the output - :type profile: XTuple - :return: A nested tuple with the same structure as profile but containing elements from sequence - :rtype: XTuple - - Example: - >>> unflatten([1, 2, 3, 4], ((0, 0), (0, 0))) - ((1, 2), (3, 4)) - """ - - def _make_generator(): - for element in sequence: - yield element - - xs = _make_generator() - return transform_leaf(lambda _: next(xs), profile) - - -@dsl_user_op -def elem_less( - lhs: Union[Shape, IntTuple, Coord], - rhs: Union[Shape, IntTuple, Coord], - *, - loc=None, - ip=None, -): - lhs_val = _pack_coord(lhs, loc=loc, ip=ip) - rhs_val = _pack_coord(rhs, loc=loc, ip=ip) - return Boolean(_cute_ir.elem_less(lhs_val, rhs_val, loc=loc, ip=ip)) - - @overload def filter_zeros( input: Layout, *, target_profile=None, loc=None, ip=None @@ -3033,8 +2354,8 @@ def filter_zeros(input, *, target_profile=None, loc=None, ip=None): :param input: The input layout or tensor to filter :type input: Layout or Tensor - :param target_profile: Target profile for the filtered result, defaults to None - :type target_profile: optional + :param target_profile: Target stride profile for the filtered result, defaults to None + :type target_profile: Stride, optional :param loc: Source location for MLIR, defaults to None :type loc: optional :param ip: Insertion point, defaults to None @@ -3044,14 +2365,22 @@ def filter_zeros(input, *, target_profile=None, loc=None, ip=None): :raises TypeError: If input is not a Layout or Tensor """ if not isinstance(input, (Layout, Tensor)): - raise TypeError(f"Expect layout or tensor as input but got {type(input)=}") + raise TypeError(f"Expected layout or tensor as input, but got {type(input)=}") if isinstance(input, Tensor): input = input.value return _cute_ir.filter_zeros(input, target_profile=target_profile, loc=loc, ip=ip) +@overload +def filter(input: Layout, *, loc=None, ip=None) -> Layout: ... +@overload +def filter(input: ComposedLayout, *, loc=None, ip=None) -> ComposedLayout: ... +@overload +def filter(input: Tensor, *, loc=None, ip=None) -> Tensor: ... + + @dsl_user_op -def filter(input: Union[Layout, Tensor], *, loc=None, ip=None): +def filter(input, *, loc=None, ip=None): """Filter a layout or tensor. This function filters a layout or tensor according to CuTe's filtering rules. @@ -3066,117 +2395,19 @@ def filter(input: Union[Layout, Tensor], *, loc=None, ip=None): :rtype: Layout or Tensor :raises TypeError: If input is not a Layout or Tensor """ - if not isinstance(input, (Layout, Tensor)): - raise TypeError(f"Expect layout or tensor as input but got {type(input)=}") - if isinstance(input, _Tensor): - input = input.value - return _cute_ir.filter(input, loc=loc, ip=ip) + from .tensor import _Tensor + if not isinstance(input, (Layout, Tensor, ComposedLayout)): + raise TypeError(f"Expected layout or tensor as input, but got {type(input)=}") -@dsl_user_op -def product(a: Union[IntTuple, Shape], *, loc=None, ip=None): - """Return product of the given IntTuple or Shape. - - Computes the product of all elements in the input tuple or shape. - Returns static value if type is static. - - :param a: The input tuple or shape - :type a: IntTuple or Shape - :param loc: Source location for MLIR, defaults to None - :type loc: optional - :param ip: Insertion point, defaults to None - :type ip: optional - :return: Static product of IntTuple or Shape if static, otherwise a Value - :rtype: int or Value - :raises TypeError: If input is not an IntTuple or Shape - """ - if is_integer(a): - return a - if isinstance(a, tuple): - a_val = _pack_int_tuple(a, loc=loc, ip=ip) - res = _cute_ir.tuple_product(a_val, loc=loc, ip=ip) - return _unpack_x_tuple(res, loc=loc, ip=ip) - else: - raise TypeError(f"expects IntTuple or Shape, but got {type(a)}") - - -@overload -def product_like( - a: IntTuple, target_profile: XTuple, *, loc=None, ip=None -) -> IntTuple: ... -@overload -def product_like(a: Shape, target_profile: XTuple, *, loc=None, ip=None) -> Shape: ... - - -@dsl_user_op -def product_like( - a: Union[IntTuple, Shape], target_profile: XTuple, *, loc=None, ip=None -): - """Return product of the given IntTuple or Shape at leaves of `target_profile`. - - This function computes products according to the structure defined by target_profile. - - :param a: The input tuple or shape - :type a: IntTuple or Shape - :param target_profile: The profile that guides how products are computed - :type target_profile: XTuple - :param loc: Source location for MLIR, defaults to None - :type loc: optional - :param ip: Insertion point, defaults to None - :type ip: optional - :return: The resulting tuple with products computed according to target_profile - :rtype: IntTuple or Shape - :raises TypeError: If inputs have incompatible types - :raises ValueError: If inputs have incompatible shapes - """ - # Perform product at leaf of `target_profile` - if not isinstance(target_profile, tuple): - return product(a, loc=loc, ip=ip) - else: - if not isinstance(a, tuple): - raise TypeError(f"expects `a` tuple but got {a}") - - if len(a) != len(target_profile): - raise ValueError(f"expects `a` and `guide` have the same rank") - - return tuple( - product_like(x, g, loc=loc, ip=ip) for x, g in zip(a, target_profile) + if isinstance(input, ComposedLayout): + return make_composed_layout( + input.inner, input.offset, filter(input.outer, loc=loc, ip=ip) ) - - -@overload -def product_each(a: IntTuple, *, loc=None, ip=None) -> IntTuple: ... -@overload -def product_each(a: Shape, *, loc=None, ip=None) -> Shape: ... - - -@dsl_user_op -def product_each(a, *, loc=None, ip=None): - """Compute products for each component of the input. - - Returns a rank(a) tuple `result` such that get(result, mode=[i]) == product(get(a, mode=[i])) - - :param a: The input tuple or shape - :type a: IntTuple or Shape - :param loc: Source location for MLIR, defaults to None - :type loc: optional - :param ip: Insertion point, defaults to None - :type ip: optional - :return: A tuple containing products for each component - :rtype: tuple - :raises TypeError: If input is not an IntTuple or Shape - """ - if is_integer(a): - return a - if isinstance(a, tuple): - if not a: - return 1 - else: - a_val = _pack_int_tuple(a, loc=loc, ip=ip) - res = _cute_ir.tuple_product_each(a_val, loc=loc, ip=ip) - return _unpack_x_tuple(res, loc=loc, ip=ip) + elif isinstance(input, _Tensor): + return _cute_ir.filter(input.value, loc=loc, ip=ip) else: - raise TypeError(f"expects IntTuple or Shape, but got {type(a)}") + return _cute_ir.filter(input, loc=loc, ip=ip) @dsl_user_op @@ -3206,15 +2437,17 @@ def size( :rtype: int or Value :raises ValueError: If mode contains non-integer elements """ + from .atom import TiledCopy, TiledMma + if any(not isinstance(m, int) for m in mode): - raise ValueError(f"expects integer elements in mode, but got {mode}") + raise ValueError(f"Expected integer elements in mode, but got {mode}") if isinstance(a, (TiledMma, TiledCopy)): return a.size a_val = None if not isinstance(a, (Layout, ComposedLayout, Tensor)): a_val = _pack_int_tuple(a, loc=loc, ip=ip) - elif isinstance(a, Tensor): + elif isinstance(a, (ComposedLayout, Tensor)): a_val = a.value else: a_val = a @@ -3278,7 +2511,7 @@ def ceil_div(input: Shape, tiler: Tiler, *, loc=None, ip=None) -> Shape: result = cute.ceil_div(input, tiler) print(result) # Outputs: (4, 2) """ - input_val = _pack_shape(input, loc=loc, ip=ip) + input_val = _pack_int_tuple(input, loc=loc, ip=ip) tiler_val = _pack_tile(tiler, loc=loc, ip=ip) res = _cute_ir.ceil_div(input=input_val, tiler=tiler_val, loc=loc, ip=ip) return _unpack_x_tuple(res, loc=loc, ip=ip) @@ -3290,17 +2523,23 @@ def round_up(a: IntTuple, b: IntTuple) -> IntTuple: """ if isinstance(a, tuple): if not a: - raise ValueError(f"inputs cannot be empty") + raise ValueError("inputs cannot be empty") if not isinstance(b, tuple): raise TypeError( - f"expects both inputs to be tuple, but got {type(a)} and {type(b)}" + f"Expected both inputs to be tuple, but got {type(a)} and {type(b)}" ) if rank(a) < rank(b): raise ValueError( - f"expects rank(a) to be greater or equal than rank(b), but got {a}, {b}" + f"Expected rank(a) to be greater or equal than rank(b), but got {a}, {b}" ) b = append(b, 1, rank(a)) return tuple(round_up(x, y) for x, y in zip(a, b)) + + if isinstance(b, tuple): + raise TypeError( + f"Expected `b` to be a single integer when `a` is not a tuple, but got {b}" + ) + return ((a + b - 1) // b) * b @@ -3359,7 +2598,7 @@ def make_layout( * make_layout((3,4), stride=(1,4)) is more readable """ if stride is not None and not is_congruent(shape, stride): - raise ValueError(f"shape and stride must be congruent") + raise ValueError("shape and stride must be congruent") shape_val = _pack_shape(shape, loc=loc, ip=ip) if stride is not None: @@ -3453,6 +2692,64 @@ def make_ordered_layout(shape: Shape, order: Shape, *, loc=None, ip=None) -> Lay ) +@dsl_user_op +def make_layout_like(input: Union[Layout, Tensor], *, loc=None, ip=None) -> Layout: + if isinstance(input, Tensor): + layout = input.layout + else: + layout = input + return _cute_ir.make_layout_like(layout, loc=loc, ip=ip) + + +class _ComposedLayoutWithInnerFunc(ComposedLayout): + @dsl_user_op + def __init__(self, inner, offset, outer, *, loc=None, ip=None): + self._inner = inner + self._offset = offset + self._outer = outer + + self._offset_val = _pack_int_tuple(offset, loc=loc, ip=ip) + + @dsl_user_op + def __call__(self, coord, *, loc=None, ip=None): + delta = self._outer(coord) + + delta_val = _pack_int_tuple(delta, loc=loc, ip=ip) + offset_val_new = _cute_ir.add_offset( + self._offset_val, delta_val, loc=loc, ip=ip + ) + offset_new = _unpack_x_tuple(offset_val_new, loc=loc, ip=ip) + + return self._inner(offset_new) + + def __str__(self): + return f"({self._inner} o {self._offset} o {self._outer})" + + @property + def type(self): + raise ValueError("type is not supported for customized composed layouts") + + @property + def is_normal(self): + return False + + @property + def inner(self, *, loc=None, ip=None): + return self._inner + + @property + def offset(self, *, loc=None, ip=None): + return self._offset + + @property + def outer(self, *, loc=None, ip=None): + return self._outer + + @property + def shape(self, *, loc=None, ip=None): + return self._outer.shape + + @dsl_user_op def make_composed_layout( inner, offset: IntTuple, outer: Layout, *, loc=None, ip=None @@ -3499,8 +2796,12 @@ def make_composed_layout( ) if isinstance(inner, Swizzle) and has_scaled_basis(outer.stride): raise TypeError(f"invalid composition {inner} o {offset} o {outer}") - offset_val = _pack_int_tuple(offset, loc=loc, ip=ip) - return _cute_ir.make_composed_layout(inner, offset_val, outer, loc=loc, ip=ip) + + if isinstance(inner, (Layout, Swizzle)): + offset_val = _pack_int_tuple(offset, loc=loc, ip=ip) + return _cute_ir.make_composed_layout(inner, offset_val, outer, loc=loc, ip=ip) + + return _ComposedLayoutWithInnerFunc(inner, offset, outer, loc=loc, ip=ip) @dsl_user_op @@ -3509,9 +2810,27 @@ def cosize( ): """Return size of codomain of layout or tensor. Return static value if type is static. + For a layout ``L = S:D`` where ``S`` is the shape and ``D`` is the stride, the codomain size is the + minimum size needed to store all possible offsets generated by the layout. This is calculated + by taking the maximum offset plus 1. + + For example, given a layout ``L = (4,(3,2)):(2,(8,1))``: + - Shape ``S = (4,(3,2))`` + - Stride ``D = (2,(8,1))`` + - Maximum offset = ``2*(4-1) + 8*(3-1) + 1*(2-1) = 6 + 16 + 1 = 23`` + - Therefore ``cosize(L) = 24`` + + **Examples:** + + .. code-block:: python + + L = cute.make_layout((4,(3,2)), stride=(2,(8,1))) # L = (4,(3,2)):(2,(8,1)) + print(cute.cosize(L)) # => 24 + :param a: Layout, ComposedLayout, or Tensor object :type a: Union[Layout, ComposedLayout, Tensor] - :param mode: List of mode(s) for cosize calculation + :param mode: List of mode(s) for cosize calculation. If empty, calculates over all modes. + If specified, calculates cosize only for the given modes. :type mode: List[int], optional :param loc: Location information for diagnostics, defaults to None :type loc: optional @@ -3520,20 +2839,27 @@ def cosize( :return: Static size of layout or tensor (fast fold) if static, or a dynamic Value :rtype: Union[int, Value] """ + from .tensor import _Tensor + if any(not is_static(m) for m in mode): raise ValueError(f"expects static mode, but got {mode}") - if isinstance(a, _Tensor): - a = a.value - res = _cute_ir.cosize(a, mode=mode, loc=loc, ip=ip) + if isinstance(a, (_Tensor, _ComposedLayout)): + res = _cute_ir.cosize(a.value, mode=mode, loc=loc, ip=ip) + else: + res = _cute_ir.cosize(a, mode=mode, loc=loc, ip=ip) return _unpack_x_tuple(res, loc=loc, ip=ip) @dsl_user_op def size_in_bytes( - dtype: Type[Numeric], layout: Union[Layout, ComposedLayout], *, loc=None, ip=None -): - """Calculate the size in bytes based on its data type and layout. + dtype: Type[Numeric], + layout: Union[Layout, ComposedLayout, None], + *, + loc=None, + ip=None, +) -> Union[int, Integer]: + """Calculate the size in bytes based on its data type and layout. The result is rounded up to the nearest byte. :param dtype: The DSL numeric data type :type dtype: Type[Numeric] @@ -3549,26 +2875,37 @@ def size_in_bytes( if not isinstance(dtype, NumericMeta): raise TypeError(f"dtype must be a Numeric, but got {dtype}") + size_in_elem = 0 if layout is None: - return 0 + size_in_elem = 0 + elif isinstance(layout, ComposedLayout): - if not isinstance(layout.inner, Swizzle): - raise TypeError( - f"invalid composed layout {layout}, inner must be a Swizzle" - ) + if isinstance(layout.inner, Swizzle): + # Swizzle layout is short-cut to cosize of outer + # User of swizzle must guarantee swizzle mapping with codomain + size_in_elem = cosize(layout.outer, loc=loc, ip=ip) + elif isinstance(layout.inner, Layout): + size_in_elem = cosize(layout.inner, loc=loc, ip=ip) else: - return cosize(layout.outer, loc=loc, ip=ip) * dtype.width // 8 + raise TypeError( + "Only support size when inner layout is Swizzle or normal Layout" + ) else: - return cosize(layout, loc=loc, ip=ip) * dtype.width // 8 + size_in_elem = cosize(layout, loc=loc, ip=ip) + + return ceil_div(size_in_elem * dtype.width, 8, loc=loc, ip=ip) # type: ignore @dsl_user_op def coalesce(input, *, target_profile: Coord = None, loc=None, ip=None): if target_profile: profile_val = _pack_coord(target_profile, loc=loc, ip=ip) - return _cute_ir.coalesce(input, target_profile=profile_val, loc=loc, ip=ip) else: - return _cute_ir.coalesce(input, loc=loc, ip=ip) + profile_val = None + + return _op_wrapper( + partial(_cute_ir.coalesce, target_profile=profile_val, loc=loc, ip=ip), input + ) @dsl_user_op @@ -3593,7 +2930,7 @@ def crd2idx(coord: Coord, layout, *, loc=None, ip=None): :returns: The result of applying the layout transformation to the provided coordinate. :rtype: Any type that the layout maps to - Example: + **Example:** .. code-block:: python @@ -3606,19 +2943,73 @@ def crd2idx(coord: Coord, layout, *, loc=None, ip=None): print(idx) foo() # Expected output: 11 """ - coord_val = _pack_coord(coord, loc=loc, ip=ip) + crd_val = _pack_coord(coord, loc=loc, ip=ip) if isinstance(layout, (tuple, int)): layout = make_layout(layout, loc=loc, ip=ip) + elif isinstance(layout, _ComposedLayout): + layout = layout.value - res = _cute_ir.crd2idx(coord_val, layout, loc=loc, ip=ip) + res = _cute_ir.crd2idx(crd_val, layout, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) # type: ignore + + +@overload +def idx2crd(idx: Int, shape: Int, *, loc=None, ip=None) -> Int: ... + + +@overload +def idx2crd(idx: IntTuple, shape: Tuple, *, loc=None, ip=None) -> Tuple: ... + + +@dsl_user_op +def idx2crd(idx, shape, *, loc=None, ip=None): + """ + Convert a linear index back into a multi-dimensional coordinate using the specified layout. + + Mapping from a linear index to the corresponding multi-dimensional coordinate in the layout's coordinate space. + It essentially "unfolds" a linear index into its constituent coordinate components. + + :param idx: The linear index to convert back to coordinates. + :type idx: : int/Integer/Tuple + :param shape: Shape of the layout defining the size of each mode + :type shape: Shape + :param loc: Optional location information for IR diagnostics. + :type loc: optional + :param ip: Optional instruction pointer or context for underlying IR functions. + :type ip: optional + :return: The result of applying the layout transformation to the provided coordinate. + :rtype: Coord + + **Examples:** + + .. code-block:: python + + import cutlass.cute as cute + @cute.jit + def foo(): + coord = cute.idx2crd(11, (5,4)) + # Computed as: 11 = 2 * 4 + 3, so coordinate is (2, 3) + print(coord) + foo() # Expected output: (2, 3) + + **Note:** + Python DSL is aligned with C++ DSL. + """ + if is_integer(idx) and is_integer(shape): + return idx + idx_val = _pack_int_tuple(idx, loc=loc, ip=ip) + shape_val = _pack_shape(shape, loc=loc, ip=ip) + res = _cute_ir.idx2crd(idx_val, shape_val, loc=loc, ip=ip) return _unpack_x_tuple(res, loc=loc, ip=ip) @dsl_user_op def recast_layout(new_type_bits, old_type_bits, src_layout, *, loc=None, ip=None): + if isinstance(src_layout, _ComposedLayout): + src_layout = src_layout.value return _cute_ir.recast_layout( new_type_bits, old_type_bits, src_layout, loc=loc, ip=ip - ) + ) # type: ignore @dsl_user_op @@ -3652,7 +3043,7 @@ def shape( :return: The shape of the input object, optionally filtered by mode :rtype: Shape - Example: + **Example:** .. code-block:: python @@ -3670,7 +3061,7 @@ def shape( if is_int_tuple(input): return get(input, mode=mode) - if isinstance(input, (Tensor, Layout)): + if isinstance(input, (Tensor, Layout, ComposedLayout)): shp = input.shape else: val = _cute_ir.get_shape(_pack_tile(input, loc=loc, ip=ip)) @@ -3692,9 +3083,13 @@ def recast_ptr( ip=None, ) -> Pointer: if dtype is not None: - if not isclass(dtype) or not issubclass(dtype, Numeric): - raise TypeError(f"dtype must be a type of Numeric, but got {dtype}") - dtype = dtype.mlir_type + if isinstance(dtype, _cute_ir.SparseElemType): + # use SparseElemType as dtype + pass + else: + if not isclass(dtype) or not issubclass(dtype, Numeric): + raise TypeError(f"dtype must be a type of Numeric, but got {dtype}") + dtype = dtype.mlir_type value_type = ptr.type.value_type if dtype is None else dtype swizzle = swizzle_.type.attribute if swizzle_ is not None else None @@ -3721,9 +3116,16 @@ def make_ptr( if isinstance(value, ir.Value) and llvm.PointerType.isinstance(value.type): value = llvm.ptrtoint(T.i64(), value) + if not isinstance(mem_space, AddressSpace): + raise TypeError(f"expects mem_space to be an AddressSpace, but got {mem_space}") + + if isinstance(value, ir.Value) and llvm.PointerType.isinstance(value.type): + value = llvm.ptrtoint(T.i64(), value) + if not is_integer(value): raise TypeError(f"expects integer value, but got {type(value)}") value = Int32(value) if mem_space == AddressSpace.tmem else Int64(value) + value = Int32(value) if mem_space == AddressSpace.tmem else Int64(value) bytes_per_elt = max(1, dtype.width // 8) if assumed_align is None: @@ -3735,292 +3137,15 @@ def make_ptr( ) aligned_ty = _cute_ir.ConstrainedIntType.get(assumed_align, type(value).width) - aligned_intptr = _cute_ir.assume(aligned_ty, value.ir_value(), loc=loc, ip=ip) + aligned_intptr = _cute_ir.assume( + aligned_ty, value.ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) data_ty = T.i8() if dtype is None else dtype.mlir_type ptr_ty = _cute_ir.PtrType.get(data_ty, mem_space, assumed_align) return _cute_ir.inttoptr(ptr_ty, aligned_intptr, loc=loc, ip=ip) -# -# Tensor API -# - - -@dsl_user_op -def make_tensor( - iterator, layout: Union[Shape, Layout, ComposedLayout], *, loc=None, ip=None -) -> Tensor: - """Creates a tensor by composing an engine (iterator/pointer) with a layout. - - A tensor is defined as T = E ∘ L, where E is an engine (array, pointer, or counting iterator) - and L is a layout that maps logical coordinates to physical offsets. The tensor - evaluates coordinates by applying the layout mapping and dereferencing the engine - at the resulting offset. - - :param iterator: Engine component (pointer, iterator, or counting iterator) that provides - data access capabilities - :type iterator: Union[Pointer, IntTuple] - :param layout: Layout component that defines the mapping from logical coordinates to - physical offsets - :type layout: Union[Shape, Layout, ComposedLayout] - :param loc: Source location for MLIR operation tracking, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation, defaults to None - :type ip: Optional[InsertionPoint] - :return: A tensor object representing the composition E ∘ L - :rtype: Tensor - - :raises ValueError: If iterator type is not supported - - **Examples:** - - .. code-block:: python - - # Create a tensor with row-major layout - layout = make_layout((64, 128), stride=(128, 1)) - tensor = make_tensor(ptr, layout) - - # Create a tensor with hierarchical layout - layout = make_layout(((128, 8), (1, 4, 1)), stride=((32, 1), (0, 8, 4096))) - tensor = make_tensor(smem_ptr, layout) - - # Create a coord tensor - layout = make_layout(2, stride=16 * E(0)) - tensor = make_tensor(5, layout) - - Notes: - - The engine (iterator) must support random access operations - - Common engine types include raw pointers, arrays, and random-access iterators - - The layout defines both the shape (logical dimensions) and stride (physical mapping) - - Supports both direct coordinate evaluation T(c) and partial evaluation (slicing) - """ - if not isinstance(layout, (Layout, ComposedLayout)): - layout = make_layout(layout, loc=loc, ip=ip) - elif isinstance(layout, ComposedLayout) and layout.type.is_normal_layout: - layout = layout.outer - - ty = None - if is_integer(iterator) or isinstance(iterator, tuple): - iterator = _pack_int_tuple(iterator, loc=loc, ip=ip) - ty = _cute_ir.CoordTensorType.get(iterator.type, layout.type) - elif isinstance(iterator, Pointer): - iterator = iterator.value - ty = _cute_ir.MemRefType.get(iterator.type, layout.type) - else: - raise TypeError(f"unsupported iterator type, got {type(iterator)}") - - return _cute_ir.make_view(result=ty, iter=iterator, layout=layout, loc=loc, ip=ip) - - -@dsl_user_op -def make_identity_tensor(shape: Shape, *, loc=None, ip=None) -> Tensor: - """Creates an identity tensor with the given shape. - - An identity tensor maps each coordinate to itself, effectively creating a counting - sequence within the shape's bounds. This is useful for generating coordinate indices - or creating reference tensors for layout transformations. - - :param shape: The shape defining the tensor's dimensions. Can be a simple integer - sequence or a hierarchical structure ((m,n),(p,q)) - :type shape: Shape - :param loc: Source location for MLIR operation tracking, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation, defaults to None - :type ip: Optional[InsertionPoint] - :return: A tensor that maps each coordinate to itself - :rtype: Tensor - - **Examples:** - - .. code-block:: python - - # Create a simple 1D coord tensor - tensor = make_identity_tensor(6) # [0,1,2,3,4,5] - - # Create a 2D coord tensor - tensor = make_identity_tensor((3,2)) # [(0,0),(1,0),(2,0),(0,1),(1,1),(2,1)] - - # Create hierarchical coord tensor - tensor = make_identity_tensor(((2,1),3)) - # [((0,0),0),((1,0),0),((0,0),1),((1,0),1),((0,0),2),((1,0),2)] - - Notes: - - The shape parameter follows CuTe's IntTuple concept - - Coordinates are ordered colexicographically - - Useful for generating reference coordinates in layout transformations - """ - shape_val = _pack_shape(shape, loc=loc, ip=ip) - return _cute_ir.make_identity_tensor(shape_val, loc=loc, ip=ip) - - -@dsl_user_op -def make_fragment( - layout_or_shape: Union[Layout, Shape], - dtype: Type[Numeric], - *, - loc=None, - ip=None, -) -> Tensor: - if not issubclass(dtype, Numeric): - raise TypeError(f"value_type must be a type of Numeric, but got {type(dtype)}") - elem_ty = dtype.mlir_type if dtype is not Boolean else T.i8() - - # Alignment for register memory is useless(?), pick-up large enough number - # to allow .128 (> 16B) load store - alignment = 32 - layout = None - if not isinstance(layout_or_shape, Layout): - layout = make_layout(layout_or_shape, loc=loc, ip=ip) - else: - layout = layout_or_shape - - ptr_ty = _cute_ir.PtrType.get(elem_ty, AddressSpace.rmem, alignment) - res_ty = _cute_ir.MemRefType.get(ptr_ty, layout.type) - tensor = _cute_ir.memref_alloca(res_ty, layout=layout, loc=loc, ip=ip) - return _Tensor(tensor.value, dtype) - - -@overload -def make_fragment_like( - src: Tensor, dtype: Optional[Type[Numeric]], *, loc=None, ip=None -) -> Tensor: ... - - -@overload -def make_fragment_like(src: Layout, *, loc=None, ip=None) -> Layout: ... - - -@overload -def make_fragment_like(src: ComposedLayout, *, loc=None, ip=None) -> ComposedLayout: ... - - -@dsl_user_op -def make_fragment_like(src, dtype=None, *, loc=None, ip=None): - """Create tensor with a compact layout in the same shape as the source on stack. - - This function either creates a fragment tensor with compact layout in - same shape as the source layout or a new layout with the same shape as the source. - The strides of the new layout follow the order induced by the source's strides, with a - special handling of the 0th mode: it is always stride-1 and generated in column-major order - (LayoutLeft). - - :param src: The source layout or tensor whose shape will be matched - :type src: Union[Layout, ComposedLayout, Tensor] - :param dtype: The element type for the fragment tensor, defaults to None - :type dtype: Type[Numeric], optional - :param loc: Source location for MLIR operations, defaults to None - :type loc: Location, optional - :param ip: Insertion point for MLIR operations, defaults to None - :type ip: InsertionPoint, optional - - :return: A new layout or fragment tensor with matching shape - :rtype: Union[Layout, Tensor] - - **Examples:** - - Creating a rmem tensor from a tensor: - - .. code-block:: python - - smem_tensor = cute.make_tensor(smem_ptr, layout) - frag_tensor = cute.make_fragment_like(smem_tensor, cutlass.Float32) - # frag_tensor will be a register-backed tensor with the same shape - - Creating a fragment with a different element type: - - .. code-block:: python - - tensor = cute.make_tensor(gmem_ptr, layout) - bool_frag = cute.make_fragment_like(tensor, cutlass.Boolean) - # bool_frag will be a register-backed tensor with Boolean elements - - **Notes** - - - When used with a Tensor, if a type is provided, it will create a new - fragment tensor with that element type. - - For layouts with ScaledBasis strides, the function creates a fragment - from the shape only. - - This function is commonly used in GEMM and other tensor operations to - create register storage for intermediate results. - - """ - if isinstance(src, (Layout, ComposedLayout)): - new_layout = None - # Create base fragment layout - if isinstance(src, Layout) and has_scaled_basis(src.stride): - # For scaled basis strides, create fragment from shape only - new_layout = _cute_ir.make_fragment_like( - make_layout(src.shape), loc=loc, ip=ip - ) - else: - # Otherwise use full source layout - new_layout = _cute_ir.make_fragment_like(src, loc=loc, ip=ip) - if dtype is not None: - # call make_fragment to convert layout to tensor - return make_fragment(new_layout, dtype, loc=loc, ip=ip) - else: - return new_layout - elif isinstance(src, Tensor): - if isinstance(src.type, _cute_ir.CoordTensorType): - if dtype is None: - raise ValueError( - "dtype must be provided when src is a coordinate tensor" - ) - - new_layout = _cute_ir.make_fragment_like( - make_layout(src.shape), loc=loc, ip=ip - ) - return make_fragment(new_layout, dtype, loc=loc, ip=ip) - else: - dtype = src.element_type if dtype is None else dtype - ty = dtype.mlir_type if dtype is not Boolean else T.i8() - new_tensor = _cute_ir.make_fragment_like( - src.value, elem_type=ty, loc=loc, ip=ip - ) - return _Tensor(new_tensor.value, dtype) - else: - raise TypeError( - f"src must be a Layout or ComposedLayout or tensor, got {type(src)}" - ) - - -@dsl_user_op -def recast_tensor( - src: Tensor, dtype: Type[Numeric], swizzle_=None, *, loc=None, ip=None -): - if not isclass(dtype) or not issubclass(dtype, Numeric): - raise TypeError(f"dtype must be a type of Numeric, but got {dtype}") - - if dtype is Boolean: - dst_width = 8 - else: - dst_width = dtype.width - - if src.element_type is Boolean: - src_width = 8 - else: - src_width = src.element_type.width - - src_iter = recast_ptr(src.iterator, dtype=dtype, loc=loc, ip=ip) - src_layout = recast_layout(dst_width, src_width, src.layout, loc=loc, ip=ip) - return make_tensor(src_iter, src_layout, loc=loc, ip=ip) - - -@dsl_user_op -def domain_offset(coord: Coord, tensor: Tensor, *, loc=None, ip=None) -> Tensor: - offset = crd2idx(coord, tensor.layout, loc=loc, ip=ip) - if isinstance(tensor.iterator, Pointer): - return make_tensor(tensor.iterator + offset, tensor.layout) - elif is_integer(tensor.iterator) or isinstance(tensor.iterator, tuple): - new_iter = _cute_ir.add_offset( - _pack_int_tuple(tensor.iterator), _pack_int_tuple(offset) - ) - return make_tensor(_unpack_x_tuple(new_iter), tensor.layout) - else: - raise ValueError(f"unsupported tensor for domain_offset, got {tensor}") - - # # Layout algebra # @@ -4030,8 +3155,10 @@ def domain_offset(coord: Coord, tensor: Tensor, *, loc=None, ip=None) -> Tensor: def composition( lhs: Layout, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None ) -> Layout: ... - - +@overload +def composition( + lhs: ComposedLayout, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None +) -> ComposedLayout: ... @overload def composition( lhs: Tensor, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None @@ -4068,7 +3195,7 @@ def composition(lhs, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None): R(c) = lhs(rhs(c)). :rtype: Layout or Tensor - Example: + **Example:** .. code-block:: python @@ -4083,10 +3210,12 @@ def composition(lhs, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None): L3 = cute.composition(L1, L2) # L3 now maps coordinates through L2 then L1 """ + from .tensor import _Tensor + rhs_val = rhs if not isinstance(rhs, Layout) and isinstance(rhs, (int, tuple)): rhs_val = _pack_tile(rhs, loc=loc, ip=ip) - if isinstance(lhs, _Tensor): + if isinstance(lhs, (_Tensor, _ComposedLayout)): lhs = lhs.value return _cute_ir.composition(lhs, rhs_val, loc=loc, ip=ip) @@ -4116,7 +3245,7 @@ def complement( :returns: The complement layout :rtype: Layout - Example: + **Example:** .. code-block:: python @@ -4139,28 +3268,52 @@ def complement( @dsl_user_op def right_inverse(input: Layout, *, loc=None, ip=None) -> Layout: if not isinstance(input, Layout): - raise TypeError(f"expects input of type Layout, but got {type(input)}") + raise TypeError(f"Expected input of type Layout, but got {type(input)}") + return _cute_ir.right_inverse(input=input, loc=loc, ip=ip) @dsl_user_op def left_inverse(input: Layout, *, loc=None, ip=None) -> Layout: if not isinstance(input, Layout): - raise TypeError(f"expects input of type Layout, but got {type(input)}") + raise TypeError(f"Expected input of type Layout, but got {type(input)}") + return _cute_ir.left_inverse(input=input, loc=loc, ip=ip) @overload -def logical_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +def logical_product(block: Layout, tiler: Tile, *, loc=None, ip=None) -> Layout: ... @overload def logical_product( - block: ComposedLayout, tiler: Layout, *, loc=None, ip=None + block: ComposedLayout, tiler: Tile, *, loc=None, ip=None ) -> ComposedLayout: ... @dsl_user_op -def logical_product(block, tiler: Layout, *, loc=None, ip=None): - return _cute_ir.logical_product(input=block, tiler=tiler, loc=loc, ip=ip) +def logical_product(block, tiler: Tile, *, loc=None, ip=None): + if isinstance(block, _ComposedLayout): + block = block.value + + tiler_val = tiler + if isinstance(tiler, Layout): + return _cute_ir.logical_product(input=block, tiler=tiler_val, loc=loc, ip=ip) + if tiler is None: + return block + if is_integer(tiler): + return _cute_ir.logical_product( + input=block, tiler=make_layout(tiler_val), loc=loc, ip=ip + ) + assert rank(tiler_val) <= rank(block), "logical_product: Too many modes in tiler." + tiler_rank = rank(tiler_val) + block_rank = rank(block) + res = tuple( + logical_product(block[i], tiler_val[i]) if i < tiler_rank else block[i] + for i in range(block_rank) + ) + + res_shape = tuple(res[i].shape for i in range(block_rank)) + res_stride = tuple(res[i].stride for i in range(block_rank)) + return make_layout(res_shape, stride=res_stride, loc=loc, ip=ip) @overload @@ -4173,7 +3326,10 @@ def zipped_product( @dsl_user_op def zipped_product(block, tiler: Layout, *, loc=None, ip=None): - return _cute_ir.zipped_product(input=block, tiler=tiler, loc=loc, ip=ip) + if isinstance(block, _ComposedLayout): + return _cute_ir.zipped_product(input=block.value, tiler=tiler, loc=loc, ip=ip) + else: + return _cute_ir.zipped_product(input=block, tiler=tiler, loc=loc, ip=ip) @overload @@ -4186,7 +3342,10 @@ def tiled_product( @dsl_user_op def tiled_product(block, tiler: Layout, *, loc=None, ip=None): - return _cute_ir.tiled_product(input=block, tiler=tiler, loc=loc, ip=ip) + if isinstance(block, _ComposedLayout): + return _cute_ir.tiled_product(input=block.value, tiler=tiler, loc=loc, ip=ip) + else: + return _cute_ir.tiled_product(input=block, tiler=tiler, loc=loc, ip=ip) @overload @@ -4199,7 +3358,10 @@ def flat_product( @dsl_user_op def flat_product(block, tiler: Layout, *, loc=None, ip=None): - return _cute_ir.flat_product(input=block, tiler=tiler, loc=loc, ip=ip) + if isinstance(block, _ComposedLayout): + return _cute_ir.flat_product(input=block.value, tiler=tiler, loc=loc, ip=ip) + else: + return _cute_ir.flat_product(input=block, tiler=tiler, loc=loc, ip=ip) @overload @@ -4212,7 +3374,10 @@ def raked_product( @dsl_user_op def raked_product(block, tiler: Layout, *, loc=None, ip=None): - return _cute_ir.raked_product(input=block, tiler=tiler, loc=loc, ip=ip) + if isinstance(block, _ComposedLayout): + return _cute_ir.raked_product(input=block.value, tiler=tiler, loc=loc, ip=ip) + else: + return _cute_ir.raked_product(input=block, tiler=tiler, loc=loc, ip=ip) @overload @@ -4225,7 +3390,10 @@ def blocked_product( @dsl_user_op def blocked_product(block, tiler: Layout, *, loc=None, ip=None): - return _cute_ir.blocked_product(input=block, tiler=tiler, loc=loc, ip=ip) + if isinstance(block, _ComposedLayout): + return _cute_ir.blocked_product(input=block.value, tiler=tiler, loc=loc, ip=ip) + else: + return _cute_ir.blocked_product(input=block, tiler=tiler, loc=loc, ip=ip) @overload @@ -4236,14 +3404,11 @@ def logical_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor @dsl_user_op def logical_divide(target, tiler: Tiler, *, loc=None, ip=None): - res_type = None - if isinstance(target, _Tensor): - res_type = target.element_type - target = target.value if isinstance(tiler, tuple): - tiler = _pack_tile(tiler, loc=loc, ip=ip) - res = _cute_ir.logical_divide(input=target, tiler=tiler, loc=loc, ip=ip) - return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res + tiler = _pack_tile(tiler, loc=loc, ip=ip) # type: ignore + return _op_wrapper( + partial(_cute_ir.logical_divide, tiler=tiler, loc=loc, ip=ip), target + ) @overload @@ -4254,14 +3419,11 @@ def zipped_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: @dsl_user_op def zipped_divide(target, tiler: Tiler, *, loc=None, ip=None): - res_type = None - if isinstance(target, _Tensor): - res_type = target.element_type - target = target.value if isinstance(tiler, tuple): - tiler = _pack_tile(tiler, loc=loc, ip=ip) - res = _cute_ir.zipped_divide(input=target, tiler=tiler, loc=loc, ip=ip) - return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res + tiler = _pack_tile(tiler, loc=loc, ip=ip) # type: ignore + return _op_wrapper( + partial(_cute_ir.zipped_divide, tiler=tiler, loc=loc, ip=ip), target + ) @overload @@ -4272,32 +3434,26 @@ def tiled_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: @dsl_user_op def tiled_divide(target, tiler: Tiler, *, loc=None, ip=None): - res_type = None - if isinstance(target, _Tensor): - res_type = target.element_type - target = target.value if isinstance(tiler, tuple): tiler = _pack_tile(tiler, loc=loc, ip=ip) - res = _cute_ir.tiled_divide(input=target, tiler=tiler, loc=loc, ip=ip) - return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res + return _op_wrapper( + partial(_cute_ir.tiled_divide, tiler=tiler, loc=loc, ip=ip), target + ) @overload -def flat_divide(target: Layout, tiler: Tiler, *, loc=None, ip=None) -> Layout: ... +def flat_divide(target: Layout, tiler: Tile, *, loc=None, ip=None) -> Layout: ... @overload -def flat_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: ... +def flat_divide(target: Tensor, tiler: Tile, *, loc=None, ip=None) -> Tensor: ... @dsl_user_op -def flat_divide(target, tiler: Tiler, *, loc=None, ip=None): - res_type = None - if isinstance(target, _Tensor): - res_type = target.element_type - target = target.value +def flat_divide(target, tiler: Tile, *, loc=None, ip=None): if isinstance(tiler, tuple): tiler = _pack_tile(tiler, loc=loc, ip=ip) - res = _cute_ir.flat_divide(input=target, tiler=tiler, loc=loc, ip=ip) - return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res + return _op_wrapper( + partial(_cute_ir.flat_divide, tiler=tiler, loc=loc, ip=ip), target + ) # @@ -4309,6 +3465,8 @@ def flat_divide(target, tiler: Tiler, *, loc=None, ip=None): def max_common_layout( a: Union[Layout, Tensor], b: Union[Layout, Tensor], *, loc=None, ip=None ) -> Layout: + from .tensor import _Tensor + a_layout = a.layout if isinstance(a, _Tensor) else a b_layout = b.layout if isinstance(b, _Tensor) else b @@ -4330,6 +3488,8 @@ def max_common_layout( def max_common_vector( a: Union[Layout, Tensor], b: Union[Layout, Tensor], *, loc=None, ip=None ) -> int: + from .tensor import _Tensor + a_layout = a.layout if isinstance(a, _Tensor) else a b_layout = b.layout if isinstance(b, _Tensor) else b @@ -4347,18 +3507,25 @@ def max_common_vector( return 1 -@dsl_user_op +@overload def tile_to_shape( - atom: Union[Layout, ComposedLayout], - trg_shape: Shape, - order: Shape, - *, - loc=None, - ip=None, -) -> Union[Layout, ComposedLayout]: - trg_shape = _pack_shape(shape(trg_shape), loc=loc, ip=ip) - order = _pack_int_tuple(order, loc=loc, ip=ip) - return _cute_ir.tile_to_shape(atom, trg_shape, order, loc=loc, ip=ip) + atom: Layout, trg_shape: Shape, order: Shape, *, loc=None, ip=None +) -> Layout: ... +@overload +def tile_to_shape( + atom: ComposedLayout, trg_shape: Shape, order: Shape, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def tile_to_shape(atom, trg_shape: Shape, order: Shape, *, loc=None, ip=None): + trg_shape = _pack_shape(shape(trg_shape), loc=loc, ip=ip) # type: ignore + order = _pack_int_tuple(order, loc=loc, ip=ip) # type: ignore + + if isinstance(atom, _ComposedLayout): + return _cute_ir.tile_to_shape(atom.value, trg_shape, order, loc=loc, ip=ip) + else: + return _cute_ir.tile_to_shape(atom, trg_shape, order, loc=loc, ip=ip) @dsl_user_op @@ -4374,7 +3541,7 @@ def local_partition( if isinstance(index, cutlass_arith.ArithValue): index_val = index else: - index_val = index.ir_value() + index_val = index.ir_value(loc=loc, ip=ip) if index_val.type.width > 32: raise NotImplementedError( f"Index value should be 32-bit or smaller integer type, but got {index_val.type}" @@ -4387,18 +3554,18 @@ def local_partition( @dsl_user_op def local_tile( input: Tensor, - tiler: Union[Layout, Shape], + tiler: Tiler, coord: Coord, - proj: XTuple = None, + proj: XTuple = None, # type: ignore *, loc=None, ip=None, ) -> Tensor: - tiler_val = _pack_shape(tiler, loc=loc, ip=ip) + tiler_val = _pack_tile(tiler, loc=loc, ip=ip) coord_val = _pack_coord(coord, loc=loc, ip=ip) if proj is not None: if not isinstance(proj, tuple): - raise TypeError(f"Expects tuple for proj, but got {type(proj)}") + raise TypeError(f"Expected tuple for proj, but got {type(proj)}") proj_val = _pack_coord(proj, loc=loc, ip=ip) proj = proj_val.type.attribute @@ -4445,662 +3612,63 @@ def make_layout_image_mask( # Given that we replace only one mode with _, the rank of the slice should be 1 assert rank(sliced_lay) == 1 + if not is_static(sliced_lay): + raise ValueError("make_layout_image_mask requires the layout to be static") + # Create the mask of the image mcast_mask = Int16(0) - for i in range(size(sliced_lay)): + for i in range(size(sliced_lay)): # type: ignore mcast_mask = mcast_mask | (1 << sliced_lay(i)) mcast_mask <<= offset return Int16(mcast_mask) -#################################################################################################### -# -# Atom -# -#################################################################################################### - - -class Op(ABC): +def leading_dim(shape: Shape, stride: Stride) -> Union[int, Tuple[int, ...], None]: # type: ignore """ - Operation abstract base class. + Find the leading dimension of a shape and stride. + + :param shape: The shape of the tensor or layout + :type shape: Shape + :param stride: The stride of the tensor or layout + :type stride: Stride + :return: The leading dimension index or indices + :rtype: Union[int, Tuple[int, ...], None] + + The return value depends on the stride pattern: + + * If a single leading dimension is found, returns an integer index + * If nested leading dimensions are found, returns a tuple of indices + * If no leading dimension is found, returns None """ - pass - - -class MmaOp(Op): - """ - MMA Operation abstract base class. - """ - - @abstractmethod - def _make_trait(self, *, loc=None, ip=None, **kwargs): - pass - - -class CopyOp(Op): - """ - Copy Operation abstract base class. - """ - - @abstractmethod - def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ): - pass - - -class Trait(ABC): - """ - Trait abstract base class. - - Traits are internal-only classes used by Atoms that wrap the underlying IR Value. The Python - user should only interact with Ops and Atoms. - """ - - def __init__(self, value: ir.Value) -> None: - self.value = value - - def __extract_mlir_values__(self): - return [self.value] - - def __new_from_mlir_values__(self, values): - return self.__class__(values[0]) - - def set(self, field, value, *, loc=None, ip=None) -> None: - raise NotImplementedError( - "set not implemented, the requesting Atom has likely no runtime state" - ) - - def unpack(self, *, loc=None, ip=None, **kwargs) -> ir.Value: - return self.value - - -class Atom(ABC): - """ - Atom base class. - - An Atom is the composition of - - - a MMA or Copy Operation; - - an internal MMA or Copy Trait. - - An Operation is a pure Python class that is used to model a specific MMA or Copy instruction. - The Trait wraps the underlying IR Value and provides access to the metadata of the instruction - encoded using CuTe Layouts. When the Trait can be constructed straighforwardly from an - Operation, the ``make_mma_atom`` or ``make_copy_atom`` API should be used. There are cases where - constructing the metadata is not trivial and requires more information, for example to determine - the number of bytes copied per TMA instruction ("the TMA vector length"). In such cases, - dedicated helper functions are provided with an appropriate API such that the Atom is - constructed internally in an optimal fashion for the user. - """ - - def __init__(self, op: Op, trait: Trait) -> None: - self._op = op - self._trait = trait - - def __extract_mlir_values__(self): - return extract_mlir_values(self._trait) - - def __new_from_mlir_values__(self, values): - return self.__class__(self.op, new_from_mlir_values(self._trait, values)) - - @property - def op(self) -> Op: - return self._op - - @property - def type(self): - return self._trait.value.type - - @dsl_user_op - def set(self, modifier, value, *, loc=None, ip=None) -> None: - """ - Sets runtime fields of the Atom. - - Some Atoms have runtime state, for example a tcgen05 MMA Atom - - - .. code-block:: python - - tiled_mma = cute.make_tiled_mma(some_tcgen05_mma_op) - tiled_mma.set(cute.nvgpu.tcgen05.Field.ACCUMULATE, True) - - The ``set`` method provides a way to the user to modify such runtime state. Modifiable - fields are provided by arch-specific enumerations, for example ``tcgen05.Field``. The Atom - instance internally validates the field as well as the value provided by the user to set - the field to. - """ - self._trait.set(modifier, value, loc=loc, ip=ip) - - def _unpack(self, *, loc=None, ip=None, **kwargs) -> ir.Value: - return self._trait.unpack(loc=loc, ip=ip, **kwargs) - - -#################################################################################################### -# -# MMA Atoms, TiledMma, and ThrMma -# -#################################################################################################### - - -class MmaAtom(Atom): - """ - The MMA Atom class. - """ - - def __str__(self) -> str: - res = "MMA Atom\n" - res += " ThrID: " + pretty_str(self.thr_id) + "\n" - res += " Shape MNK: " + pretty_str(self.shape_mnk) + "\n" - res += " TV Layout A: " + pretty_str(self.tv_layout_A) + "\n" - res += " TV Layout B: " + pretty_str(self.tv_layout_B) + "\n" - res += " TV Layout C: " + pretty_str(self.tv_layout_C) - return res - - # - # Properties - # - - @property - def thr_id(self) -> Layout: - return _cute_ir.static(self._trait.value.type.thr_id) - - @property - def shape_mnk(self) -> Shape: - return _unpack_x_tuple(self._trait.value.type.shape_mnk) - - @property - def tv_layout_A(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_a_tv) - - @property - def tv_layout_B(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_b_tv) - - @property - def tv_layout_C(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_c_tv) - - # - # make_fragment - # - - @dsl_user_op - def make_fragment_A(self, input, *, loc=None, ip=None): - # input could be memref/shape/layout for tmem based fragment - if isinstance(input, _Tensor): - if self.op is not None: - self.op._verify_fragment_A(input, loc=loc, ip=ip) - input = input.value - if isinstance(input, tuple): - input = _pack_shape(input, loc=loc, ip=ip) - return _cute_ir.mma_make_fragment( - _cute_ir.MmaOperand.A, - self._trait.value, - input, - loc=loc, - ip=ip, - ) - - @dsl_user_op - def make_fragment_B(self, input, *, loc=None, ip=None): - if isinstance(input, _Tensor): - if self.op is not None: - self.op._verify_fragment_B(input, loc=loc, ip=ip) - input = input.value - return _cute_ir.mma_make_fragment( - _cute_ir.MmaOperand.B, - self._trait.value, - input, - loc=loc, - ip=ip, - ) - - @dsl_user_op - def make_fragment_C(self, input, *, loc=None, ip=None): - # input could be memref/shape/layout for tmem based fragment - if isinstance(input, _Tensor): - input = input.value - if isinstance(input, tuple): - input = _pack_shape(input, loc=loc, ip=ip) - return _cute_ir.mma_make_fragment( - _cute_ir.MmaOperand.C, - self._trait.value, - input, - loc=loc, - ip=ip, - ) - - -class TiledMma(MmaAtom): - """ - The tiled MMA class. - """ - - def __str__(self) -> str: - res = "Tiled MMA\n" - res += " Thr Layout VMNK: " + pretty_str(self.thr_layout_vmnk) + "\n" - res += " Permutation MNK: " + pretty_str(self.permutation_mnk) + "\n" - res += "MMA Atom\n" - res += " ThrID: " + pretty_str(self.thr_id) + "\n" - res += " Shape MNK: " + pretty_str(self.shape_mnk) + "\n" - res += " TV Layout A: " + pretty_str(self.tv_layout_A) + "\n" - res += " TV Layout B: " + pretty_str(self.tv_layout_B) + "\n" - res += " TV Layout C: " + pretty_str(self.tv_layout_C) - return res - - # - # Properties - # - - @property - def tv_layout_A_tiled(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_a_tv_tiled) - - @property - def tv_layout_B_tiled(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_b_tv_tiled) - - @property - def tv_layout_C_tiled(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_c_tv_tiled) - - @property - def permutation_mnk(self) -> Tile: - return _unpack_x_tuple(self._trait.value.type.permutation_mnk) - - @property - def thr_layout_vmnk(self) -> Layout: - return _cute_ir.static(self._trait.value.type.thr_layout_vmnk) - - @property - def size(self) -> int: - return self._trait.value.type.size - - # - # Tiler - # - - def get_tile_size(self, mode_idx: int) -> Shape: - assert (mode_idx >= 0) and (mode_idx < 3) - perm_tile = self.permutation_mnk[mode_idx] - if perm_tile is None: - thr_layout_vmnk = self.thr_layout_vmnk - atom_shape_mnk = self.shape_mnk - return size(atom_shape_mnk, mode=[mode_idx]) * size( - thr_layout_vmnk, mode=[mode_idx + 1] - ) - else: - return size(perm_tile) - - # - # get_slice - # - - def get_slice(self, thr_idx: Union[int, Int32]) -> "ThrMma": - return ThrMma(self.op, self._trait, thr_idx) - - # - # partition_shape - # - - def _partition_shape(self, operand_id, shape, *, loc=None, ip=None): - shape = _pack_shape(shape, loc=loc, ip=ip) - return _unpack_x_tuple( - _cute_ir.tiled_mma_partition_shape( - operand_id, self._trait.value, shape, loc=loc, ip=ip - ), - loc=loc, - ip=ip, - ) - - @dsl_user_op - def partition_shape_A(self, shape_mk, *, loc=None, ip=None): - return self._partition_shape(_cute_ir.MmaOperand.A, shape_mk, loc=loc, ip=ip) - - @dsl_user_op - def partition_shape_B(self, shape_nk, *, loc=None, ip=None): - return self._partition_shape(_cute_ir.MmaOperand.B, shape_nk, loc=loc, ip=ip) - - @dsl_user_op - def partition_shape_C(self, shape_mn, *, loc=None, ip=None): - return self._partition_shape(_cute_ir.MmaOperand.C, shape_mn, loc=loc, ip=ip) - - # - # _thrfrg - # - - @overload - def _thrfrg(self, operand_id, input: Layout, *, loc=None, ip=None) -> Layout: ... - - @overload - def _thrfrg(self, operand_id, input: Tensor, *, loc=None, ip=None) -> Tensor: ... - - def _thrfrg(self, operand_id, input, *, loc=None, ip=None) -> Union[Tensor, Layout]: - if isinstance(input, Tensor): - return make_tensor( - input.iterator, - self._thrfrg(operand_id, input.layout, loc=loc, ip=ip), - ) - elif isinstance(input, Layout): - if not is_static(input.type): - raise ValueError(f"Expects a static layout but got {input.type}") - return _cute_ir.static( - self._trait.value.type.thrfrg(operand_id, input), loc=loc, ip=ip - ) - - raise ValueError( - f"Expects a layout or a tensor as input but got {type(input)=}" - ) - - def _thrfrg_A( - self, input: Union[Layout, Tensor], *, loc=None, ip=None - ) -> Union[Layout, Tensor]: - return self._thrfrg(_cute_ir.MmaOperand.A, input, loc=loc, ip=ip) - - def _thrfrg_B( - self, input: Union[Layout, Tensor], *, loc=None, ip=None - ) -> Union[Layout, Tensor]: - return self._thrfrg(_cute_ir.MmaOperand.B, input, loc=loc, ip=ip) - - def _thrfrg_C( - self, input: Union[Layout, Tensor], *, loc=None, ip=None - ) -> Union[Layout, Tensor]: - return self._thrfrg(_cute_ir.MmaOperand.C, input, loc=loc, ip=ip) - - -class ThrMma(TiledMma): - """ - The thread MMA class for modeling a thread-slice of a tiled MMA. - """ - - def __init__(self, op: Op, trait: Trait, thr_idx: Union[int, Int32]) -> None: - super().__init__(op, trait) - self._thr_idx = thr_idx - - def __new_from_mlir_values__(self, values): - return self.__class__( - self.op, new_from_mlir_values(self._trait, values), self.thr_idx - ) - - @property - def thr_idx(self): - return self._thr_idx - - @dsl_user_op - def partition_A(self, input_mk: Tensor, *, loc=None, ip=None) -> Tensor: - thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) - return _cute_ir.tiled_mma_partition( - _cute_ir.MmaOperand.A, - self._trait.value, - input_mk.value, - thr_idx, - loc=loc, - ip=ip, - ) - - @dsl_user_op - def partition_B(self, input_nk: Tensor, *, loc=None, ip=None) -> Tensor: - thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) - return _cute_ir.tiled_mma_partition( - _cute_ir.MmaOperand.B, - self._trait.value, - input_nk.value, - thr_idx, - loc=loc, - ip=ip, - ) - - @dsl_user_op - def partition_C(self, input_mn: Tensor, *, loc=None, ip=None) -> Tensor: - thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) - return _cute_ir.tiled_mma_partition( - _cute_ir.MmaOperand.C, - self._trait.value, - input_mn.value, - thr_idx, - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def make_mma_atom(op: MmaOp, *, loc=None, ip=None, **kwargs) -> MmaAtom: - """ - Makes an MMA Atom from an MMA Operation. - - This function creates an MMA Atom from a given MMA Operation. Arbitrary kw arguments can be - provided for Op-specific additional parameters. They are not used as of today. - - :param op: The MMA Operation to construct an Atom for - :type op: MmaOp - :return: The MMA Atom - :rtype: MmaAtom - """ - trait = op._make_trait(loc=loc, ip=ip, **kwargs) - return MmaAtom(op, trait) - - -@dsl_user_op -def make_tiled_mma( - op_or_atom: Union[Op, MmaAtom], - atom_layout_mnk=(1, 1, 1), - permutation_mnk=None, - *, - loc=None, - ip=None, - **kwargs, -) -> TiledMma: - """ - Makes a tiled MMA from an MMA Operation or an MMA Atom. - - :param op_or_atom: The MMA Operation or Atom - :type op_or_atom: Union[Op, MmaAtom] - :param atom_layout_mnk: A Layout describing the tiling of Atom across threads - :type atom_layout_mnk: Layout - :param permutation_mnk: A permutation Tiler describing the tiling of Atom across values including any permutation of such tiling - :type permutation_mnk: Tiler - :return: The resulting tiled MMA - :rtype: TiledMma - """ - if isinstance(op_or_atom, Op): - op = op_or_atom - atom = make_mma_atom(op_or_atom, loc=loc, ip=ip, **kwargs) - elif isinstance(op_or_atom, MmaAtom): - op = op_or_atom.op - atom = op_or_atom - else: - raise TypeError( - f"expected an MMA Op or Atom, but got an instance of {type(op_or_atom)}" - ) - if isinstance(atom_layout_mnk, tuple): - atom_layout_mnk = make_layout(atom_layout_mnk, loc=loc, ip=ip) - if rank(atom_layout_mnk) != 3: - raise ValueError(f"expects rank-3 MNK atom layout, but got {atom_layout_mnk}") - permutation_mnk_ty = None - if permutation_mnk is not None: - permutation_mnk_ty = _pack_tile(permutation_mnk, loc=loc, ip=ip).type - ty = _cute_nvgpu_ir.TiledMmaType.get( - atom._trait.value.type, - atom_layout_mnk.type, - permutation_mnk_ty, - ) - val = _cute_ir.make_tiled_mma(ty, atom._trait.value, loc=loc, ip=ip) - # Instead of modifying atom which might have been provided by the user, create a brand new - # trait instance and replace the Atom ir.Value with the tiled one - trait = new_from_mlir_values(atom._trait, [val]) - return TiledMma(op, trait) - - -#################################################################################################### -# -# Copy Atoms, TiledCopy, and ThrCopy -# -#################################################################################################### - - -class CopyAtom(Atom): - """ - The Copy Atom class. - """ - - def __str__(self) -> str: - res = "Copy Atom\n" - res += " ThrID: " + str(self.thr_id) + "\n" - res += " TV Layout Src: " + str(self.layout_src_tv) + "\n" - res += " TV Layout Dst: " + str(self.layout_dst_tv) + "\n" - res += " Value type: " + str(self._trait.value.type.value_type) - return res - - # - # Properties - # - - @property - def value_type(self) -> Type[Numeric]: - return Numeric.from_mlir_type(self._trait.value.type.value_type) - - @property - def thr_id(self) -> Layout: - return _cute_ir.static(self._trait.value.type.thr_id) - - @property - def layout_src_tv(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_src_tv) - - @property - def layout_dst_tv(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_dst_tv) - - -class TiledCopy(CopyAtom): - """ - The tiled Copy class. - """ - - def __str__(self) -> str: - res = "Tiled Copy\n" - res += " Tiler MN: " + pretty_str(self.tiler_mn) + "\n" - res += " TV Layout tiled: " + str(self.layout_tv_tiled) + "\n" - res += "Copy Atom\n" - res += " ThrID: " + str(self.thr_id) + "\n" - res += " TV Layout Src: " + str(self.layout_src_tv) + "\n" - res += " TV Layout Dst: " + str(self.layout_dst_tv) + "\n" - res += " Value type: " + str(self._trait.value.type.value_type) - return res - - # - # Properties - # - - @property - def layout_tv_tiled(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_tv_tiled) - - @property - def tiler_mn(self) -> Tile: - return _unpack_x_tuple(self._trait.value.type.tiler_mn) - - @property - def layout_src_tv_tiled(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_src_tv_tiled) - - @property - def layout_dst_tv_tiled(self) -> Layout: - return _cute_ir.static(self._trait.value.type.layout_dst_tv_tiled) - - @property - def size(self) -> int: - return self._trait.value.type.size - - # - # get_slice and retile - # - - def get_slice(self, thr_idx: Union[int, Int32]) -> "ThrCopy": - return ThrCopy(self.op, self._trait, thr_idx) - - @dsl_user_op - def retile(self, src, *, loc=None, ip=None): - return _cute_ir.tiled_copy_retile( - tiled_copy=self._trait.value, input=src.value, loc=loc, ip=ip - ) - - -class ThrCopy(TiledCopy): - """ - The thread Copy class for modeling a thread-slice of a tiled Copy. - """ - - def __init__(self, op: Op, trait: Trait, thr_idx: Union[int, Int32]) -> None: - super().__init__(op, trait) - self._thr_idx = thr_idx - - def __new_from_mlir_values__(self, values): - return self.__class__( - self.op, new_from_mlir_values(self._trait, values), self.thr_idx - ) - - @property - def thr_idx(self): - return self._thr_idx - - @dsl_user_op - def partition_S(self, src: Tensor, *, loc=None, ip=None) -> Tensor: - thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) - return _cute_ir.tiled_copy_partition_S( - self._trait.value, src.value, thr_idx, loc=loc, ip=ip - ) - - @dsl_user_op - def partition_D(self, dst: Tensor, *, loc=None, ip=None) -> Tensor: - thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) - return _cute_ir.tiled_copy_partition_D( - self._trait.value, dst.value, thr_idx, loc=loc, ip=ip - ) - - -@dsl_user_op -def make_copy_atom( - op: CopyOp, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs -) -> CopyAtom: - """ - Makes a Copy Atom from a Copy Operation. - - This function creates a Copy Atom from a given Copy Operation. Arbitrary kw arguments can be - provided for Op-specific additional parameters. - - Example: - - .. code-block:: python - - op = cute.nvgpu.CopyUniversalOp() - atom = cute.make_copy_atom(op, tensor_dtype, num_bits_per_copy=64) - - :param op: The Copy Operation to construct an Atom for - :type op: CopyOp - :param copy_internal_type: An internal data type used to construct the source/destination layouts in unit of tensor elements - :type copy_internal_type: Type[Numeric] - :return: The Copy Atom - :rtype: CopyAtom - """ - trait = op._make_trait(copy_internal_type, loc=loc, ip=ip, **kwargs) - return CopyAtom(op, trait) + def pred_fn(val, pos): + # skip dynamic values which can't be compared + # find the candidate target val, stride at this position is 1 + if (not is_dynamic_expression(val)) and (val == 1): + # extract the shape at this position + mode = [pos] if isinstance(pos, int) else list(pos) + s = get(shape, mode) + if is_dynamic_expression(s) or s != 1: + # shape at this position is dynamic value or not 1 + # we found the leading dimension + return True + return False + + return find_if(stride, pred_fn=pred_fn) @dsl_user_op def make_layout_tv( thr_layout: Layout, val_layout: Layout, *, loc=None, ip=None ) -> Tuple[Shape, Layout]: - """Create a thread-value layout for partitioning data tensors. + """Create a thread-value layout by repeating the val_layout over the thr_layout. This function creates a thread-value layout that maps between ``(thread_idx, value_idx)`` - coordinates and logical ``(M,N)`` coordinates. The thread layout must be compact to ensure + coordinates and logical ``(M,N)`` coordinates. The thread and value layouts must be compact to ensure proper partitioning. - This implements the thread-value partitioning pattern shown in - Figure TVLayout, where data is partitioned across threads and values within each thread. + This implements the thread-value partitioning pattern where data is partitioned + across threads and values within each thread. :param thr_layout: Layout mapping from ``(TileM,TileN)`` coordinates to thread IDs (must be compact) :type thr_layout: Layout @@ -5110,7 +3678,6 @@ def make_layout_tv( :type loc: Optional[Location], optional :param ip: Insertion point, defaults to None :type ip: Optional[InsertionPoint], optional - :return: A tuple containing ``tiler_mn`` and ``layout_tv`` :rtype: Tuple[Shape, Layout] @@ -5120,17 +3687,35 @@ def make_layout_tv( **Example:** + The below code creates a TV Layout that maps thread/value coordinates to the logical coordinates in a ``(4,6)`` tensor: + - *Tiler MN*: ``(4,6)`` + - *TV Layout*: ``((3,2),(2,2)):((8,2),(4,1))`` + .. code-block:: python - tiler_mn, layout_tv = cute.make_layout_tv( - cute.make_layout((4, 8), stride=(8, 1)), cute.make_layout(2, stride=1) - ) + thr_layout = cute.make_layout((2, 3), stride=(3, 1)) + val_layout = cute.make_layout((2, 2), stride=(2, 1)) + tiler_mn, layout_tv = cute.make_layout_tv(thr_layout, val_layout) - Above code creates a TV layout that maps between thread/value coordinates - and the logical coordinates in a 8x8 matrix with: + .. table:: TV Layout + :widths: auto + + +---+-----+-----+-----+-----+-----+-----+ + | | 0 | 1 | 2 | 3 | 4 | 5 | + +---+-----+-----+-----+-----+-----+-----+ + | 0| T0, | T0, | T1, | T1, | T2, | T2, | + | | V0 | V1 | V0 | V1 | V0 | V1 | + +---+-----+-----+-----+-----+-----+-----+ + | 1| T0, | T0, | T1, | T1, | T2, | T2, | + | | V2 | V3 | V2 | V3 | V2 | V3 | + +---+-----+-----+-----+-----+-----+-----+ + | 2| T3, | T3, | T4, | T4, | T5, | T5, | + | | V0 | V1 | V0 | V1 | V0 | V1 | + +---+-----+-----+-----+-----+-----+-----+ + | 3| T3, | T3, | T4, | T4, | T5, | T5, | + | | V2 | V3 | V2 | V3 | V2 | V3 | + +---+-----+-----+-----+-----+-----+-----+ - * thread block layout ``(4,8):(8,1)`` - * 2 elements per thread """ if not isinstance(thr_layout, Layout): @@ -5154,1552 +3739,6 @@ def make_layout_tv( return (tiler_mn, layout_tv) -def _make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): - if type(tiler_mn) is tuple: - tiler_mn = _pack_tile(tiler_mn, loc=loc, ip=ip) - - assert isinstance(tiler_mn, ir.Value) and _cute_ir.TileType.isinstance( - tiler_mn.type - ), f"tiler_mn must be a Tile, but got {type(tiler_mn)}" - assert is_static(layout_tv.type) and is_static( - tiler_mn.type - ), "layout tv and tiler mn must be static" - tiled_copy_ty = _cute_nvgpu_ir.TiledCopyType.get( - atom.type, layout_tv.type, tiler_mn.type - ) - - val = _cute_ir.make_tiled_copy(tiled_copy_ty, atom._trait.value, loc=loc, ip=ip) - # Instead of modifying atom which might have been provided by the user, create a brand new - # trait instance and replace the Atom ir.Value with the tiled one - trait = new_from_mlir_values(atom._trait, [val]) - return TiledCopy(atom.op, trait) - - -def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): - """Create a tiled type given a TV partitioner and tiler. - - :param atom: Copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc. - :type atom: CopyAtom - :param layout_tv: Thread-value layout - :type layout_tv: Layout - :param tiler_mn: Tile size - :type tiler_mn: Tiler - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tiled copy for the partitioner - :rtype: TiledCopy - """ - return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) - - -@dsl_user_op -def make_tiled_copy_tv( - atom: CopyAtom, thr_layout: Layout, val_layout: Layout, *, loc=None, ip=None -) -> TiledCopy: - """Create a tiled copy given separate thread and value layouts. - - A TV partitioner is inferred based on the input layouts. The input thread layout - must be compact. - - :param atom: Copy atom - :type atom: CopyAtom - :param thr_layout: Layout mapping from ``(TileM,TileN)`` coordinates to thread IDs (must be compact) - :type thr_layout: Layout - :param val_layout: Layout mapping from ``(ValueM,ValueN)`` coordinates to value IDs - :type val_layout: Layout - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tiled copy for the partitioner - :rtype: TiledCopy - """ - - tiler_mn, layout_tv = make_layout_tv(thr_layout, val_layout, loc=loc, ip=ip) - tiler_mn = _pack_tile(product_each(tiler_mn, loc=loc, ip=ip), loc=loc, ip=ip) - return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) - - -@dsl_user_op -def make_tiled_copy_A(atom, tiled_mma, *, loc=None, ip=None): - """Create a tiled copy out of the copy_atom that matches the A-Layout of tiled_mma. - - :param atom: Copy atom - :type atom: CopyAtom - :param tiled_mma: Tiled MMA - :type tiled_mma: TiledMma - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tiled copy for the partitioner - :rtype: TiledCopy - """ - - return _make_tiled_copy( - atom, - tiled_mma.tv_layout_A_tiled, - (tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)), - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def make_tiled_copy_B(atom, tiled_mma, *, loc=None, ip=None): - """Create a tiled copy out of the copy_atom that matches the B-Layout of tiled_mma. - - :param atom: Copy atom - :type atom: CopyAtom - :param tiled_mma: Tiled MMA - :type tiled_mma: TiledMma - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tiled copy for the partitioner - :rtype: TiledCopy - """ - - return _make_tiled_copy( - atom, - tiled_mma.tv_layout_B_tiled, - (tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)), - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def make_tiled_copy_C(atom, tiled_mma, *, loc=None, ip=None): - """Create a tiled copy out of the copy_atom that matches the C-Layout of tiled_mma. - - :param atom: Copy atom - :type atom: CopyAtom - :param tiled_mma: Tiled MMA - :type tiled_mma: TiledMma - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tiled copy for the partitioner - :rtype: TiledCopy - """ - - return _make_tiled_copy( - atom, - tiled_mma.tv_layout_C_tiled, - (tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(1)), - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def make_tiled_copy_S(atom, tiled_copy, *, loc=None, ip=None): - """Create a tiled copy out of the copy_atom that matches the Src-Layout of tiled_copy. - - :param atom: Copy atom - :type atom: CopyAtom - :param tiled_copy: Tiled copy - :type tiled_copy: TiledCopy - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tiled copy for the partitioner - :rtype: TiledCopy - """ - - return _make_tiled_copy( - atom, tiled_copy.layout_src_tv_tiled, tiled_copy.tiler_mn, loc=loc, ip=ip - ) - - -@dsl_user_op -def make_tiled_copy_D(atom, tiled_copy, *, loc=None, ip=None): - """Create a tiled copy out of the copy_atom that matches the Dst-Layout of tiled_copy. - - :param atom: Copy atom - :type atom: CopyAtom - :param tiled_copy: Tiled copy - :type tiled_copy: TiledCopy - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tiled copy for the partitioner - :rtype: TiledCopy - """ - - return _make_tiled_copy( - atom, tiled_copy.layout_dst_tv_tiled, tiled_copy.tiler_mn, loc=loc, ip=ip - ) - - -@dsl_user_op -def make_tiled_copy_C_atom(atom: CopyAtom, mma: TiledMma, *, loc=None, ip=None): - """Create the smallest tiled copy that can retile LayoutC_TV for use with pipelined epilogues with subtiled stores. - - :param atom: Copy atom - :type atom: CopyAtom - :param mma: Tiled MMA - :type mma: TiledMma - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - - :return: A tiled copy for partitioner - :rtype: TiledCopy - - :raises ValueError: If the number value of CopyAtom's source layout is greater than the size of TiledMma's LayoutC_TV - """ - # Truncate the V-layout to just the Copy_Atom, keep the V-order - layoutC_tv = mma.tv_layout_C_tiled - val_layout_src = atom.layout_src_tv - num_val_src = size(val_layout_src, mode=[1], loc=loc, ip=ip) - num_val_layoutC_tv = size(layoutC_tv, mode=[1], loc=loc, ip=ip) - if num_val_src > num_val_layoutC_tv: - raise ValueError( - f"The number value of CopyAtom's source layout {num_val_src} " - f"is greater than the size of TiledMma's LayoutC_TV {num_val_layoutC_tv}" - ) - layout_TV = composition( - layoutC_tv, - make_layout( - (size(layoutC_tv, mode=[0], loc=loc, ip=ip), num_val_src), loc=loc, ip=ip - ), - loc=loc, - ip=ip, - ) - - # Recompute tiler and restride the TV layout for the new tiler - - # Tiler -- Find the active elements in the MMA tensor and generate a tiler to extract them - # Convert to the awkward by-mode tiler to preserve the modes of the tiled MMA - mma_tiler = (mma.get_tile_size(0), mma.get_tile_size(1)) - - tiler_0 = filter( - composition( - make_layout(mma_tiler, stride=(1, 0), loc=loc, ip=ip), - layout_TV, - loc=loc, - ip=ip, - ), - loc=loc, - ip=ip, - ) - tiler_1 = filter( - composition( - make_layout(mma_tiler, stride=(0, 1), loc=loc, ip=ip), - layout_TV, - loc=loc, - ip=ip, - ), - loc=loc, - ip=ip, - ) - tiler = (tiler_0, tiler_1) - - tile2mma = composition( - make_layout(mma_tiler, loc=loc, ip=ip), tiler, loc=loc, ip=ip - ) - layout_tv = composition( - left_inverse(tile2mma, loc=loc, ip=ip), layout_TV, loc=loc, ip=ip - ) - - tiler_mn = _pack_tile(tiler, loc=loc, ip=ip) - - return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) - - -#################################################################################################### -# -# cute.gemm and cute.copy -# -#################################################################################################### - - -@dsl_user_op -def gemm( - atom: MmaAtom, - d: Tensor, - a: Tensor, - b: Tensor, - c: Tensor, - *, - loc=None, - ip=None, - **kwargs, -) -> None: - """The GEMM algorithm. - - Computes ``D <- A * B + C`` where ``C`` and ``D`` can alias. Note that some MMA Atoms (e.g. - warpgroup-wide or tcgen05 MMAs) require manually setting an "accumulate" boolean field. - - All tensors must be partitioned according to the provided MMA Atom. - - For MMA Atoms that require single-threaded execution, the gemm op automatically handles thread - election internally. Manual thread selection is not required in such cases. - - Following dispatch rules are supported: - - - Dispatch [1]: (V) x (V) => (V) => (V,1,1) x (V,1,1) => (V,1,1) - - Dispatch [2]: (M) x (N) => (M,N) => (1,M,1) x (1,N,1) => (1,M,N) - - Dispatch [3]: (M,K) x (N,K) => (M,N) => (1,M,K) x (1,N,K) => (1,M,N) - - Dispatch [4]: (V,M) x (V,N) => (V,M,N) => (V,M,1) x (V,N,1) => (V,M,N) - - Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) - - :param atom: MMA atom - :type atom: MmaAtom - :param d: Destination tensor - :type d: Tensor - :param a: First source tensor - :type a: Tensor - :param b: Second source tensor - :type b: Tensor - :param c: Third source tensor - :type c: Tensor - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point for MLIR, defaults to None - :type ip: Optional[InsertionPoint], optional - :param kwargs: Additional keyword arguments - :type kwargs: dict - :return: None - :rtype: None - """ - - a_rank = rank(a.shape) - b_rank = rank(b.shape) - c_rank = rank(c.shape) - d_rank = rank(d.shape) - - if a_rank != b_rank: - raise ValueError("`a` and `b` must have the same rank") - - if c_rank != d_rank: - raise ValueError("`c` and `d` must have the same rank") - - if a_rank == 1: - if c_rank > 2: - raise ValueError("`c` must have rank <= 2 when `a` has rank 1") - elif a_rank == 2: - if c_rank not in (2, 3): - raise ValueError("`c` must have rank 2 or 3 when `a` has rank 2") - elif a_rank == 3: - if c_rank != 3: - raise ValueError("`c` must have rank 3 when `a` has rank 3") - - value = atom._unpack(loc=loc, ip=ip, **kwargs) - return _cute_ir.gemm(value, d.value, a.value, b.value, c.value, loc=loc, ip=ip) - - -@dsl_user_op -def basic_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: - """Performs a basic element-wise copy. - - This functions **assumes** the following pre-conditions: - 1. `size(src) == size(dst)` - - When the `src` and `dst` shapes are static, the pre-conditions are actually verified and the - element-wise loop is fully unrolled. - - :param src: Source tensor - :type src: Tensor - :param dst: Destination tensor - :type dst: Tensor - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint], optional - """ - - if is_static(src.shape) and is_static(dst.shape): - simt_copy_ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( - src.element_type.mlir_type, src.element_type.width - ) - simt_copy = _cute_ir.atom(simt_copy_ty, loc=loc, ip=ip) - return _cute_ir.copy(simt_copy, src.value, dst.value, loc=loc, ip=ip) - - s = size(dst, loc=loc, ip=ip) - # Always generate an scf.for Op when one of the tensors is dynamic - for i in for_generate(0, s): - dst[i] = src[i] - yield_out() - - -@dsl_user_op -def basic_copy_if(pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: - """Performs a basic predicated element-wise copy. - - This functions **assumes** the following pre-conditions: - 1. `size(src) == size(dst)` - 2. `size(src) == size(pred)` - - When all shapes are static, the pre-conditions are actually verified and the element-wise loop - is fully unrolled. - - """ - if src.element_type.width != dst.element_type.width: - raise NotImplementedError( - "basic_copy_if currently only supports equal source and destination " - "element type bit width" - ) - - if is_static(src.shape) and is_static(dst.shape) and is_static(pred.shape): - return _basic_copy_if_static(pred, src, dst, loc=loc, ip=ip) - - s = size(dst, loc=loc, ip=ip) - # Always generate an scf.for Op when one of the tensors is dynamic - for i in for_generate(0, s): - if_generate(pred[i], lambda: dst.__setitem__(i, src[i])) - yield_out() - - -# Version of basic_copy_if when src and dst have static shapes -# - verify size(src) == size(dst) == size(prd) -# - fully unroll the loop for now -def _basic_copy_if_static( - pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None -) -> None: - assert is_static(src.shape) and is_static(dst.shape) and is_static(pred.shape) - if size(src, loc=loc, ip=ip) != size(dst, loc=loc, ip=ip): - raise ValueError( - "basic_copy expects the size of source, destination, and predicate tensors to match" - ) - # Fully unrolled loop in the static case for now - for i in range(size(dst, loc=loc, ip=ip)): - if_generate(pred[i], lambda: dst.__setitem__(i, src[i])) - - -@dsl_user_op -def autovec_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: - """ - Auto-vectorizing SIMT copy policy. - - Given a source and destination tensors that are statically shaped, this policy figures out the - largest safe vector width that the copy instruction can take and performs the copy. - """ - if src.element_type.width != dst.element_type.width: - raise NotImplementedError( - "autovec_copy currently only supports equal source and destination " - "element type bit width" - ) - - # We are going to dispatch to copy-with-atom which requires shapes to be static - if not is_static(src.shape) or not is_static(dst.shape): - raise ValueError( - "autovec_copy expects source and destination tensors to be statically shaped" - ) - - vec_layout = max_common_layout(src, dst, loc=loc, ip=ip) - num_common_elements = size(vec_layout, loc=loc, ip=ip) - - # Next we construct an upper-bound on the number bits that can be vectorized by considering - # - the maximum alignment of the layouts - # - the maximum alignment of the pointers - - upper_bound = math.gcd(src.layout.max_alignment, dst.layout.max_alignment) - upper_bound = math.gcd(upper_bound, num_common_elements) - upper_bound *= src.element_type.width - - # For our instructions, the alignment of the pointer is an upper bound to the vector width - # max_alignment, as opposed to alignment, takes into account possible address swizzling - upper_bound = math.gcd(upper_bound, src.iterator.max_alignment * 8) - upper_bound = math.gcd(upper_bound, dst.iterator.max_alignment * 8) - - # Finally, we put a cap at 128b - num_bits_per_copy = math.gcd(upper_bound, 128) - - if (num_common_elements > 1) and (num_bits_per_copy % 8 == 0): - num_common_elements = num_bits_per_copy // src.element_type.width - - # 2 step logical divides ensuring that the divides are valid at every step - vec_src = logical_divide(src, vec_layout, loc=loc, ip=ip) - vec_dst = logical_divide(dst, vec_layout, loc=loc, ip=ip) - tiled_src = logical_divide( - vec_src, make_layout(num_common_elements, loc=loc, ip=ip), loc=loc, ip=ip - ) - tiled_dst = logical_divide( - vec_dst, make_layout(num_common_elements, loc=loc, ip=ip), loc=loc, ip=ip - ) - - # Dispatch to copy with atom - simt_type = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( - src.element_type.mlir_type, num_bits_per_copy - ) - simt_copy = _cute_ir.atom(simt_type, loc=loc, ip=ip) - return _cute_ir.copy( - simt_copy, tiled_src.value, tiled_dst.value, loc=loc, ip=ip - ) - - # Failed to vectorize, use a basic copy - basic_copy(src, dst, loc=loc, ip=ip) - - -@dsl_user_op -def copy( - atom: CopyAtom, - src: Tensor, - dst: Tensor, - *, - pred: Optional[Tensor] = None, - loc=None, - ip=None, - **kwargs, -) -> None: - """ - The Copy algorithm. - - The "copy with Atom" expects source and destination tensors to be partitioned according to the - provided Copy Atom. Some Atoms require additional Op-specific kw arguments, for example TMA - copies: - - .. code-block:: python - - cute.copy(tma_atom, src, dst, tma_bar_ptr=mbar_ptr, mcast_mask=mask) - - An additional predication tensor can be provided. If the partitioned tensors have the following - logical profile ``((ATOM_V,ATOM_REST),REST_M,...)``, the predication tensor must have a profile - consistent with ``(ATOM_REST,REST_M,...)``. - - For Copy Atoms that require single-threaded execution, the copy op automatically handles thread - election internally. Manual thread selection is not required in such cases. - """ - if isinstance(src.type, _cute_ir.MemRefType) and isinstance( - dst.type, _cute_ir.MemRefType - ): - if src.element_type.width != dst.element_type.width: - raise TypeError( - "`copy` currently only supports equal source and destination " - "element type bit width" - ) - - value = atom._unpack(loc=loc, ip=ip, **kwargs) - if isinstance(pred, Tensor): - pred = pred.value - return _cute_ir.copy(value, src.value, dst.value, pred=pred, loc=loc, ip=ip) - - -@dsl_user_op -def copy_atom_call( - atom: CopyAtom, - src: Tensor, - dst: Tensor, - *, - pred: Optional[Tensor] = None, - loc=None, - ip=None, - **kwargs, -) -> None: - """ - Execute a single copy atom operation. - - The copy_atom_call operation executes a copy atom with the given operands. - Following src/dst layout of atom are valid: - * ((atom_v)) - * (atom_v) - - Note: The format ((atom_v, rest_v)) is NOT valid for copy_atom_call since it would - require multiple atom operations, which contradicts the definition of a single copy atom call. - - Examples: - - .. code-block:: python - - # Call a copy atom operation - cute.copy_atom_call(copy_atom, src_tensor, dst_tensor) - - An additional predication tensor can be provided. If the partitioned tensors have the following - logical profile ``((ATOM_V,ATOM_REST),REST_M,...)``, the predication tensor must have a profile - consistent with ``(ATOM_REST,REST_M,...)``. - """ - if isinstance(src.type, _cute_ir.MemRefType) and isinstance( - dst.type, _cute_ir.MemRefType - ): - if src.element_type.width != dst.element_type.width: - raise TypeError( - "`copy_atom_call` currently only supports equal source and destination " - "element type bit width" - ) - - value = atom._unpack(loc=loc, ip=ip, **kwargs) - if isinstance(pred, Tensor): - pred = pred.value - return _cute_ir.copy_atom_call( - value, src.value, dst.value, pred=pred, loc=loc, ip=ip - ) - - -def prefetch(atom: CopyAtom, src: Tensor, *, loc=None, ip=None) -> None: - """ - The Prefetch algorithm. - - The "prefetch" expects source tensors to be partitioned according to the provided Copy Atom. - Prefetch is used for loading tensors from global memory to L2. - - Prefetch accepts Copy Atom but not all are allowed. Currently, only support for tma load tensor prefetch. - - .. code-block:: python - - cute.prefetch(tma_atom, src) - - For Copy Atoms that require single-threaded execution, the copy op automatically handles thread - election internally. Manual thread selection is not required in such cases. - """ - dummy_tma_bar_ptr = make_ptr(Int64, 0, AddressSpace.smem, loc=loc, ip=ip) - value = atom._unpack(loc=loc, ip=ip, tma_bar_ptr=dummy_tma_bar_ptr) - return _cute_ir.prefetch(value, src.value, loc=loc, ip=ip) - -#################################################################################################### -# -# TensorSSA class (experimental) -# -#################################################################################################### - - -class ReductionOp(Enum): - ADD = auto() - MUL = auto() - MAX = auto() - MIN = auto() - INC = auto() - DEC = auto() - AND = auto() - OR = auto() - XOR = auto() - - def __str__(self): - return self.name.lower() - - -class TensorSSA(cutlass_arith.ArithValue): - """A class representing thread local data from CuTe Tensor in value semantic and immutable. - - :param value: Flatten vector as ir.Value holding logic data of SSA Tensor - :type value: ir.Value - :param shape: The nested shape in CuTe of the vector - :type shape: Shape - :param dtype: Data type of the tensor elements - :type dtype: Type[Numeric] - - :ivar _shape: The nested shape in CuTe of the vector - :ivar _dtype: Data type of the tensor elements - - :raises ValueError: If shape is not static - """ - - def __init__(self, value, shape: Shape, dtype: Type[Numeric]): - """Initialize a new TensorSSA object. - - :param value: Flatten vector as ir.Value holding logic data of SSA Tensor - :type value: ir.Value - :param shape: The nested shape in CuTe of the vector - :type shape: Shape - :param dtype: Data type of the tensor elements - :type dtype: Type[Numeric] - :raises ValueError: If shape is not static - """ - if not is_static(shape): - raise ValueError("dynamic shape is not supported") - - signed = dtype.signed if issubclass(dtype, Integer) else False - super().__init__(value, signed) - - self._shape = shape - self._dtype = dtype - self._layout = None - - @property - def dtype(self) -> Type[Numeric]: - return self._dtype - - @property - def element_type(self) -> Type[Numeric]: - return self._dtype - - @abstractmethod - def __extract_mlir_values__(self): - return [self] - - @abstractmethod - def __new_from_mlir_values__(self, values): - return TensorSSA(values[0], self.shape, self.dtype) - - def __str__(self): - return f"tensor_value<{self.type} o {self.shape}>" - - @property - def shape(self): - return self._shape - - @overload - def _apply_op(self, op, other: "TensorSSA", flip, *, loc, ip) -> "TensorSSA": ... - - @overload - def _apply_op( - self, op, other: cutlass_arith.ArithValue, flip, *, loc, ip - ) -> "TensorSSA": ... - - @overload - def _apply_op( - self, op, other: Union[int, float, bool], flip, *, loc, ip - ) -> "TensorSSA": ... - - def _apply_op(self, op, other, flip=False, *, loc=None, ip=None): - def get_attr_for_type(ty, value): - if isinstance(ty, ir.IntegerType): - return ir.IntegerAttr.get(ty, value) - elif isinstance(ty, ir.FloatType): - return ir.FloatAttr.get(ty, value) - else: - raise TypeError(f"unsupported type: {ty}") - - # Canonicalize into Numeric - if isinstance(other, (int, float, bool)) or ( - not isinstance(other, TensorSSA) - and isinstance(other, cutlass_arith.ArithValue) - ): - other = as_numeric(other) - - # Promote types - lhs, rhs, res_type = _binary_op_type_promote(self, other) - - # Promote scalar to vector - if not isinstance(rhs, TensorSSA): - if isinstance(rhs, Numeric): - vect_val = vector.broadcast(lhs.type, rhs.ir_value(loc=loc, ip=ip)) - else: - elem_attr = get_attr_for_type(lhs.type.element_type, rhs) - vect_attr = ir.DenseElementsAttr.get_splat(lhs.type, elem_attr) - vect_val = arith.constant(lhs.type, vect_attr, loc=loc, ip=ip) - rhs = TensorSSA(vect_val, lhs.shape, lhs.dtype) - - if flip: - lhs, rhs = rhs, lhs - - if op in ( - operator.lt, - operator.le, - operator.gt, - operator.ge, - operator.eq, - operator.ne, - ): - res_type = Boolean - - assert isinstance(rhs, TensorSSA), f"rhs must be TensorSSA but got {rhs}" - - def _broadcast(s, t): - if s == 1: - return t - elif t == 1: - return s - elif s == t: - return s - else: - raise ValueError(f"cannot broadcast {s} and {t}") - - max_rank = max(rank(lhs.shape), rank(rhs.shape)) - lhs_shape = append(lhs.shape, 1, up_to_rank=max_rank) - rhs_shape = append(rhs.shape, 1, up_to_rank=max_rank) - res_shape = transform_leaf(_broadcast, lhs_shape, rhs_shape) - - # broadcast to the same shape - lhs = lhs.broadcast_to(res_shape) - rhs = rhs.broadcast_to(res_shape) - - if ( - op in (operator.add, operator.sub) - and lhs.dtype == Boolean - and rhs.dtype == Boolean - ): - res = op(lhs.to(Int32), rhs.to(Int32)) - zero = zeros_like(res) - res = res.__ne__(zero).to(res_type) - else: - lhs_val = lhs.maybe_downcast() - rhs_val = rhs.maybe_downcast() - - if issubclass(lhs.dtype, Integer): - lhs_val = lhs_val.with_signedness(lhs.dtype.signed) - - if issubclass(rhs.dtype, Integer): - rhs_val = rhs_val.with_signedness(rhs.dtype.signed) - - res_vect = op(lhs_val, rhs_val) - res = TensorSSA(res_vect, lhs._shape, res_type) - - return res - - def broadcast_to(self, target_shape: Shape, *, loc=None, ip=None) -> "TensorSSA": - """ - Broadcast the tensor to the target shape. - """ - # pad source shape to the same rank - shape = append(self.shape, 1, up_to_rank=rank(target_shape)) - if shape == target_shape: - return self - - def _check_broadcast(s, t): - if s != t and s != 1: - raise ValueError( - f"src_shape and target_shape must be the same when src_shape is not 1, but got {s} and {t}" - ) - - transform_leaf(_check_broadcast, shape, target_shape) - - # reshape to flatten N-D vector - flat_shp = flatten_to_tuple(shape) - temp_ty = ir.VectorType.get(list(flat_shp), self.dtype.mlir_type) - temp_vect = vector.shape_cast(temp_ty, self, loc=loc, ip=ip) - - # broadcast to result N-D vector - flat_tgt_shp = flatten_to_tuple(target_shape) - temp_tgt_ty = ir.VectorType.get(list(flat_tgt_shp), self.dtype.mlir_type) - temp_tgt_vect = vector.broadcast(temp_tgt_ty, temp_vect, loc=loc, ip=ip) - - res_1d_ty = ir.VectorType.get([size(target_shape)], self.dtype.mlir_type) # type: ignore - res_1d_vect = vector.shape_cast(res_1d_ty, temp_tgt_vect, loc=loc, ip=ip) - - return TensorSSA(res_1d_vect, target_shape, self.dtype) - - def __pow__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the results of tensor^other. - - :param other: The other tensor for exponent. - :type other: TensorSSA - :return: The power of the tensor. - :rtype: TensorSSA - """ - return self._apply_op(operator.pow, other, loc=loc, ip=ip) - - def __rpow__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the results of other^tensor. - - :param other: The other tensor to compute power with. - :type other: TensorSSA - :return: The element-wise power of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.pow, other, flip=True, loc=loc, ip=ip) - - def __add__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the sum of the tensor and another tensor. - - :param other: The other tensor to add. - :type other: TensorSSA - :return: The sum of the two tensors with the same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.add, other, loc=loc, ip=ip) - - def __radd__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the sum of the tensor and another tensor (reverse add) - - :param other: The other tensor to add. - :type other: TensorSSA - :return: The sum of the two tensors with the same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.add, other, flip=True, loc=loc, ip=ip) - - def __sub__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the difference of the tensor and another tensor. - - :param other: The other tensor to subtract. - :type other: TensorSSA - :return: The subtraction of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.sub, other, loc=loc, ip=ip) - - def __rsub__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the difference of the tensor and another tensor (reverse subtract) - - :param other: The other tensor to subtract. - :type other: TensorSSA - :return: The subtraction of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.sub, other, flip=True, loc=loc, ip=ip) - - def __mul__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the multiplication of the tensor and another tensor. - - :param other: The other tensor to multiply. - :type other: TensorSSA - :return: The multiplication of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.mul, other, loc=loc, ip=ip) - - def __rmul__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the multiplication of the tensor and another tensor (reverse multiply) - - :param other: The other tensor to multiply. - :type other: TensorSSA - :return: The multiplication of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.mul, other, flip=True, loc=loc, ip=ip) - - def __mod__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the modulo of the tensor and another tensor. - - :param other: The other tensor to compute modulo with. - :type other: TensorSSA - :return: The element-wise modulo of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.mod, other, loc=loc, ip=ip) - - def __rmod__(self, other) -> "TensorSSA": - """ - Returns the modulo of the tensor and another tensor (reverse modulo) - - :param other: The other tensor to compute modulo with. - :type other: TensorSSA - :return: The element-wise modulo of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.mod, other, flip=True) - - def __floordiv__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the floordiv(//) of the tensor and another tensor. - - :param other: The other tensor to compute floordiv with. - :type other: TensorSSA - :return: The floordiv of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.floordiv, other, loc=loc, ip=ip) - - def __rfloordiv__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the floordiv(//) of the tensor and another tensor (reverse floordiv) - - :param other: The other tensor to compute floordiv with. - :type other: TensorSSA - :return: The floordiv of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.floordiv, other, flip=True, loc=loc, ip=ip) - - def __truediv__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the truediv(/) of the tensor and another tensor. - - :param other: The other tensor to compute truediv with. - :type other: TensorSSA - :return: The truediv of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.truediv, other, loc=loc, ip=ip) - - def __rtruediv__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the truediv(/) of the tensor and another tensor (reverse truediv) - - :param other: The other tensor to compute truediv with. - :type other: TensorSSA - :return: The truediv of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.truediv, other, flip=True, loc=loc, ip=ip) - - def __eq__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the comparison of the tensor and another tensor as mask - - :param other: The other tensor to compare. - :type other: TensorSSA - :return: The comparison of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.eq, other, loc=loc, ip=ip) - - def __ne__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise not equal comparison of the tensor and another tensor. - - :param other: The other tensor to compare. - :type other: TensorSSA - :return: A boolean tensor with same shape as inputs, True where self != other. - :rtype: TensorSSA - """ - return self._apply_op(operator.ne, other, loc=loc, ip=ip) - - def __lt__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise less than comparison of the tensor and another tensor. - - :param other: The other tensor to compare with. - :type other: TensorSSA - :return: A boolean tensor with same shape as inputs, True where self < other. - :rtype: TensorSSA - """ - return self._apply_op(operator.lt, other, loc=loc, ip=ip) - - def __le__(self, other) -> "TensorSSA": - """ - Returns the element-wise less than or equal comparison of the tensor and another tensor. - - :param other: The other tensor to compare with. - :type other: TensorSSA - :return: A boolean tensor with same shape as inputs, True where self <= other. - :rtype: TensorSSA - """ - return self._apply_op(operator.le, other) - - def __gt__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise greater than comparison of the tensor and another tensor. - - :param other: The other tensor to compare with. - :type other: TensorSSA - :return: A boolean tensor with same shape as inputs, True where self > other. - :rtype: TensorSSA - """ - return self._apply_op(operator.gt, other) - - def __ge__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise greater than or equal comparison of the tensor and another tensor. - - :param other: The other tensor to compare with. - :type other: TensorSSA - :return: A boolean tensor with same shape as inputs, True where self >= other. - :rtype: TensorSSA - """ - return self._apply_op(operator.ge, other, loc=loc, ip=ip) - - def __xor__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise XOR of the tensor and another tensor. - - :param other: The other tensor to perform XOR with. - :type other: TensorSSA - :return: The element-wise XOR of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.xor, other) - - def __rxor__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the bitwise XOR of the tensor and another tensor. - - :param other: The other tensor to compute XOR with. - :type other: TensorSSA - :return: The element-wise bitwise XOR of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.xor, other, flip=True, loc=loc, ip=ip) - - def __or__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise OR of the tensor and another tensor. - - :param other: The other tensor to perform OR with. - :type other: TensorSSA - :return: The element-wise OR of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.or_, other) - - def __ror__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise OR of the tensor and another tensor. - - :param other: The other tensor to perform OR with. - :type other: TensorSSA - :return: The element-wise OR of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.or_, other, flip=True) - - def __and__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise AND of the tensor and another tensor. - - :param other: The other tensor to perform AND with. - :type other: TensorSSA - :return: The element-wise AND of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.and_, other) - - def __rand__(self, other, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the element-wise AND of the tensor and another tensor. - - :param other: The other tensor to perform AND with. - :type other: TensorSSA - :return: The element-wise AND of two tensors with same shape as inputs. - :rtype: TensorSSA - """ - return self._apply_op(operator.and_, other, flip=True, loc=loc, ip=ip) - - def __neg__(self, *, loc=None, ip=None) -> "TensorSSA": - """ - Returns the negation of the tensor. - - :return: The element-wise negation of the tensor - :rtype: TensorSSA - """ - - return self._apply_op(operator.sub, 0, flip=True, loc=loc, ip=ip) - - def _flatten_shape_and_coord(self, crd, *, loc=None, ip=None): - # Coalesce and flatten source layout at terminal of coordinate - # (N_0,(N_1,...), ...) -> (N_0,N_1,N_2,...) - crd_shp = product_like(self._shape, target_profile=crd, loc=loc, ip=ip) - - # Flatten coordinate - flat_shp = flatten(crd_shp) - assert isinstance(flat_shp, tuple) and is_static(flat_shp) - # (C_0,(C_1,...), ...) -> (C_0,C_1,C_2,...) - flat_crd = flatten(crd) - - assert isinstance(flat_crd, tuple) and is_static(flat_crd) - return flat_shp, flat_crd - - def _build_result(self, res_vect, res_shp, *, loc=None, ip=None): - if isinstance(res_shp, ir.Value): - raise ValueError( - f"expects static shape and coordinates, but got {self._shape} and {crd}" - ) - - # cast back to 1D vector - res_1d_ty = ir.VectorType.get([size(res_shp)], self.type.element_type) - res_1d_vect = vector.shape_cast(res_1d_ty, res_vect, loc=loc, ip=ip) - return TensorSSA(res_1d_vect, res_shp, self.dtype) - - @dsl_user_op - def __getitem__( - self, crd: Coord, *, loc=None, ip=None - ) -> Union["TensorSSA", Numeric]: - """Access or slice tensor elements using coordinates. - - This method implements tensor evaluation T(c) = *(E + L(c)) where E is the iterator/engine - and L is the layout. It supports both direct element access and slicing operations. - - :param crd: Coordinate or slice specification for accessing tensor elements - :type crd: Coord - :param loc: Source location for MLIR operation tracking, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for MLIR operation, defaults to None - :type ip: Optional[InsertionPoint] - :return: Tensor element value or sliced subtensor - :rtype: Union[TensorSSA, Numeric] - - :raises ValueError: If coordinate access is invalid for the tensor layout - - **Examples:** - - .. code-block:: python - - # Create a fragment from rmem as shape (8, 4) - layout = make_layout((8, 4)) - tensor = make_fragment(layout, Float32) - frg = tensor.load() - - # Direct element access - val = frg[0] # Returns first element of fragment - val = frg[(0, 1)] # Returns element at (0, 1) - - # Slice access - sliced = frg[(3, None)] # Returns fragment slice - """ - # short-cut to no-op - if crd is None: - return self - - if not has_underscore(crd): - if self._layout is None: - self._layout = make_layout(self._shape, loc=loc, ip=ip) - idx = crd2idx(crd, self._layout, loc=loc, ip=ip) - idx_val = as_numeric(idx).ir_value(loc=loc, ip=ip) - res_val = vector.extractelement(self, position=idx_val, loc=loc, ip=ip) - return self.dtype(res_val) - - if not is_static(crd): - raise ValueError("dynamic coordinate is not supported") - - flat_shp, flat_crd = self._flatten_shape_and_coord(crd) - - multi_dim_ty = ir.VectorType.get(list(flat_shp), self.type.element_type) - # vector -> vector - tmp_vect = vector.shape_cast(multi_dim_ty, self) - - # Slice and keep dims matching `_` or None - res_shp = slice_(self._shape, crd) - if isinstance(res_shp, ir.Value): - raise TypeError( - f"expects static shape and coordinates, but got {self._shape} and {crd}" - ) - - # Offsets is index of coordinates if NOT `_` otherwise 0 - offsets = [c if c is not None else 0 for c in flat_crd] - # Sizes is size of shapes if `_` otherwise 1 - sizes = [s if c is None else 1 for s, c in zip(flat_shp, flat_crd)] - # Logic stride to index vector. Only support stride-1 by vector - strides = [1] * rank(flat_shp) - - # Vector slice on N-D vector - res_ty = ir.VectorType.get(list(sizes), self.type.element_type) - res_vect = vector.extract_strided_slice( - res_ty, tmp_vect, offsets=offsets, sizes=sizes, strides=strides - ) - - # Slice and keep dims matching `_` or None - res_shp = slice_(self._shape, crd) - return self._build_result(res_vect, res_shp, loc=loc, ip=ip) - - @dsl_user_op - def to(self, dtype: Type[Numeric], *, loc=None, ip=None): - """Convert the tensor to a different numeric type. - - :param dtype: The target numeric type to cast to. - :type dtype: Type[Numeric] - :return: A new tensor with the same shape but with elements cast to the target type. - :rtype: TensorSSA - :raises TypeError: If dtype is not a subclass of Numeric. - :raises NotImplementedError: If dtype is an unsigned integer type. - """ - if dtype is ir.Value: - return self - - if not isclass(dtype) or not issubclass(dtype, Numeric): - raise TypeError(f"dtype must be a type of Numeric, but got {type(dtype)}") - - src_dtype = self.dtype - if src_dtype == dtype: - return self - - # maybe downcast can lose signedness - src = self.maybe_downcast().with_signedness(self.signed) - if src_dtype.is_float and dtype.is_float: - res_vect = cutlass_arith.cvtf(src, dtype.mlir_type, loc=loc, ip=ip) - elif src_dtype.is_float and issubclass(dtype, Integer): - res_vect = cutlass_arith.fptoi( - src, dtype.signed, dtype.mlir_type, loc=loc, ip=ip - ) - elif issubclass(src_dtype, Integer) and dtype.is_float: - res_vect = cutlass_arith.itofp( - src, src_dtype.signed, dtype.mlir_type, loc=loc, ip=ip - ) - else: - res_vect = cutlass_arith.int_to_int(src, dtype, loc=loc, ip=ip) - - return TensorSSA(res_vect, self._shape, dtype) - - def ir_value(self, *, loc=None, ip=None): - return self - - def ir_value_int8(self, *, loc=None, ip=None): - """ - Returns int8 ir value of Boolean tensor. - When we need to store Boolean tensor ssa, use ir_value_int8(). - - :param loc: Source location information, defaults to None - :type loc: Optional[Location], optional - :param ip: Insertion point for MLIR operations, defaults to None - :type ip: Optional[InsertionPoint], optional - :return: The int8 value of this Boolean - :rtype: ir.Value - """ - assert ( - self.element_type is Boolean - ), f"Only boolean type needs to be converted to int8, got {self.element_type}" - - if not hasattr(self, "_value_int8"): - self._value_int8 = arith.extsi( - T.vector(self.type.shape[0], T.i8()), self, loc=loc, ip=ip - ) - return self._value_int8 - - def reduce(self, op, init_val, reduction_profile: Coord, *, loc=None, ip=None): - """ - Perform reduce on selected modes with given predefined reduction op. - - :param op: The reduction operator to use (operator.add or operator.mul) - :type op: operator - :param init_val: The initial value for the reduction - :type init_val: numeric - :param reduction_profile: Specifies which dimensions to reduce. Dimensions marked with `None` are kept. - :type reduction_profile: Coord - - :return: The reduced tensor - :rtype: TensorSSA - - **Examples:** - - .. code-block:: python - - reduce(f32 o (4,)) - => f32 - - reduce(f32 o (4, 5)) - => f32 - reduce(f32 o (4, (5, 4)), reduction_profile=(None, 1)) - => f32 o (4,) - reduce(f32 o (4, (5, 4)), reduction_profile=(None, (None, 1))) - => f32 o (4, (5,)) - """ - # short-cut to no-op - if reduction_profile is None: - return self - - if not is_weakly_congruent(reduction_profile, self.shape): - raise ValueError( - f"Expect reduction_profile be weakly congruent to the shape of the tensor, " - f"but got {reduction_profile} and {self.shape}" - ) - - if op is ReductionOp.ADD: - red_kind = vector.CombiningKind.ADD - elif op is ReductionOp.MUL: - red_kind = vector.CombiningKind.MUL - elif op is ReductionOp.MAX: - red_kind = vector.CombiningKind.MAXIMUMF - elif op is ReductionOp.MIN: - red_kind = vector.CombiningKind.MINIMUMF - else: - raise NotImplementedError( - f"{op} is not supported, expects one of " - f"{ReductionOp.ADD, ReductionOp.MUL, ReductionOp.MAX, ReductionOp.MIN}" - ) - - elem_ty = self.element_type - # Canonicalize to `Numeric` and convert into MLIR value - init_val = as_numeric(init_val).ir_value(loc=loc, ip=ip) - - if depth(reduction_profile) == 0: - return vector.reduction( - elem_ty.mlir_type, red_kind, self, acc=init_val, loc=loc, ip=ip - ) - - flat_shp, flat_prof = self._flatten_shape_and_coord( - reduction_profile, loc=loc, ip=ip - ) - assert depth(flat_shp) == 1 and depth(flat_prof) == 1 - assert rank(flat_shp) == rank(flat_prof) - - temp_ty = ir.VectorType.get(list(flat_shp), elem_ty.mlir_type) - temp_vect = vector.shape_cast(temp_ty, self, loc=loc, ip=ip) - - if isinstance(flat_prof, tuple): - red_dims = [i for i, x in enumerate(flat_prof) if x is not None] - else: - red_dims = [0] - - temp_acc_shp = slice_(flat_shp, flat_prof, loc=loc, ip=ip) - temp_acc_ty = ir.VectorType.get(list(temp_acc_shp), elem_ty.mlir_type) - - init_val = vector.broadcast(temp_acc_ty, init_val, loc=loc, ip=ip) - res_vect = vector.multi_reduction( - red_kind, temp_vect, acc=init_val, reduction_dims=red_dims, loc=loc, ip=ip - ) - - # Slice and keep dims matching `_` or None - res_shp = slice_(self.shape, reduction_profile, loc=loc, ip=ip) - return self._build_result(res_vect, res_shp, loc=loc, ip=ip) - - -@dsl_user_op -def full(shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> TensorSSA: - """ - Return a new TensorSSA of given shape and type, filled with fill_value. - - :param shape: Shape of the new tensor. - :type shape: tuple - :param fill_value: Value to fill the tensor with. - :type fill_value: scalar - :param dtype: Data type of the tensor. - :type dtype: Type[Numeric] - :return: Tensor of fill_value with the specified shape and dtype. - :rtype: TensorSSA - """ - size = product(shape, loc=loc, ip=ip) - if not is_static(size): - raise ValueError("shape must be static") - - if isinstance(fill_value, (ir.Value, int, float, bool)): - fill_value = dtype(fill_value) - elif isinstance(fill_value, Numeric): - fill_value = fill_value.to(dtype, loc=loc, ip=ip) - else: - raise ValueError(f"Expected fill_value be numeric type, but got {fill_value}") - - res_ty = T.vector(size, dtype.mlir_type) - res_val = vector.splat(res_ty, fill_value.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) - return TensorSSA(res_val, shape, dtype) - - -def full_like( - a: Union[TensorSSA, Tensor], - fill_value, - dtype: Union[None, Type[Numeric]] = None, - *, - loc=None, - ip=None, -) -> TensorSSA: - """ - Return a full TensorSSA with the same shape and type as a given array. - - :param a: The shape and data-type of `a` define these same attributes of the returned array. - :type a: array_like - :param fill_value: Fill value. - :type fill_value: array_like - :param dtype: Overrides the data type of the result, defaults to None - :type dtype: Union[None, Type[Numeric]], optional - :return: Tensor of `fill_value` with the same shape and type as `a`. - :rtype: TensorSSA - - .. seealso:: - :func:`empty_like`: Return an empty array with shape and type of input. - :func:`ones_like`: Return an array of ones with shape and type of input. - :func:`zeros_like`: Return an array of zeros with shape and type of input. - :func:`full`: Return a new array of given shape filled with value. - - **Examples:** - - .. code-block:: python - - frg = cute.make_fragment(Float32, (2, 3)) - a = frg.load() - b = cute.full_like(a, 1.0) - """ - if not hasattr(a, "shape"): - raise TypeError(f"Expect `a` be shaped type, but got {type(a)}") - - return full( - a.shape, fill_value, dtype if dtype is not None else a.dtype, loc=loc, ip=ip - ) - - -def empty_like(a, dtype=None): - """ - Return a new TensorSSA with the same shape and type as a given array, without initializing entries. - - :param a: The shape and data-type of `a` define these same attributes of the returned array. - :type a: TensorSSA - :param dtype: Overrides the data type of the result, defaults to None - :type dtype: Type[Numeric], optional - :return: Uninitialized tensor with the same shape and type (unless overridden) as `a`. - :rtype: TensorSSA - """ - return full_like(a, 0, dtype) - - -def ones_like(a, dtype=None): - """ - Return a TensorSSA of ones with the same shape and type as a given array. - - :param a: The shape and data-type of `a` define these same attributes of the returned array. - :type a: TensorSSA - :param dtype: Overrides the data type of the result, defaults to None - :type dtype: Type[Numeric], optional - :return: Tensor of ones with the same shape and type (unless overridden) as `a`. - :rtype: TensorSSA - """ - return full_like(a, 1, dtype) - - -def zeros_like(a, dtype=None, *, loc=None, ip=None): - """ - Return a TensorSSA of zeros with the same shape and type as a given array. - - :param a: The shape and data-type of `a` define these same attributes of the returned array. - :type a: TensorSSA - :param dtype: Overrides the data type of the result, defaults to None - :type dtype: Type[Numeric], optional - :return: Tensor of zeros with the same shape and type (unless overridden) as `a`. - :rtype: TensorSSA - """ - return full_like(a, 0, dtype, loc=loc, ip=ip) - - -def where( - cond: TensorSSA, x: TensorSSA, y: TensorSSA, *, loc=None, ip=None -) -> TensorSSA: - """ - Return elements chosen from x or y depending on condition. - - :param cond: Where True, yield x, where False, yield y. - :type cond: TensorSSA - :param x: Values from which to choose when condition is True. - :type x: TensorSSA - :param y: Values from which to choose when condition is False. - :type y: TensorSSA - :return: A tensor with elements from x where condition is True, and elements from y where condition is False. - :rtype: TensorSSA - """ - if x.dtype != y.dtype: - raise ValueError( - f"x and y must have the same dtype, but got {x.dtype} and {y.dtype}" - ) - - if cond.dtype != Boolean: - raise ValueError(f"cond must be Boolean type, but got {cond.dtype}") - - return TensorSSA( - arith.select(cond.ir_value(), x, y, loc=loc, ip=ip), x.shape, x.dtype - ) - - -def any_(x: TensorSSA, *, loc=None, ip=None) -> Boolean: - """ - Test whether any tensor element evaluates to True. - - :param x: Input tensor. - :type x: TensorSSA - :return: Returns a TensorSSA scalar containing True if any element of x is True, False otherwise. - :rtype: TensorSSA - """ - is_true = x != full_like(x, 0, x.dtype, loc=loc, ip=ip) - return Boolean( - vector.reduction(T.bool(), vector.CombiningKind.OR, is_true, loc=loc, ip=ip) - ) - - -def all_(x: TensorSSA, *, loc=None, ip=None) -> Boolean: - """ - Test whether all tensor elements evaluate to True. - - :param x: Input tensor. - :type x: TensorSSA - :return: Returns a TensorSSA scalar containing True if all elements of x are True, False otherwise. - :rtype: TensorSSA - """ - is_true = x != full_like(x, 0, x.dtype, loc=loc, ip=ip) - return Boolean( - vector.reduction(T.bool(), vector.CombiningKind.AND, is_true, loc=loc, ip=ip) - ) - - ############################################################################## # User defined struct ############################################################################## @@ -6727,7 +3766,7 @@ class struct: intA : cutlass.Int16 - # Supports aligment for its elements: + # Supports alignment for its elements: @cute.struct class StorageB: a: cute.struct.Align[ @@ -6854,10 +3893,12 @@ class struct: :raises TypeError: If the layout is incompatible with the swizzle. :raises AssertionError: If the size of the memory range is not greater than zero. """ + from .tensor import make_tensor + assert self._size > 0 # make tensor if isinstance(layout, ComposedLayout) and (swizzle is not None): - raise TypeError(f"incompatible layout with swizzle") + raise TypeError("incompatible layout with swizzle") elem_type = self._dtype if dtype is None else dtype ptr = recast_ptr(self._base, swizzle, dtype=elem_type) res = make_tensor(ptr, layout) @@ -6952,7 +3993,7 @@ class struct: self._cls = cls self.__name__ = f"struct::{cls.__name__}" # Get the class annotations - self._annotations = cls.__annotations__ + self._annotations = getattr(cls, "__annotations__", {}) # Create a dictionary to store the offsets self._offsets: Dict[str, int] = {} @@ -7000,7 +4041,7 @@ class struct: f"Struct element only support struct/array/base_dsl scalar, " f"but got {object}" ) - # Total aligment determined by the strictest requirement + # Total alignment determined by the strictest requirement alignment = max(alignment, sub_align) # Total size determined by alignment self._align_of = alignment @@ -7068,3 +4109,19 @@ class struct: align & (align - 1) ), "align should be a strictly positive power of 2." return (offset + (align - 1)) & ~(align - 1) + + +# Deprecated usage but keep them to avoid breaking some examples uses `cute.core.ThrMma` + +from .atom import ThrCopy as _ThrCopy +from .atom import ThrMma as _ThrMma + + +@deprecated("cute.core.ThrMma is deprecated, use cute.ThrMma instead") +class ThrMma(_ThrMma): + pass + + +@deprecated("cute.core.ThrCopy is deprecated, use cute.ThrCopy instead") +class ThrCopy(_ThrCopy): + pass diff --git a/python/CuTeDSL/cutlass/cute/math.py b/python/CuTeDSL/cutlass/cute/math.py index daaa6082..952f5ae9 100644 --- a/python/CuTeDSL/cutlass/cute/math.py +++ b/python/CuTeDSL/cutlass/cute/math.py @@ -9,8 +9,11 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from .core import TensorSSA +from typing import Callable, Union + from .typing import Numeric +from .tensor import TensorSSA + from cutlass._mlir.dialects import math, arith from typing import Callable, Union @@ -62,7 +65,7 @@ def acos( .. code-block:: - x = cute.make_fragment(layout) # Create tensor + x = cute.make_rmem_tensor(layout) # Create tensor y = x.load() # Load values z = acos(y) # Compute arc cosine """ @@ -85,7 +88,7 @@ def asin( .. code-block:: - x = cute.make_fragment(layout) # Create tensor + x = cute.make_rmem_tensor(layout) # Create tensor y = x.load() # Load values z = asin(y) # Compute arc sine """ @@ -108,11 +111,10 @@ def atan( .. code-block:: - x = cute.make_fragment(layout) # Create tensor + x = cute.make_rmem_tensor(layout) # Create tensor y = x.load() # Load values z = atan(y) # Compute arc tangent """ - raise NotImplementedError("atan is not implemented") return _math_op(math.atan, fastmath, a) @@ -137,8 +139,8 @@ def atan2( .. code-block:: - y = cute.make_fragment(ptr1, layout).load() # y coordinates - x = cute.make_fragment(ptr2, layout).load() # x coordinates + y = cute.make_rmem_tensor(ptr1, layout).load() # y coordinates + x = cute.make_rmem_tensor(ptr2, layout).load() # x coordinates theta = atan2(y, x) # Compute angles """ return _math_op(math.atan2, fastmath, a, b) @@ -160,7 +162,7 @@ def cos( .. code-block:: - x = cute.make_fragment(layout) # Create tensor + x = cute.make_rmem_tensor(layout) # Create tensor y = x.load() # Load values z = cos(y) # Compute cosine """ @@ -186,7 +188,7 @@ def erf( .. code-block:: - x = cute.make_fragment(layout) # Create tensor + x = cute.make_rmem_tensor(layout) # Create tensor y = x.load() # Load values z = erf(y) # Compute error function """ @@ -209,7 +211,7 @@ def exp( .. code-block:: - x = cute.make_fragment(layout) # Create tensor + x = cute.make_rmem_tensor(layout) # Create tensor y = x.load() # Load values z = exp(y) # Compute exponential """ @@ -232,7 +234,7 @@ def exp2( .. code-block:: - x = cute.make_fragment(layout) # Create tensor + x = cute.make_rmem_tensor(layout) # Create tensor y = x.load() # Load values z = exp2(y) # Compute 2^x """ @@ -255,7 +257,7 @@ def log( .. code-block:: - x = cute.make_fragment(layout) # Create tensor + x = cute.make_rmem_tensor(layout) # Create tensor y = x.load() # Load values z = log(y) # Compute natural logarithm """ @@ -278,7 +280,7 @@ def log2( .. code-block:: - x = cute.make_fragment(layout) # Create tensor + x = cute.make_rmem_tensor(layout) # Create tensor y = x.load() # Load values z = log2(y) # Compute log base 2 """ @@ -301,7 +303,7 @@ def log10( .. code-block:: - x = cute.make_fragment(layout) # Create tensor + x = cute.make_rmem_tensor(layout) # Create tensor y = x.load() # Load values z = log10(y) # Compute log base 10 """ @@ -326,7 +328,7 @@ def rsqrt( .. code-block:: - x = cute.make_fragment(layout) # Create tensor + x = cute.make_rmem_tensor(layout) # Create tensor y = x.load() # Load values z = rsqrt(y) # Compute 1/√x """ @@ -349,7 +351,7 @@ def sin( .. code-block:: - x = cute.make_fragment(layout) # Create tensor + x = cute.make_rmem_tensor(layout) # Create tensor y = x.load() # Load values z = sin(y) # Compute sine """ @@ -372,7 +374,7 @@ def sqrt( .. code-block:: - x = cute.make_fragment(layout) # Create tensor + x = cute.make_rmem_tensor(layout) # Create tensor y = x.load() # Load values z = sqrt(y) # Compute square root """ @@ -395,7 +397,7 @@ def tan( .. code-block:: - x = cute.make_fragment(layout) # Create tensor + x = cute.make_rmem_tensor(layout) # Create tensor y = x.load() # Load values z = tan(y) # Compute tangent """ @@ -418,7 +420,7 @@ def tanh( .. code-block:: - x = cute.make_fragment(layout) # Create tensor + x = cute.make_rmem_tensor(layout) # Create tensor y = x.load() # Load values z = tanh(y) # Compute hyperbolic tangent """ diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/common.py b/python/CuTeDSL/cutlass/cute/nvgpu/common.py index 1b0c4c82..ad993fcd 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/common.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/common.py @@ -18,8 +18,9 @@ import cutlass._mlir.dialects.cute as _cute_ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir import ir -from .. import core +from .. import atom from ..typing import Float16, Float32, Float64, Numeric +from cutlass import cute class OpError(DSLBaseError): @@ -28,7 +29,7 @@ class OpError(DSLBaseError): """ def __init__( - self, op: core.Op, message: str, suggestion: Optional[str] = None + self, op: atom.Op, message: str, suggestion: Optional[str] = None ) -> None: if suggestion is None: # Default suggestion @@ -48,7 +49,7 @@ class OpError(DSLBaseError): @dataclass(frozen=True) -class MmaUniversalOp(core.MmaOp): +class MmaUniversalOp(atom.MmaOp): """ The universal MMA Operation. @@ -65,7 +66,7 @@ class MmaUniversalOp(core.MmaOp): if self.abacc_dtype not in [Float16, Float32, Float64]: raise OpError( self, - f"expects the 'abacc_dtype' Op parameter to be one of Float16, Float32, or Float64", + "expects the 'abacc_dtype' Op parameter to be one of Float16, Float32, or Float64", ) def __str__(self) -> str: @@ -75,14 +76,14 @@ class MmaUniversalOp(core.MmaOp): ) def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaUniversalTrait": - shape_mnk_attr = ir.Attribute.parse(f'#cute.shape<"(1,1,1)">') + shape_mnk_attr = ir.Attribute.parse('#cute.shape<"(1,1,1)">') atom_ty = _cute_nvgpu_ir.UniversalFmaAtomType.get( shape_mnk_attr, self.abacc_dtype.mlir_type, self.abacc_dtype.mlir_type, self.abacc_dtype.mlir_type, ) - return MmaUniversalTrait(_cute_ir.atom(atom_ty, loc=loc, ip=ip)) + return MmaUniversalTrait(cute.make_atom(atom_ty, loc=loc, ip=ip)) def _verify_fragment_A(self, input, *, loc=None, ip=None): pass @@ -90,7 +91,8 @@ class MmaUniversalOp(core.MmaOp): def _verify_fragment_B(self, input, *, loc=None, ip=None): pass -class MmaUniversalTrait(core.Trait): + +class MmaUniversalTrait(atom.Trait): pass @@ -137,8 +139,9 @@ class MemoryScope(enum.Enum): def _to_ir(self) -> _cute_ir.MemScopeKind: return self.value + @dataclass(frozen=True) -class CopyUniversalOp(core.CopyOp): +class CopyUniversalOp(atom.CopyOp): """ The universal Copy Operation. @@ -182,8 +185,8 @@ class CopyUniversalOp(core.CopyOp): memory_order._to_ir(), memory_scope._to_ir(), ) - return CopyUniversalTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return CopyUniversalTrait(cute.make_atom(ty, loc=loc, ip=ip)) -class CopyUniversalTrait(core.Trait): +class CopyUniversalTrait(atom.Trait): pass diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py index 246360c2..5fc08b65 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py @@ -36,4 +36,5 @@ __all__ = [ "fence_tma_desc_acquire", "cp_fence_tma_desc_release", "fence_tma_desc_release", + "group_bulk_copy_modes", ] diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py index a1549560..55f8bb86 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py @@ -13,13 +13,15 @@ import enum from dataclasses import dataclass from typing import Optional, Type -from cutlass.cutlass_dsl import CuTeDSL, t +from cutlass import cute +from cutlass.base_dsl.arch import Arch +from cutlass.cutlass_dsl import CuTeDSL -import cutlass._mlir.dialects.cute as _cute_ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir import ir -from ...core import CopyOp, Trait, ReductionOp +from ...atom import CopyOp, Trait +from ...tensor import ReductionOp from ...typing import Int16, Pointer, Integer, Numeric from ..common import OpError from ..tcgen05.mma import CtaGroup @@ -73,19 +75,13 @@ class CopyG2SOp(CopyOp): def _make_trait( self, - copy_internal_type: Type[t.Numeric], + copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs, ) -> "CopyG2STrait": num_bits_per_copy = kwargs.get("num_bits_per_copy", None) - # Verify that the user provided enum values - if not isinstance(self.cache_mode, LoadCacheMode): - raise OpError( - self, - "expects the 'cache_mode' Op parameter to be a LoadCacheMode instance", - ) if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy <= 0): raise ValueError( "expects a 'num_bits_per_copy' kw argument of type int that is positive " @@ -100,7 +96,7 @@ class CopyG2SOp(CopyOp): ty = _cute_nvgpu_ir.CopyAtomSIMTAsyncCopyType.get( copy_internal_type.mlir_type, self.cache_mode._to_ir(), num_bits_per_copy ) - return CopyG2STrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return CopyG2STrait(cute.make_atom(ty, loc=loc, ip=ip)) class CopyG2STrait(Trait): @@ -114,8 +110,18 @@ class CopyG2STrait(Trait): #################################################################################################### TMA_MBAR_PTR_FIELD_NAME = "tma_bar" -TMA_MASK_FIELD_NAME = "mcast_mask" +TMA_MCAST_MASK_FIELD_NAME = "mcast_mask" TMA_DESC_PTR_FIELD_NAME = "tma_descriptor_ptr" +TMA_BYTE_MASK_FIELD_NAME = "byte_mask" + + +class TmaCopyOp(CopyOp): + """ + Base class for all TMA copy operations. + """ + + pass + # # TMA GMEM -> SMEM copies @@ -123,7 +129,7 @@ TMA_DESC_PTR_FIELD_NAME = "tma_descriptor_ptr" @dataclass(frozen=True) -class CopyBulkTensorTileG2SOp(CopyOp): +class CopyBulkTensorTileG2SOp(TmaCopyOp): """ Bulk tensor asynchrnous GMEM to SMEM Copy Operation using the TMA unit. @@ -133,27 +139,20 @@ class CopyBulkTensorTileG2SOp(CopyOp): cta_group: CtaGroup = CtaGroup.ONE - admissible_archs = [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ] - def __post_init__(self) -> None: if not isinstance(self.cta_group, CtaGroup): raise OpError( self, "expects the 'cta_group' parameter to be a CtaGroup instance" ) # Arch verification - arch = CuTeDSL._get_dsl().envar.arch - if arch not in self.admissible_archs: + arch: Arch = CuTeDSL._get_dsl().get_arch_enum() + if not arch >= Arch.sm_90: raise OpError( self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", + f"expects arch to be at least {Arch.sm_90.name}, but got {arch.name}", suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", ) - if (self.cta_group == CtaGroup.TWO) and arch[:5] == "sm_90": + if (self.cta_group == CtaGroup.TWO) and arch.major == Arch.sm_90.major: raise OpError( self, f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}", @@ -163,7 +162,7 @@ class CopyBulkTensorTileG2SOp(CopyOp): def __str__(self) -> str: res = "cp.async GMEM -> SMEM bulk tensor copy Operation" if self.cta_group == CtaGroup.TWO: - res += f"\n CTA group = 2" + res += "\n CTA group = 2" return res def _make_trait( @@ -225,7 +224,7 @@ class CopyBulkTensorTileG2SNonExecTrait(Trait): @dataclass(frozen=True) -class CopyBulkTensorTileG2SMulticastOp(CopyOp): +class CopyBulkTensorTileG2SMulticastOp(TmaCopyOp): """ Bulk tensor asynchrnous multicast GMEM to SMEM Copy Operation using the TMA unit. @@ -235,27 +234,20 @@ class CopyBulkTensorTileG2SMulticastOp(CopyOp): cta_group: CtaGroup = CtaGroup.ONE - admissible_archs = [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ] - def __post_init__(self): if not isinstance(self.cta_group, CtaGroup): raise OpError( self, "expects the 'cta_group' parameter to be a CtaGroup instance" ) # Arch verification - arch = CuTeDSL._get_dsl().envar.arch - if arch not in self.admissible_archs: + arch = CuTeDSL._get_dsl().get_arch_enum() + if not arch >= Arch.sm_90: raise OpError( self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", + f"expects arch to be at least {Arch.sm_90.name}, but got {arch.name}", suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", ) - if (self.cta_group == CtaGroup.TWO) and arch[:5] == "sm_90": + if (self.cta_group == CtaGroup.TWO) and arch.major == Arch.sm_90.major: raise OpError( self, f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}", @@ -265,7 +257,7 @@ class CopyBulkTensorTileG2SMulticastOp(CopyOp): def __str__(self) -> str: res = "cp.async GMEM -> SMEM bulk tensor multicast copy Operation" if self.cta_group == CtaGroup.TWO: - res += f"\n CTA group = 2" + res += "\n CTA group = 2" return res def _make_trait( @@ -309,12 +301,7 @@ class CopyBulkTensorTileG2SMulticastNonExecTrait(Trait): "expects a multicast mask to be provided via the mcast_mask kw argument" ) exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip) - attr_str = f"#cute_nvgpu.atom_copy_field_tmaload" - attr = ir.Attribute.parse(attr_str) - exec_value = _cute_nvgpu_ir.atom_set_value( - exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip - ) - attr_str = f"#cute_nvgpu.atom_copy_field_tmaload" + attr_str = "#cute_nvgpu.atom_copy_field_tmaload" attr = ir.Attribute.parse(attr_str) exec_value = _cute_nvgpu_ir.atom_set_value( exec_value, attr, Int16(mcast_mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip @@ -325,6 +312,13 @@ class CopyBulkTensorTileG2SMulticastNonExecTrait(Trait): exec_value = _cute_nvgpu_ir.atom_set_value( exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip ) + # Set the tma_bar_ptr at last to ensure that the atom creation and setting + # operations above can be moved outside the loop + attr_str = "#cute_nvgpu.atom_copy_field_tmaload" + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip + ) return exec_value @@ -334,7 +328,7 @@ class CopyBulkTensorTileG2SMulticastNonExecTrait(Trait): @dataclass(frozen=True) -class CopyBulkTensorTileS2GOp(CopyOp): +class CopyBulkTensorTileS2GOp(TmaCopyOp): """ Bulk tensor asynchronous SMEM to GMEM Copy Operation using the TMA unit. @@ -342,20 +336,13 @@ class CopyBulkTensorTileS2GOp(CopyOp): This Operation uses TMA in the ``.tile`` mode. """ - admissible_archs = [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ] - def __post_init__(self): # Arch verification - arch = CuTeDSL._get_dsl().envar.arch - if arch not in self.admissible_archs: + arch = CuTeDSL._get_dsl().get_arch_enum() + if not arch >= Arch.sm_90: raise OpError( self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", + f"expects arch to be at least {Arch.sm_90.name}, but got {arch.name}", suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", ) @@ -386,8 +373,9 @@ class CopyBulkTensorTileS2GTrait(Trait): ) return exec_value + @dataclass(frozen=True) -class CopyReduceBulkTensorTileS2GOp(CopyOp): +class CopyReduceBulkTensorTileS2GOp(TmaCopyOp): """ Bulk tensor asynchronous SMEM to GMEM Reduction Operation using the TMA unit. @@ -397,20 +385,13 @@ class CopyReduceBulkTensorTileS2GOp(CopyOp): reduction_kind: ReductionOp = ReductionOp.ADD - admissible_archs = [ - "sm_90", - "sm_90a", - "sm_100a", - "sm_100f", - ] - def __post__init__(self): # Arch verification - arch = CuTeDSL.__get_dsl().envar.arch - if arch not in self.admissible_archs: + arch = CuTeDSL._get_dsl().get_arch_enum() + if not arch >= Arch.sm_90: raise OpError( self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", + f"expects arch to be at least {Arch.sm_90.name}, but got {arch.name}", suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", ) @@ -461,11 +442,200 @@ class CopyReduceBulkTensorTileS2GTrait(Trait): ) return exec_value -__all__ = [ - "LoadCacheMode", - "CopyG2SOp", - "CopyBulkTensorTileG2SOp", - "CopyBulkTensorTileG2SMulticastOp", - "CopyBulkTensorTileS2GOp", - "CopyReduceBulkTensorTileS2GOp", -] + +# +# Bulk GMEM -> SMEM copies +# + + +@dataclass(frozen=True) +class CopyBulkG2SOp(CopyOp): + """ + Bulk copy asynchrnous GMEM to SMEM Copy Operation. + + See the `PTX documentation `__. + """ + + def __post_init__(self) -> None: + # Arch verification + arch: Arch = CuTeDSL._get_dsl().get_arch_enum() + if not arch >= Arch.sm_90: + raise OpError( + self, + f"expects arch to be at least {Arch.sm_90.name}, but got {arch.name}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + def __str__(self) -> str: + res = "cp.async GMEM -> SMEM bulk copy Operation" + return res + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "CopyBulkG2STrait": + num_bits_per_copy = kwargs.get("num_bits_per_copy", 0) + if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0): + raise ValueError( + "expects a 'num_bits_per_copy' kw argument of type int that is positive " + f"when creating a copy Atom for {self.__class__.__name__}" + ) + ty = _cute_nvgpu_ir.CopyAtomBulkCopyG2SType.get( + copy_internal_type.mlir_type, num_bits_per_copy, False + ) + return CopyBulkG2STrait(cute.make_atom(ty, loc=loc, ip=ip)) + + +class CopyBulkG2STrait(Trait): + # We allow kw args to be dropped so that the user can write common code for non-multicast + # and multicast loads. + def unpack( + self, + *, + loc=None, + ip=None, + mbar_ptr: Optional[Pointer] = None, + **kwargs, + ): + """ + Custom implementation of unpack for bulk copy load. + + The non-multicast bulk load requires a `mbar_ptr` keyword argument to be provided when + using `cute.copy`. Any other kw arguments will be ignored instead of triggering an error. + """ + if not isinstance(mbar_ptr, Pointer): + raise ValueError( + "expects a pointer to an mbarrier to be provided via the mbar_ptr kw argument" + ) + attr_str = f"#cute_nvgpu.atom_copy_field_bulkg2s<{TMA_MBAR_PTR_FIELD_NAME}>" + attr = ir.Attribute.parse(attr_str) + val = _cute_nvgpu_ir.atom_set_value( + self.value, attr, mbar_ptr.value, loc=loc, ip=ip + ) + return val + + +# +# Bulk GMEM -> SMEM Multicast copies +# + + +@dataclass(frozen=True) +class CopyBulkG2SMulticastOp(CopyOp): + """ + Bulk multicast copy asynchrnous GMEM to SMEM Copy Operation. + + See the `PTX documentation `__. + """ + + def __post_init__(self) -> None: + # Arch verification + arch: Arch = CuTeDSL._get_dsl().get_arch_enum() + if not arch >= Arch.sm_90: + raise OpError( + self, + f"expects arch to be at least {Arch.sm_90.name}, but got {arch.name}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + def __str__(self) -> str: + res = "cp.async GMEM -> SMEM multicast bulk copy Operation" + return res + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "CopyBulkG2SMulticastTrait": + num_bits_per_copy = kwargs.get("num_bits_per_copy", 0) + if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0): + raise ValueError( + "expects a 'num_bits_per_copy' kw argument of type int that is positive " + f"when creating a copy Atom for {self.__class__.__name__}" + ) + ty = _cute_nvgpu_ir.CopyAtomBulkCopyG2SType.get( + copy_internal_type.mlir_type, num_bits_per_copy, True + ) + return CopyBulkG2SMulticastTrait(cute.make_atom(ty, loc=loc, ip=ip)) + + +class CopyBulkG2SMulticastTrait(Trait): + # We allow kw args to be dropped so that the user can write common code for non-multicast + # and multicast loads. + def unpack( + self, + *, + loc=None, + ip=None, + mbar_ptr: Optional[Pointer] = None, + mcast_mask: Optional[Integer] = None, + **kwargs, + ): + """ + Custom implementation of unpack for bulk copy load. + + The non-multicast bulk load requires a `mbar_ptr` keyword argument to be provided when + using `cute.copy`. Any other kw arguments will be ignored instead of triggering an error. + """ + if not isinstance(mbar_ptr, Pointer): + raise ValueError( + "expects a pointer to an mbarrier to be provided via the mbar_ptr kw argument" + ) + if not isinstance(mcast_mask, Integer): + raise ValueError( + "expects a multicast mask to be provided via the mcast_mask kw argument" + ) + attr_str = f"#cute_nvgpu.atom_copy_field_bulkg2s<{TMA_MBAR_PTR_FIELD_NAME}>" + attr = ir.Attribute.parse(attr_str) + val = _cute_nvgpu_ir.atom_set_value( + self.value, attr, mbar_ptr.value, loc=loc, ip=ip + ) + attr_str = f"#cute_nvgpu.atom_copy_field_bulkg2s<{TMA_MCAST_MASK_FIELD_NAME}>" + attr = ir.Attribute.parse(attr_str) + val = _cute_nvgpu_ir.atom_set_value( + val, attr, Int16(mcast_mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + return val + + +# +# Bulk SMEM -> GMEM copies +# + + +@dataclass(frozen=True) +class CopyBulkS2GOp(CopyOp): + """ + Bulk copy asynchrnous SMEM to GMEM Copy Operation. + + See the `PTX documentation `__. + """ + + def __post_init__(self) -> None: + # Arch verification + arch: Arch = CuTeDSL._get_dsl().get_arch_enum() + if not arch >= Arch.sm_90: + raise OpError( + self, + f"expects arch to be at least {Arch.sm_90.name}, but got {arch.name}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + def __str__(self) -> str: + res = "cp.async SMEM -> GMEM bulk copy Operation" + return res + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "CopyBulkS2GTrait": + num_bits_per_copy = kwargs.get("num_bits_per_copy", 0) + if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0): + raise ValueError( + "expects a 'num_bits_per_copy' kw argument of type int that is positive " + f"when creating a copy Atom for {self.__class__.__name__}" + ) + ty = _cute_nvgpu_ir.CopyAtomBulkCopyS2GType.get( + copy_internal_type.mlir_type, num_bits_per_copy, False + ) + return CopyBulkS2GTrait(cute.make_atom(ty, loc=loc, ip=ip)) + + +class CopyBulkS2GTrait(Trait): + pass diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py index f64f07f1..02fe57f8 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py @@ -16,8 +16,18 @@ from cutlass.cutlass_dsl import dsl_user_op import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir.dialects import llvm -from ...typing import Coord, Layout, Tensor, Tiler, Pointer, Int16, Numeric, NumericMeta -from ... import core +from ...typing import ( + Coord, + Layout, + ComposedLayout, + Tensor, + Tiler, + Pointer, + Int16, + Numeric, + NumericMeta, +) +from ... import core, atom from .copy import ( CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, @@ -39,14 +49,14 @@ def make_tiled_tma_atom( CopyReduceBulkTensorTileS2GOp, ], gmem_tensor: Tensor, - smem_layout: Union[Layout, core.ComposedLayout], + smem_layout: Union[Layout, ComposedLayout], cta_tiler: Tiler, num_multicast: int = 1, *, internal_type: Optional[Type[Numeric]] = None, loc=None, ip=None, -) -> Tuple[core.CopyAtom, Tensor]: +) -> Tuple[atom.CopyAtom, Tensor]: """ Makes a TMA Copy Atom in the ``.tile`` mode to copy tiles of a GMEM tensor to/from SMEM buffer with the given Layout. @@ -74,7 +84,7 @@ def make_tiled_tma_atom( :param gmem_tensor: The GMEM tensor involved in the Copy :type gmem_tensor: Tensor :param smem_layout: The SMEM layout to construct the Copy Atom for - :type smem_layout: Union[Layout, core.ComposedLayout] + :type smem_layout: Union[Layout, ComposedLayout] :param cta_tiler: The CTA Tiler to use :type cta_tiler: Tiler :param num_multicast: The multicast factor @@ -82,7 +92,7 @@ def make_tiled_tma_atom( :param internal_type: An optional parameter for the internal data type to use when the actual data type is not supported by the TMA unit :type internal_type: Type[Numeric] :return: A Copy Atom for this Operation and the associated TMA tensor - :rtype: Tuple[core.CopyAtom, Tensor] + :rtype: Tuple[atom.CopyAtom, Tensor] """ if internal_type is not None: @@ -97,6 +107,9 @@ def make_tiled_tma_atom( ip=ip, ) + if isinstance(smem_layout, core._ComposedLayout): + smem_layout = smem_layout.value + if isinstance(op, CopyBulkTensorTileG2SOp): if num_multicast != 1: raise ValueError( @@ -113,7 +126,7 @@ def make_tiled_tma_atom( loc=loc, ip=ip, ) - return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] + return atom.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] elif isinstance(op, CopyBulkTensorTileG2SMulticastOp): if num_multicast < 1: raise ValueError( @@ -131,7 +144,7 @@ def make_tiled_tma_atom( ip=ip, ) return ( - core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), + atom.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), res[1], ) elif isinstance(op, CopyBulkTensorTileS2GOp): @@ -143,7 +156,7 @@ def make_tiled_tma_atom( loc=loc, ip=ip, ) - return core.CopyAtom(op, CopyBulkTensorTileS2GTrait(res[0])), res[1] + return atom.CopyAtom(op, CopyBulkTensorTileS2GTrait(res[0])), res[1] elif isinstance(op, CopyReduceBulkTensorTileS2GOp): res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_reduce( gmem_tensor.value, @@ -154,14 +167,14 @@ def make_tiled_tma_atom( loc=loc, ip=ip, ) - return core.CopyAtom(op, CopyReduceBulkTensorTileS2GTrait(res[0])), res[1] + return atom.CopyAtom(op, CopyReduceBulkTensorTileS2GTrait(res[0])), res[1] else: raise ValueError(f"expects a bulk tensor (TMA) Copy Op, but got {op}") @dsl_user_op def tma_partition( - atom: core.CopyAtom, + atom: atom.CopyAtom, cta_coord: Coord, cta_layout: Layout, smem_tensor: Tensor, @@ -221,7 +234,7 @@ def create_tma_multicast_mask( @dsl_user_op -def prefetch_descriptor(tma_atom: core.CopyAtom, *, loc=None, ip=None) -> None: +def prefetch_descriptor(tma_atom: atom.CopyAtom, *, loc=None, ip=None) -> None: """ Prefetches the TMA descriptor associated with the TMA Atom. """ @@ -230,7 +243,7 @@ def prefetch_descriptor(tma_atom: core.CopyAtom, *, loc=None, ip=None) -> None: @dsl_user_op def copy_tensormap( - tma_atom: core.CopyAtom, tensormap_ptr: Pointer, *, loc=None, ip=None + tma_atom: atom.CopyAtom, tensormap_ptr: Pointer, *, loc=None, ip=None ) -> None: """ Copies the tensormap held by a TMA Copy Atom to the memory location pointed to by the provided @@ -248,7 +261,7 @@ def copy_tensormap( @dsl_user_op def update_tma_descriptor( - tma_atom: core.CopyAtom, + tma_atom: atom.CopyAtom, gmem_tensor: Tensor, tma_desc_ptr: Pointer, *, @@ -289,7 +302,7 @@ def fence_tma_desc_acquire( """ See the `PTX documentation `__. """ - tma_desc_ptr_i64 = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value() + tma_desc_ptr_i64 = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) llvm.inline_asm( None, [tma_desc_ptr_i64], @@ -312,8 +325,12 @@ def cp_fence_tma_desc_release( """ See the `PTX documentation `__. """ - tma_desc_global_ptr_i64 = tma_desc_global_ptr.toint(loc=loc, ip=ip).ir_value() - tma_desc_shared_ptr_i32 = tma_desc_shared_ptr.toint(loc=loc, ip=ip).ir_value() + tma_desc_global_ptr_i64 = tma_desc_global_ptr.toint(loc=loc, ip=ip).ir_value( + loc=loc, ip=ip + ) + tma_desc_shared_ptr_i32 = tma_desc_shared_ptr.toint(loc=loc, ip=ip).ir_value( + loc=loc, ip=ip + ) llvm.inline_asm( None, [tma_desc_global_ptr_i64, tma_desc_shared_ptr_i32], @@ -339,3 +356,13 @@ def fence_tma_desc_release(*, loc=None, ip=None) -> None: is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) + + +@dsl_user_op +def group_bulk_copy_modes(src: Tensor, dst: Tensor, loc=None, ip=None) -> Tuple: + """ + Copy async bulk need group mode 0, acquiring whole tensor for bulk copy + """ + mSrc = core.group_modes(src, 0, core.rank(src)) + mDst = core.group_modes(dst, 0, core.rank(dst)) + return (mSrc, mDst) diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py index 9b4aa0db..03814de2 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py @@ -15,8 +15,8 @@ from cutlass.cutlass_dsl import dsl_user_op import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir -from .. import core -from ..typing import Shape, Layout, Tensor, Numeric, NumericMeta +from .. import core, atom +from ..typing import Shape, Layout, ComposedLayout, Tensor, Numeric, NumericMeta from ...impl_utils import check_type_in from .cpasync.copy import ( CopyBulkTensorTileG2SOp, @@ -37,15 +37,15 @@ from .cpasync.copy import ( def make_tiled_tma_atom_A( op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], gmem_tensor: Tensor, - smem_layout: Union[Layout, core.ComposedLayout], + smem_layout: Union[Layout, ComposedLayout], mma_tiler_mnk: Shape, - tiled_mma: core.TiledMma, - cluster_shape_vmnk: Shape, + tiled_mma: atom.TiledMma, + cluster_shape_vmnk: Union[Shape, None] = None, *, internal_type: Optional[Type[Numeric]] = None, loc=None, ip=None, -) -> Tuple[core.CopyAtom, Tensor]: +) -> Tuple[atom.CopyAtom, Tensor]: """ Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation accounting for the MK projections of the TiledMMA for A tensor loads. @@ -76,18 +76,18 @@ def make_tiled_tma_atom_A( :param gmem_tensor: The GMEM tensor to be loaded by this copy atom :type gmem_tensor: Tensor :param smem_layout: Shared memory layout to load the tensor into (PDSL) - :type smem_layout: Union[Layout, core.ComposedLayout] + :type smem_layout: Union[Layout, ComposedLayout] :param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions :type mma_tiler_mnk: Shape :param tiled_mma: The TiledMMA that will consume the load as operands - :type tiled_mma: core.TiledMma + :type tiled_mma: atom.TiledMma :param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions :type cluster_shape_vmnk: Shape :param internal_type: An optional parameter for the internal data type to when element type does not match the copy type :type internal_type: Type[Numeric] :return: A copy atom for this operation and the associated TMA coord tensor - :rtype: Tuple[core.CopyAtom, Tensor] + :rtype: Tuple[atom.CopyAtom, Tensor] """ @@ -114,8 +114,15 @@ def make_tiled_tma_atom_A( else: assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) # multicast across the N-mode since those would share the same tile of A + if cluster_shape_vmnk is None: + raise ValueError( + "cluster_shape_vmnk must be provided for multicast A tensor loads" + ) num_multicast = core.size(cluster_shape_vmnk, mode=[2]) + if isinstance(smem_layout, core._ComposedLayout): + smem_layout = smem_layout.value + # res[0] = the IR Value for the non-executable atom instance # res[1] = the IR Value for the associated TMA tensor res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( @@ -129,11 +136,11 @@ def make_tiled_tma_atom_A( ip=ip, ) if isinstance(op, CopyBulkTensorTileG2SOp): - return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] + return atom.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] else: assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) return ( - core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), + atom.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), res[1], ) @@ -142,15 +149,15 @@ def make_tiled_tma_atom_A( def make_tiled_tma_atom_B( op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], gmem_tensor: Tensor, - smem_layout: Union[Layout, core.ComposedLayout], + smem_layout: Union[Layout, ComposedLayout], mma_tiler_mnk: Shape, - tiled_mma: core.TiledMma, - cluster_shape_vmnk: Shape, + tiled_mma: atom.TiledMma, + cluster_shape_vmnk: Union[Shape, None] = None, *, internal_type: Optional[Type[Numeric]] = None, loc=None, ip=None, -) -> Tuple[core.CopyAtom, Tensor]: +) -> Tuple[atom.CopyAtom, Tensor]: """ Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation accounting for the NK projections of the TiledMMA for B tensor loads. @@ -181,7 +188,7 @@ def make_tiled_tma_atom_B( :param gmem_tensor: The GMEM tensor to be loaded by this copy atom :type gmem_tensor: Tensor :param smem_layout: Shared memory layout to load the tensor into (PDSL) - :type smem_layout: Union[Layout, core.ComposedLayout] + :type smem_layout: Union[Layout, ComposedLayout] :param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions :type mma_tiler_mnk: Shape :param tiled_mma: The TiledMMA that will consume the load as operands @@ -192,7 +199,7 @@ def make_tiled_tma_atom_B( type does not match the copy type :type internal_type: Type[Numeric] :return: A Copy Atom for this Operation and the associated TMA tensor - :rtype: Tuple[core.CopyAtom, Tensor] + :rtype: Tuple[atom.CopyAtom, Tensor] """ @@ -219,8 +226,15 @@ def make_tiled_tma_atom_B( else: assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) # multicast across the M-mode since those would share the same tile of B + if cluster_shape_vmnk is None: + raise ValueError( + "cluster_shape_vmnk must be provided for multicast B tensor loads" + ) num_multicast = core.size(cluster_shape_vmnk, mode=[1]) + if isinstance(smem_layout, core._ComposedLayout): + smem_layout = smem_layout.value + # res[0] = the IR Value for the non-executable atom instance # res[1] = the IR Value for the associated TMA tensor res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( @@ -234,11 +248,11 @@ def make_tiled_tma_atom_B( ip=ip, ) if isinstance(op, CopyBulkTensorTileG2SOp): - return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] + return atom.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] else: assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) return ( - core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), + atom.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), res[1], ) diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py index 2831bec6..56c92393 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py @@ -40,6 +40,7 @@ __all__ = [ "Field", "MmaTF32Op", "MmaF16BF16Op", + "MmaF16BF16SparseOp", "MmaI8Op", "MmaFP8Op", "MmaMXF8Op", diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py index df954b09..aa0da896 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py @@ -13,14 +13,15 @@ import enum from dataclasses import dataclass from typing import Type +from cutlass import cute +from cutlass.base_dsl.arch import Arch from cutlass.cutlass_dsl import CuTeDSL -import cutlass._mlir.dialects.cute as _cute_ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir import ir from ..common import OpError -from ...core import CopyOp, Trait +from ...atom import CopyOp, Trait from ...typing import Numeric from .mma import CtaGroup @@ -46,24 +47,6 @@ class Repetition(enum.Enum): def __repr__(self) -> str: return f"<{self.__class__.__name__}.{self.name}>" - @classmethod - def _missing_(cls, value): - if isinstance(value, int): - if value == 1: - return Repetition.x1 - elif value == 2: - return Repetition.x2 - elif value == 8: - return Repetition.x8 - elif value == 16: - return Repetition.x16 - elif value == 32: - return Repetition.x32 - elif value == 64: - return Repetition.x64 - elif value == 128: - return Repetition.x128 - class Pack(enum.Enum): """ @@ -97,17 +80,40 @@ class Unpack(enum.Enum): @dataclass(frozen=True) class _LdBase(CopyOp): + """ + Base class for TMEM load operations in the tcgen05 instruction set. + + This abstract base class provides common functionality and validation for tensor memory (TMEM) + load operations. It defines the fundamental parameters and architecture constraints that apply + to all load operation variants. + + :param repeat: Number of repetitions for the load operation, defaults to Repetition.x1 + :type repeat: Repetition, optional + :param pack: Packing pattern for TMEM to RMEM copies, defaults to Pack.NONE + :type pack: Pack, optional + :raises OpError: If the current architecture is not supported or if invalid parameters are provided + """ + repeat: Repetition = Repetition.x1 pack: Pack = Pack.NONE - admissible_archs = [ - "sm_100a", - "sm_100f", - ] + admissible_archs = Arch.filter( + lambda arch: arch.is_family_of(Arch.sm_100f) or arch.is_family_of(Arch.sm_110f) + ) def __post_init__(self) -> None: + """ + Post-initialization validation for TMEM load operations. + + Performs comprehensive validation of operation parameters and architecture compatibility. + This method is automatically called after object creation to ensure all constraints are met. + + :raises OpError: If architecture is not supported + :raises OpError: If repeat parameter is not a Repetition instance + :raises OpError: If pack parameter is not a Pack instance + """ # Arch verification - arch = CuTeDSL._get_dsl().envar.arch + arch = CuTeDSL._get_dsl().get_arch_enum() if arch not in self.admissible_archs: raise OpError( self, @@ -127,12 +133,21 @@ class _LdBase(CopyOp): ) def __str__(self) -> str: + """ + Generate a human-readable string representation of the load operation. + + Creates a formatted description showing the operation type, repetition count, + and any special packing configuration. + + :return: Multi-line string describing the operation configuration + :rtype: str + """ res = ( f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation" + f"\n number of repetitions = {self.repeat.value}" ) if self.pack == Pack.PACK_16b_IN_32b: - res += f"\n with 2x 16-bit to 32b packing" + res += "\n with 2x 16-bit to 32b packing" return res @@ -148,6 +163,24 @@ class Ld16x64bOp(_LdBase): def _make_trait( self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs ) -> "Ld16x64bTrait": + """ + Create a trait object for the 16x64b TMEM load operation. + + Constructs an MLIR-based trait that encapsulates the specific parameters and + characteristics of this load operation. The trait is used by the compiler + infrastructure to generate the appropriate low-level code. + + :param copy_internal_type: The data type for the copy operation + :type copy_internal_type: Type[Numeric] + :param loc: MLIR location information for debugging, defaults to None + :type loc: optional + :param ip: MLIR insertion point for code generation, defaults to None + :type ip: optional + :param kwargs: Additional keyword arguments passed to the trait constructor + :type kwargs: dict + :return: A trait object that represents this specific load operation + :rtype: Ld16x64bTrait + """ ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( copy_internal_type.mlir_type, 16, @@ -155,7 +188,7 @@ class Ld16x64bOp(_LdBase): self.repeat.value, ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, ) - return Ld16x64bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return Ld16x64bTrait(cute.make_atom(ty, loc=loc, ip=ip)) class Ld16x64bTrait(Trait): @@ -172,6 +205,15 @@ class Ld16x128bOp(_LdBase): """ def __post_init__(self) -> None: + """ + Additional validation specific to 16x128b load operations. + + Extends the base class validation with operation-specific constraints. + The 16x128b operation has limitations on the maximum repetition count due to + hardware register and bandwidth constraints. + + :raises OpError: If x128 repetition is specified + """ super().__post_init__() if self.repeat == Repetition.x128: raise OpError( @@ -183,6 +225,20 @@ class Ld16x128bOp(_LdBase): def _make_trait( self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs ) -> "Ld16x128bTrait": + """ + Create a trait object for the 16x128b TMEM load operation. + + :param copy_internal_type: The data type for the copy operation + :type copy_internal_type: Type[Numeric] + :param loc: MLIR location information for debugging, defaults to None + :type loc: optional + :param ip: MLIR insertion point for code generation, defaults to None + :type ip: optional + :param kwargs: Additional keyword arguments + :type kwargs: dict + :return: A trait object for this load operation + :rtype: Ld16x128bTrait + """ ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( copy_internal_type.mlir_type, 16, @@ -190,7 +246,7 @@ class Ld16x128bOp(_LdBase): self.repeat.value, ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, ) - return Ld16x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return Ld16x128bTrait(cute.make_atom(ty, loc=loc, ip=ip)) class Ld16x128bTrait(Trait): @@ -207,6 +263,15 @@ class Ld16x256bOp(_LdBase): """ def __post_init__(self) -> None: + """ + Additional validation specific to 16x256b load operations. + + Extends the base class validation with operation-specific constraints. + The 16x256b operation has more restrictive limitations on repetition count due to + the larger data size per operation requiring more hardware resources. + + :raises OpError: If x64 or x128 repetition is specified + """ super().__post_init__() if self.repeat in (Repetition.x128, Repetition.x64): raise OpError( @@ -218,6 +283,20 @@ class Ld16x256bOp(_LdBase): def _make_trait( self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs ) -> "Ld16x256bTrait": + """ + Create a trait object for the 16x256b TMEM load operation. + + :param copy_internal_type: The data type for the copy operation + :type copy_internal_type: Type[Numeric] + :param loc: MLIR location information for debugging, defaults to None + :type loc: optional + :param ip: MLIR insertion point for code generation, defaults to None + :type ip: optional + :param kwargs: Additional keyword arguments + :type kwargs: dict + :return: A trait object for this load operation + :rtype: Ld16x256bTrait + """ ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( copy_internal_type.mlir_type, 16, @@ -225,7 +304,7 @@ class Ld16x256bOp(_LdBase): self.repeat.value, ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, ) - return Ld16x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return Ld16x256bTrait(cute.make_atom(ty, loc=loc, ip=ip)) class Ld16x256bTrait(Trait): @@ -244,6 +323,20 @@ class Ld16x32bx2Op(_LdBase): def _make_trait( self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs ) -> "Ld16x32bx2Trait": + """ + Create a trait object for the 16x32bx2 TMEM load operation. + + :param copy_internal_type: The data type for the copy operation + :type copy_internal_type: Type[Numeric] + :param loc: MLIR location information for debugging, defaults to None + :type loc: optional + :param ip: MLIR insertion point for code generation, defaults to None + :type ip: optional + :param kwargs: Additional keyword arguments + :type kwargs: dict + :return: A trait object for this load operation + :rtype: Ld16x32bx2Trait + """ ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( copy_internal_type.mlir_type, 16, @@ -251,7 +344,7 @@ class Ld16x32bx2Op(_LdBase): self.repeat.value, ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, ) - return Ld16x32bx2Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return Ld16x32bx2Trait(cute.make_atom(ty, loc=loc, ip=ip)) class Ld16x32bx2Trait(Trait): @@ -270,6 +363,20 @@ class Ld32x32bOp(_LdBase): def _make_trait( self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs ) -> "Ld32x32bTrait": + """ + Create a trait object for the 32x32b TMEM load operation. + + :param copy_internal_type: The data type for the copy operation + :type copy_internal_type: Type[Numeric] + :param loc: MLIR location information for debugging, defaults to None + :type loc: optional + :param ip: MLIR insertion point for code generation, defaults to None + :type ip: optional + :param kwargs: Additional keyword arguments + :type kwargs: dict + :return: A trait object for this load operation + :rtype: Ld32x32bTrait + """ ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( copy_internal_type.mlir_type, 32, @@ -277,7 +384,7 @@ class Ld32x32bOp(_LdBase): self.repeat.value, ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, ) - return Ld32x32bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return Ld32x32bTrait(cute.make_atom(ty, loc=loc, ip=ip)) class Ld32x32bTrait(Trait): @@ -286,17 +393,30 @@ class Ld32x32bTrait(Trait): @dataclass(frozen=True) class _StBase(CopyOp): + """ + Base class for TMEM store operations in the tcgen05 instruction set. + + This abstract base class provides common functionality and validation for tensor memory (TMEM) + store operations. It defines the fundamental parameters and architecture constraints that apply + to all store operation variants. + + :param repeat: Number of repetitions for the store operation (required parameter) + :type repeat: Repetition + :param unpack: Unpacking pattern for RMEM to TMEM copies, defaults to Unpack.NONE + :type unpack: Unpack, optional + :raises OpError: If the current architecture is not supported or if invalid parameters are provided + """ + repeat: Repetition unpack: Unpack = Unpack.NONE - admissible_archs = [ - "sm_100a", - "sm_100f", - ] + admissible_archs = Arch.filter( + lambda arch: arch.is_family_of(Arch.sm_100f) or arch.is_family_of(Arch.sm_110f) + ) def __post_init__(self) -> None: # Arch verification - arch = CuTeDSL._get_dsl().envar.arch + arch = CuTeDSL._get_dsl().get_arch_enum() if arch not in self.admissible_archs: raise OpError( self, @@ -312,7 +432,7 @@ class _StBase(CopyOp): if not isinstance(self.unpack, Unpack): raise OpError( self, - "expects the 'pack' Op parameter to be a tcgen05.Unpack instance", + "expects the 'unpack' Op parameter to be a tcgen05.Unpack instance", ) def __str__(self) -> str: @@ -321,7 +441,7 @@ class _StBase(CopyOp): + f"\n number of repetitions = {self.repeat.value}" ) if self.unpack == Unpack.UNPACK_32b_IN_16b: - res += f"\n with 32-bit to 2x 16b unpacking" + res += "\n with 32-bit to 2x 16b unpacking" return res @@ -337,6 +457,20 @@ class St16x64bOp(_StBase): def _make_trait( self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs ) -> "St16x64bTrait": + """ + Create a trait object for the 16x64b TMEM store operation. + + :param copy_internal_type: The data type for the copy operation + :type copy_internal_type: Type[Numeric] + :param loc: MLIR location information for debugging, defaults to None + :type loc: optional + :param ip: MLIR insertion point for code generation, defaults to None + :type ip: optional + :param kwargs: Additional keyword arguments + :type kwargs: dict + :return: A trait object for this store operation + :rtype: St16x64bTrait + """ ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( copy_internal_type.mlir_type, 16, @@ -344,7 +478,7 @@ class St16x64bOp(_StBase): self.repeat.value, ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, ) - return St16x64bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return St16x64bTrait(cute.make_atom(ty, loc=loc, ip=ip)) class St16x64bTrait(Trait): @@ -379,7 +513,7 @@ class St16x128bOp(_StBase): self.repeat.value, ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, ) - return St16x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return St16x128bTrait(cute.make_atom(ty, loc=loc, ip=ip)) class St16x128bTrait(Trait): @@ -414,7 +548,7 @@ class St16x256bOp(_StBase): self.repeat.value, ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, ) - return St16x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return St16x256bTrait(cute.make_atom(ty, loc=loc, ip=ip)) class St16x256bTrait(Trait): @@ -440,7 +574,7 @@ class St16x32bx2Op(_StBase): self.repeat.value, ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, ) - return St16x32bx2Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return St16x32bx2Trait(cute.make_atom(ty, loc=loc, ip=ip)) class St16x32bx2Trait(Trait): @@ -466,7 +600,7 @@ class St32x32bOp(_StBase): self.repeat.value, ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, ) - return St32x32bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return St32x32bTrait(cute.make_atom(ty, loc=loc, ip=ip)) class St32x32bTrait(Trait): @@ -475,20 +609,27 @@ class St32x32bTrait(Trait): @dataclass(frozen=True) class _S2TCopyBase(CopyOp): - cta_group: CtaGroup + """ + Base class for SMEM to TMEM copy operations in the tcgen05 instruction set. - admissible_archs = [ - "sm_100a", - "sm_100f", - ] + This abstract base class provides common functionality and validation for shared memory (SMEM) + to tensor memory (TMEM) copy operations. These operations are used for high-throughput data + movement between different memory hierarchies in modern GPU architectures. + + :param cta_group: Cooperative Thread Array (CTA) group configuration + :type cta_group: CtaGroup + :raises OpError: If the current architecture is not SM100f family or if invalid parameters are provided + """ + + cta_group: CtaGroup def __post_init__(self) -> None: # Arch verification - arch = CuTeDSL._get_dsl().envar.arch - if arch not in self.admissible_archs: + arch = CuTeDSL._get_dsl().get_arch_enum() + if not arch.is_family_of(Arch.sm_100f): raise OpError( self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", + f"expects arch to be one of {Arch.filter(lambda arch: arch.is_family_of(Arch.sm_100f))}, but got {arch}", suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", ) # Verify that the user provided enum values @@ -519,6 +660,20 @@ class Cp128x256bOp(_S2TCopyBase): def _make_trait( self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs ) -> "Cp128x256bTrait": + """ + Create a trait object for the 128x256b SMEM to TMEM copy operation. + + :param copy_internal_type: The data type for the copy operation + :type copy_internal_type: Type[Numeric] + :param loc: MLIR location information for debugging, defaults to None + :type loc: optional + :param ip: MLIR insertion point for code generation, defaults to None + :type ip: optional + :param kwargs: Additional keyword arguments + :type kwargs: dict + :return: A trait object for this S2T copy operation + :rtype: Cp128x256bTrait + """ ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( copy_internal_type.mlir_type, 128, @@ -526,7 +681,7 @@ class Cp128x256bOp(_S2TCopyBase): self.cta_group.value, _cute_nvgpu_ir.CopyS2TBroadcast.none, ) - return Cp128x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return Cp128x256bTrait(cute.make_atom(ty, loc=loc, ip=ip)) class Cp128x256bTrait(Trait): @@ -552,7 +707,7 @@ class Cp128x128bOp(_S2TCopyBase): self.cta_group.value, _cute_nvgpu_ir.CopyS2TBroadcast.none, ) - return Cp128x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return Cp128x128bTrait(cute.make_atom(ty, loc=loc, ip=ip)) class Cp128x128bTrait(Trait): @@ -578,7 +733,7 @@ class Cp4x256bOp(_S2TCopyBase): self.cta_group.value, _cute_nvgpu_ir.CopyS2TBroadcast.none, ) - return Cp4x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return Cp4x256bTrait(cute.make_atom(ty, loc=loc, ip=ip)) class Cp4x256bTrait(Trait): @@ -604,7 +759,7 @@ class Cp4x32x128bOp(_S2TCopyBase): self.cta_group.value, _cute_nvgpu_ir.CopyS2TBroadcast.x4, ) - return Cp4x32x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return Cp4x32x128bTrait(cute.make_atom(ty, loc=loc, ip=ip)) class Cp4x32x128bTrait(Trait): @@ -630,7 +785,7 @@ class Cp2x64x128b0213Op(_S2TCopyBase): self.cta_group.value, _cute_nvgpu_ir.CopyS2TBroadcast.lw_0213, ) - return Cp2x64x128b0213Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return Cp2x64x128b0213Trait(cute.make_atom(ty, loc=loc, ip=ip)) class Cp2x64x128b0213Trait(Trait): @@ -656,7 +811,7 @@ class Cp2x64x128b0123Op(_S2TCopyBase): self.cta_group.value, _cute_nvgpu_ir.CopyS2TBroadcast.lw_0123, ) - return Cp2x64x128b0123Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return Cp2x64x128b0123Trait(cute.make_atom(ty, loc=loc, ip=ip)) class Cp2x64x128b0123Trait(Trait): diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py index 0ad27e62..f5c2cef8 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py @@ -13,7 +13,6 @@ from typing import overload, Type, Tuple, Union from cutlass.cutlass_dsl import dsl_user_op -import cutlass._mlir.dialects.cute as _cute_ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir.dialects import nvvm @@ -21,6 +20,7 @@ from ...typing import ( Shape, IntTuple, Layout, + ComposedLayout, Tensor, Int, Numeric, @@ -29,6 +29,8 @@ from ...typing import ( Int32, ) from ... import core +from ...tensor import recast_tensor +from ...atom import CopyAtom, TiledCopy from .mma import SmemLayoutAtomKind, CtaGroup from .copy import ( Pack, @@ -56,7 +58,7 @@ from .copy import ( @dsl_user_op def make_smem_layout_atom( kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None -) -> core.ComposedLayout: +) -> ComposedLayout: """ Makes a SMEM layout Atom. @@ -68,7 +70,7 @@ def make_smem_layout_atom( :param element_type: The element data type to construct the layout for :type element_type: Type[Numeric] :return: The SMEM layout atom - :rtype: core.ComposedLayout + :rtype: ComposedLayout """ if not isinstance(element_type, NumericMeta): raise TypeError(f"element_type must be a Numeric, but got {element_type}") @@ -130,13 +132,13 @@ def tile_to_mma_shape( @overload def tile_to_mma_shape( - atom: core.ComposedLayout, + atom: ComposedLayout, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None, -) -> core.ComposedLayout: ... +) -> ComposedLayout: ... @dsl_user_op @@ -152,7 +154,7 @@ def tile_to_mma_shape( if core.rank(order) != core.rank(mma_tile_shape) - 1: raise ValueError( f"rank(order)={core.rank(order)} must be equal to " - f"rank(mma_tile_shape)-1={core.rank(mma_tile_shape)-1}" + f"rank(mma_tile_shape)-1={core.rank(mma_tile_shape) - 1}" ) order_val = core._pack_int_tuple(order, loc=loc, ip=ip) mma_tile_shape_val = core._pack_shape(mma_tile_shape, loc=loc, ip=ip) @@ -164,8 +166,12 @@ def tile_to_mma_shape( ): raise ValueError("tile_to_mma_shape only supports static inputs") + if isinstance(atom, core._ComposedLayout): + atom = atom.value + res_ty = _cute_nvgpu_ir.tile_to_mma_shape(atom, mma_tile_shape_val, order_val) - return _cute_ir.static(res_ty, loc=loc, ip=ip) + res_val = core.static(res_ty, loc=loc, ip=ip) + return core.coalesce(res_val, target_profile=mma_tile_shape, loc=loc, ip=ip) @dsl_user_op @@ -209,7 +215,7 @@ def commit( #################################################################################################### -def is_tmem_load(atom: core.CopyAtom) -> bool: +def is_tmem_load(atom: CopyAtom) -> bool: """ Returns whether a CopyAtom instance is a TMEM load. """ @@ -225,7 +231,7 @@ def is_tmem_load(atom: core.CopyAtom) -> bool: ) -def is_tmem_store(atom: core.CopyAtom) -> bool: +def is_tmem_store(atom: CopyAtom) -> bool: """ Returns whether a CopyAtom instance is a TMEM store. """ @@ -242,7 +248,7 @@ def is_tmem_store(atom: core.CopyAtom) -> bool: def get_tmem_copy_properties( - atom: core.CopyAtom, + atom: CopyAtom, ) -> Tuple[int, int, int, Union[Pack, Unpack]]: """ Returns the properties of a TMEM copy atom (number of data paths, bits, repetitions, @@ -279,7 +285,7 @@ def find_tmem_tensor_col_offset(tmem_tensor: Tensor, *, loc=None, ip=None) -> In """ tmem_col_mask = 0x0000FFFF offset = ( - core.cosize(core.recast_tensor(tmem_tensor, Int32).layout, loc=loc, ip=ip) + core.cosize(recast_tensor(tmem_tensor, Int32).layout, loc=loc, ip=ip) & tmem_col_mask ) if isinstance(offset, int): @@ -289,8 +295,8 @@ def find_tmem_tensor_col_offset(tmem_tensor: Tensor, *, loc=None, ip=None) -> In @dsl_user_op def make_tmem_copy( - atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None -) -> core.TiledCopy: + atom: CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None +) -> TiledCopy: """ Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor. """ @@ -298,13 +304,13 @@ def make_tmem_copy( atom._trait.value, tmem_tensor.value, loc=loc, ip=ip ) new_trait = type(atom._trait)(tiled_copy_val) - return core.TiledCopy(atom.op, new_trait) + return TiledCopy(atom.op, new_trait) @dsl_user_op def make_s2t_copy( - atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None -) -> core.TiledCopy: + atom: CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None +) -> TiledCopy: """ Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor. """ @@ -312,12 +318,12 @@ def make_s2t_copy( atom._trait.value, tmem_tensor.value, loc=loc, ip=ip ) new_trait = type(atom._trait)(tiled_copy_val) - return core.TiledCopy(atom.op, new_trait) + return TiledCopy(atom.op, new_trait) @dsl_user_op def get_s2t_smem_desc_tensor( - atom: core.CopyAtom, smem_tensor: Tensor, *, loc=None, ip=None + atom: CopyAtom, smem_tensor: Tensor, *, loc=None, ip=None ) -> Tensor: """ Returns the SMEM descriptor tensor from a S2T copy atom and a SMEM tensor. diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py index 3a938523..2a116cb0 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py @@ -11,8 +11,10 @@ import enum from dataclasses import dataclass -from typing import Type +from typing import Type, Any +from cutlass import cute +from cutlass.base_dsl.arch import Arch from cutlass.cutlass_dsl import CuTeDSL, T import cutlass._mlir.dialects.cute as _cute_ir @@ -20,8 +22,9 @@ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir import ir from ..common import OpError -from ... import core -from ...core import Trait, _pack_shape, rank, depth, _Tensor +from ... import core, atom +from ...core import _pack_shape, rank, depth +from ...tensor import _Tensor from ...typing import ( Shape, Float4E2M1FN, @@ -40,6 +43,9 @@ from ...typing import ( AddressSpace, Pointer, ) +from ...atom import Trait + +from ..warp.mma import SparseMetadataFormat #################################################################################################### @@ -49,6 +55,14 @@ from ...typing import ( #################################################################################################### +class Tcgen05MmaOp(atom.MmaOp): + """ + Base class for all tcgen05 MMA operations. + """ + + pass + + class OperandMajorMode(enum.Enum): """ An enumeration for the majorness of the input operands of the MMA. @@ -108,6 +122,7 @@ class CtaGroup(enum.Enum): def __repr__(self) -> str: return f"<{self.__class__.__name__}.{self.name}>" + class Field(enum.Enum): """ An enumeration for the fields of the MMA Atom that can be modified at runtime. @@ -131,7 +146,7 @@ class Field(enum.Enum): # Base class for all tcgen05 MMA Ops with syntax `tcgen05.mma.cta_group.kind` used to factor out some internal code @dataclass(frozen=True) -class MmaOp(core.MmaOp): +class MmaOp(Tcgen05MmaOp): a_dtype: Type[Numeric] b_dtype: Type[Numeric] acc_dtype: Type[Numeric] @@ -141,14 +156,13 @@ class MmaOp(core.MmaOp): a_major_mode: OperandMajorMode b_major_mode: OperandMajorMode - admissible_archs = [ - "sm_100a", - "sm_100f", - ] + admissible_archs = Arch.filter( + lambda arch: arch.is_family_of(Arch.sm_100f) or arch.is_family_of(Arch.sm_110f) + ) def __post_init__(self) -> None: # Verify arch - arch = CuTeDSL._get_dsl().envar.arch + arch = CuTeDSL._get_dsl().get_arch_enum() if arch not in self.admissible_archs: raise OpError( self, @@ -194,18 +208,18 @@ class MmaOp(core.MmaOp): f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}", ) elif m == 128: - if (n < 16) or (n > 256) or (n % 16 != 0): + if (n < 8) or (n > 256) or (n % 8 != 0): raise OpError( self, - f"expects the N-mode to satisfy 8 <= N <= 256 and N % 16 == 0, but got {n}", + f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}", ) else: if m not in [128, 256]: raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}") - if (n < 32) or (n > 256) or (n % 32 != 0): + if (n < 16) or (n > 256) or (n % 16 != 0): raise OpError( self, - f"expects the N-mode to satisfy 32 <= N <= 256 and N % 32 == 0, but got {n}", + f"expects the N-mode to satisfy 16 <= N <= 256 and N % 16 == 0, but got {n}", ) def __str__(self) -> str: @@ -246,7 +260,7 @@ class MmaOp(core.MmaOp): return True -class MmaTrait(Trait): +class MmaTraits(Trait): admissible_fields = [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B] def set(self, field, value, *, loc=None, ip=None) -> None: @@ -260,10 +274,21 @@ class MmaTrait(Trait): self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip ) + def get(self, field, *, loc=None, ip=None) -> Any: + if field not in self.admissible_fields: + raise ValueError( + f"expects field to be one of {self.admissible_fields}, but got {field}" + ) + field_name = f"#cute_nvgpu.atom_mma_field_sm100<{field._to_ir_field_name()}>" + attr = ir.Attribute.parse(field_name) + return _cute_nvgpu_ir.atom_get_value( + Boolean.mlir_type, self.value, attr, loc=loc, ip=ip + ) + # Base class for all tcgen05 BlockScaled MMA Ops with syntax `tcgen05.mma.cta_group.kind.block_scale` used to factor out some internal code @dataclass(frozen=True) -class BlockScaledMmaOp(core.MmaOp): +class BlockScaledMmaOp(Tcgen05MmaOp): a_dtype: Type[Numeric] b_dtype: Type[Numeric] acc_dtype: Float32 @@ -276,12 +301,12 @@ class BlockScaledMmaOp(core.MmaOp): b_major_mode: OperandMajorMode admissible_archs = [ - "sm_100a", + Arch.sm_100a, ] def __post_init__(self) -> None: # Verify arch - arch = CuTeDSL._get_dsl().envar.arch + arch = CuTeDSL._get_dsl().get_arch_enum() if arch not in self.admissible_archs: raise OpError( self, @@ -409,6 +434,170 @@ class BlockScaledMmaTraits(Trait): self.value, attr, value, loc=loc, ip=ip ) + def get(self, field, *, loc=None, ip=None) -> Any: + if field not in [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B]: + raise ValueError(f"the get method for {field} is not supported") + field_name = f"#cute_nvgpu.atom_mma_field_sm100_block_scaled<{field._to_ir_field_name()}>" + attr = ir.Attribute.parse(field_name) + return _cute_nvgpu_ir.atom_get_value( + Boolean.mlir_type, self.value, attr, loc=loc, ip=ip + ) + + +# Base class for all tcgen05 Sparse MMA Ops with syntax `tcgen05.mma.cta_group.kind.sparse` used to factor out some internal code +@dataclass(frozen=True) +class SparseMmaOp(Tcgen05MmaOp): + a_dtype: Type[Numeric] + b_dtype: Type[Numeric] + acc_dtype: Type[Numeric] + shape_mnk: Shape + cta_group: CtaGroup + a_src: OperandSource + a_major_mode: OperandMajorMode + b_major_mode: OperandMajorMode + sparse_metadata_format: SparseMetadataFormat + + admissible_archs = Arch.filter( + lambda arch: arch.is_family_of(Arch.sm_100f) or arch.is_family_of(Arch.sm_110f) + ) + + def __post_init__(self) -> None: + # Verify arch + arch = CuTeDSL._get_dsl().get_arch_enum() + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + # Verify that the user provided enum values + if not isinstance(self.cta_group, CtaGroup): + raise OpError( + self, + "expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance", + ) + if not isinstance(self.a_src, OperandSource): + raise OpError( + self, + "expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance", + ) + if not isinstance(self.a_major_mode, OperandMajorMode): + raise OpError( + self, + "expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", + ) + if not isinstance(self.b_major_mode, OperandMajorMode): + raise OpError( + self, + "expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", + ) + if not isinstance(self.sparse_metadata_format, SparseMetadataFormat): + raise OpError( + self, + "expects the 'sparse_metadata_format' Op parameter to be a tcgen05.SparseMetadataFormat instance", + ) + # Verify the instruction shape + if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1): + raise OpError( + self, + f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, " + f"but got {self.shape_mnk}", + ) + m, n = self.shape_mnk[0], self.shape_mnk[1] + # For sparse MMA, the shape validation follows the same rules as dense MMA + # but the K dimension is typically doubled in the derived classes + if self.cta_group == CtaGroup.ONE: + if m not in [64, 128]: + raise OpError(self, f"expects the M-mode to be 64 or 128, but got {m}") + if m == 64: + if (n < 8) or (n > 256) or (n % 8 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}", + ) + elif m == 128: + if (n < 16) or (n > 256) or (n % 16 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 16 <= N <= 256 and N % 16 == 0, but got {n}", + ) + else: + if m not in [128, 256]: + raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}") + if (n < 32) or (n > 256) or (n % 32 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 32 <= N <= 256 and N % 32 == 0, but got {n}", + ) + + def __str__(self) -> str: + return ( + self.__class__.descriptive_name # type: ignore + + f"\n A data type = {self.a_dtype}" + + f"\n B data type = {self.b_dtype}" + + f"\n Accumulator data type = {self.acc_dtype}" + + f"\n CTA group = {self.cta_group}" + + f"\n A source location = {self.a_src}" + + f"\n A major mode = {self.a_major_mode}" + + f"\n B major mode = {self.b_major_mode}" + + f"\n Instruction shape MNK = {self.shape_mnk}" + + f"\n Sparse metadata format = {self.sparse_metadata_format}" + ) + + def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand A, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand B, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + +class SparseMmaTraits(Trait): + admissible_fields = [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B] + + def set(self, field, value, *, loc=None, ip=None) -> None: + if field not in self.admissible_fields: + raise ValueError( + f"expects field to be one of {self.admissible_fields}, but got {field}" + ) + field_name = ( + f"#cute_nvgpu.atom_mma_field_sm100_sparse<{field._to_ir_field_name()}>" + ) + attr = ir.Attribute.parse(field_name) + self.value = _cute_nvgpu_ir.atom_set_value( + self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + + def get(self, field, *, loc=None, ip=None) -> Any: + if field not in self.admissible_fields: + raise ValueError( + f"expects field to be one of {self.admissible_fields}, but got {field}" + ) + field_name = ( + f"#cute_nvgpu.atom_mma_field_sm100_sparse<{field._to_ir_field_name()}>" + ) + attr = ir.Attribute.parse(field_name) + return _cute_nvgpu_ir.atom_get_value( + Boolean.mlir_type, self.value, attr, loc=loc, ip=ip + ) + # # TF32 MMA @@ -472,18 +661,20 @@ class MmaTF32Op(MmaOp): 0, ) return MmaTF32Trait( - _cute_nvgpu_ir.make_sm100_mma( + cute.make_atom( ty, - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), + ( + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + ), loc=loc, ip=ip, ) ) -class MmaTF32Trait(MmaTrait): +class MmaTF32Trait(MmaTraits): pass @@ -564,18 +755,123 @@ class MmaF16BF16Op(MmaOp): 0, ) return MmaF16BF16Trait( - _cute_nvgpu_ir.make_sm100_mma( + cute.make_atom( ty, - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), + ( + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + ), loc=loc, ip=ip, ) ) -class MmaF16BF16Trait(MmaTrait): +class MmaF16BF16Trait(MmaTraits): + pass + + +@dataclass(frozen=True) +class MmaF16BF16SparseOp(SparseMmaOp): + """ + F16/BF16 tcgen05 Sparse MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::f16`` qualifier with sparse support. + """ + + descriptive_name = "tcgen05 F16/BF16 Sparse MMA Operation" + + def __init__( + self, + ab_dtype: Type[Numeric], + acc_dtype: Type[Numeric], + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + sparse_metadata_format: SparseMetadataFormat, + ) -> None: + super().__init__( + ab_dtype, + ab_dtype, + acc_dtype, + instruction_shape, + cta_group, + a_src, + a_major_mode, + b_major_mode, + sparse_metadata_format, + ) + self._verify() + + def _verify(self) -> None: + # Input data type verification + if self.a_dtype not in [Float16, BFloat16]: + raise OpError( + self, + "expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16", + ) + assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" + # Accumulator data type verification + if self.acc_dtype not in [Float16, Float32]: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", + ) + # Instruction shape verification + instruction_k = 32 # For sparse, K is doubled compared to dense F16/BF16 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16SparseTrait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMASparseType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + T.ui8(), + self.sparse_metadata_format._to_ir(), + self.a_src._to_ir(), + 0, # cScaleExp + ) + + def get_e_ptr(): + ptr_type = _cute_ir.PtrType.get(T.ui8(), _cute_ir.AddressSpace.tmem, 8) + address_value = Int32(0).ir_value(loc=loc, ip=ip) + aligned_ty = _cute_ir.ConstrainedIntType.get(8, 32) + aligned_intptr = _cute_ir.assume(aligned_ty, address_value, loc=loc, ip=ip) + ui8_tmem_ptr = _cute_ir.inttoptr(ptr_type, aligned_intptr, loc=loc, ip=ip) + return ui8_tmem_ptr + + return MmaF16BF16SparseTrait( + cute.make_atom( + ty, + ( + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + get_e_ptr().value, + ), + loc=loc, + ip=ip, + ) + ) + + +class MmaF16BF16SparseTrait(SparseMmaTraits): pass @@ -649,18 +945,20 @@ class MmaI8Op(MmaOp): 0, ) return MmaI8Trait( - _cute_nvgpu_ir.make_sm100_mma( + cute.make_atom( ty, - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), + ( + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + ), loc=loc, ip=ip, ) ) -class MmaI8Trait(MmaTrait): +class MmaI8Trait(MmaTraits): pass @@ -689,7 +987,6 @@ class MmaFP8Op(MmaOp): a_major_mode: OperandMajorMode, b_major_mode: OperandMajorMode, ) -> None: - super().__init__( ab_dtype, ab_dtype, @@ -741,18 +1038,20 @@ class MmaFP8Op(MmaOp): 0, ) return MmaFP8Trait( - _cute_nvgpu_ir.make_sm100_mma( + cute.make_atom( ty, - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), + ( + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + ), loc=loc, ip=ip, ) ) -class MmaFP8Trait(MmaTrait): +class MmaFP8Trait(MmaTraits): pass @@ -829,13 +1128,19 @@ class MmaMXF8Op(BlockScaledMmaOp): self.sf_vec_size, ) return MmaMXF8Trait( - _cute_nvgpu_ir.make_sm100_mma_bs( + cute.make_atom( ty, - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, - core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + ( + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + core.make_ptr( + self.sf_dtype, 0, _cute_ir.AddressSpace.tmem, loc=loc, ip=ip + ).value, + core.make_ptr( + self.sf_dtype, 0, _cute_ir.AddressSpace.tmem, loc=loc, ip=ip + ).value, + ), loc=loc, ip=ip, ) @@ -909,13 +1214,19 @@ class MmaMXF4Op(BlockScaledMmaOp): self.sf_vec_size, ) return MmaMXF4Trait( - _cute_nvgpu_ir.make_sm100_mma_bs( + cute.make_atom( ty, - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, - core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + ( + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + core.make_ptr( + self.sf_dtype, 0, _cute_ir.AddressSpace.tmem, loc=loc, ip=ip + ).value, + core.make_ptr( + self.sf_dtype, 0, _cute_ir.AddressSpace.tmem, loc=loc, ip=ip + ).value, + ), loc=loc, ip=ip, ) @@ -996,13 +1307,19 @@ class MmaMXF4NVF4Op(BlockScaledMmaOp): self.sf_vec_size, ) return MmaMXF4NVF4Trait( - _cute_nvgpu_ir.make_sm100_mma_bs( + cute.make_atom( ty, - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - Boolean(False).ir_value(loc=loc, ip=ip), - core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, - core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + ( + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + core.make_ptr( + self.sf_dtype, 0, _cute_ir.AddressSpace.tmem, loc=loc, ip=ip + ).value, + core.make_ptr( + self.sf_dtype, 0, _cute_ir.AddressSpace.tmem, loc=loc, ip=ip + ).value, + ), loc=loc, ip=ip, ) @@ -1012,6 +1329,7 @@ class MmaMXF4NVF4Op(BlockScaledMmaOp): class MmaMXF4NVF4Trait(BlockScaledMmaTraits): pass + #################################################################################################### # # SMEM layout atoms diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py b/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py index a6ad4ca8..2307fcae 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py @@ -12,13 +12,14 @@ from dataclasses import dataclass from typing import Type -import cutlass._mlir.dialects.cute as _cute_ir +from cutlass import cute import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir import ir from ..common import OpError -from ...core import CopyOp, Trait, _pack_shape +from ...core import _pack_shape from ...typing import Numeric +from ...atom import CopyOp, Trait @dataclass(frozen=True) @@ -39,7 +40,7 @@ class BaseOp(CopyOp): + f"\n number of matrices = {self.num_matrices}" ) if self.transpose: - res += f"\n transposed" + res += "\n transposed" return res @@ -71,7 +72,7 @@ class LdMatrix8x8x16bOp(BaseOp): self.num_matrices, ir.UnitAttr.get() if self.transpose else None, ) - return LdMatrix8x8x16bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return LdMatrix8x8x16bTrait(cute.make_atom(ty, loc=loc, ip=ip)) class LdMatrix8x8x16bTrait(Trait): @@ -110,7 +111,7 @@ class LdMatrix16x16x8bOp(BaseOp): self.num_matrices, ir.UnitAttr.get(), ) - return LdMatrix16x16x8bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return LdMatrix16x16x8bTrait(cute.make_atom(ty, loc=loc, ip=ip)) class LdMatrix16x16x8bTrait(Trait): @@ -144,7 +145,7 @@ class StMatrix8x8x16bOp(BaseOp): self.num_matrices, ir.UnitAttr.get() if self.transpose else None, ) - return StMatrix8x8x16bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return StMatrix8x8x16bTrait(cute.make_atom(ty, loc=loc, ip=ip)) class StMatrix8x8x16bTrait(Trait): @@ -182,7 +183,7 @@ class StMatrix16x8x8bOp(BaseOp): self.num_matrices, ir.UnitAttr.get(), ) - return StMatrix16x8x8bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return StMatrix16x8x8bTrait(cute.make_atom(ty, loc=loc, ip=ip)) class StMatrix16x8x8bTrait(Trait): diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py index 49df213b..968e0af4 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py @@ -12,16 +12,36 @@ from dataclasses import dataclass from typing import Type -import cutlass._mlir.dialects.cute as _cute_ir -import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +import enum +from cutlass import cute from ..common import OpError -from ...core import MmaOp, Trait, _pack_shape, _Tensor -from ...typing import Shape, Float16, BFloat16, Float32, Numeric, AddressSpace +from ...typing import Shape, Float16, BFloat16, Float32, Numeric +from ...core import _pack_shape +from ...tensor import _Tensor +from ...atom import MmaOp, Trait + +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir.dialects.cute_nvgpu import SparseMetadataFormat + + +#################################################################################################### +# +# MMA Ops and Traits +# +#################################################################################################### + + +class WarpMmaOp(MmaOp): + """ + Base class for all warp-level MMA operations. + """ + + pass @dataclass(frozen=True) -class MmaF16BF16Op(MmaOp): +class MmaF16BF16Op(WarpMmaOp): """ F16/BF16 tcgen05 MMA Operation. @@ -63,7 +83,7 @@ class MmaF16BF16Op(MmaOp): self.ab_dtype.mlir_type, self.acc_dtype.mlir_type, ) - return MmaF16BF16Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + return MmaF16BF16Trait(cute.make_atom(ty, loc=loc, ip=ip)) def __str__(self) -> str: return ( @@ -79,5 +99,84 @@ class MmaF16BF16Op(MmaOp): def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): pass + class MmaF16BF16Trait(Trait): pass + + +class SparseMetadataFormat(enum.Enum): + """ + An enumeration for the sparse metadata format of the MMA. + """ + + TID = SparseMetadataFormat.tid + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_nvgpu_ir.SparseMetadataFormat: + return self.value + + +@dataclass(frozen=True) +class MmaF16BF16SparseOp(WarpMmaOp): + ab_dtype: Type[Numeric] + acc_dtype: Type[Numeric] + shape_mnk: Shape + sparse_metadata_format: SparseMetadataFormat + + def __post_init__(self) -> None: + # verify field after initialization + if not isinstance(self.sparse_metadata_format, SparseMetadataFormat): + raise OpError( + self, + "expects the 'sparse_metadata_format' Op parameter to be a SparseMetadataFormat instance", + ) + # verify the instruction shape + if self.ab_dtype not in [Float16, BFloat16]: + raise OpError( + self, + "expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16", + ) + if self.acc_dtype not in [Float16, Float32]: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", + ) + if (self.ab_dtype == BFloat16) and (self.acc_dtype != Float32): + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be Float32 when 'ab_dtype' is BFloat16", + ) + if self.shape_mnk not in [(16, 8, 16), (16, 8, 32)]: + raise OpError( + self, + "expects the 'shape_mnk' Op parameter to be one of (16,8,16) or (16,8,32)", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16SparseTrait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM80SparseType.get( + shape_mnk.type.attribute, + self.ab_dtype.mlir_type, + self.ab_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.sparse_metadata_format._to_ir(), + ) + return MmaF16BF16SparseTrait(cute.make_atom(ty, loc=loc, ip=ip)) + + def __str__(self) -> str: + return ( + "warp-level F16/BF16 Sparse MMA Operation" + + f"\n A/B data type = {self.ab_dtype}" + + f"\n Accumulator data type = {self.acc_dtype}" + + f"\n Instruction shape MNK = {self.shape_mnk}" + + f"\n Sparse metadata format = {self.sparse_metadata_format}" + ) + + +class MmaF16BF16SparseTrait(Trait): + pass diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py index f6284134..1a7130bd 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py @@ -15,7 +15,7 @@ from cutlass.cutlass_dsl import dsl_user_op from cutlass._mlir.dialects import nvvm -from ...typing import Numeric, NumericMeta +from ...typing import Numeric, NumericMeta, ComposedLayout from ... import core from .mma import SmemLayoutAtomKind @@ -23,7 +23,7 @@ from .mma import SmemLayoutAtomKind @dsl_user_op def make_smem_layout_atom( kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None -) -> core.ComposedLayout: +) -> ComposedLayout: """ Makes a SMEM layout Atom. @@ -35,7 +35,7 @@ def make_smem_layout_atom( :param element_type: The element data type to construct the layout for :type element_type: Type[Numeric] :return: The SMEM layout atom - :rtype: core.ComposedLayout + :rtype: ComposedLayout """ if not isinstance(element_type, NumericMeta): raise TypeError(f"element_type must be a Numeric, but got {element_type}") diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py index 275861f7..ef282047 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py @@ -11,16 +11,19 @@ import enum from dataclasses import dataclass -from typing import Type +from typing import Type, Any -from cutlass.cutlass_dsl import CuTeDSL +from cutlass import cute +from cutlass.base_dsl.arch import Arch +from cutlass.cutlass_dsl import CuTeDSL, T import cutlass._mlir.dialects.cute as _cute_ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir import ir from ..common import OpError -from ...core import MmaOp, Trait, _pack_shape, rank, depth, _Tensor +from ...core import _pack_shape, rank, depth +from ...tensor import _Tensor from ...typing import ( Shape, Float16, @@ -29,9 +32,13 @@ from ...typing import ( Boolean, Float8E5M2, Float8E4M3FN, + Int32, + Int8, + Uint8, Numeric, AddressSpace, ) +from ...atom import MmaOp, Trait #################################################################################################### @@ -41,6 +48,14 @@ from ...typing import ( #################################################################################################### +class WarpGroupMmaOp(MmaOp): + """ + Base class for all warpgroup-level MMA operations. + """ + + pass + + class OperandMajorMode(enum.Enum): """ An enumeration for the majorness of the input operands of the MMA. @@ -104,7 +119,7 @@ class Field(enum.Enum): @dataclass(frozen=True) -class MmaOp(MmaOp): +class MmaOp(WarpGroupMmaOp): a_dtype: Type[Numeric] b_dtype: Type[Numeric] acc_dtype: Type[Numeric] @@ -113,15 +128,13 @@ class MmaOp(MmaOp): a_major_mode: OperandMajorMode b_major_mode: OperandMajorMode - admissible_archs = ["sm_90a"] - def __post_init__(self) -> None: # Verify arch - arch = CuTeDSL._get_dsl().envar.arch - if arch not in self.admissible_archs: + arch = CuTeDSL._get_dsl().get_arch_enum() + if not arch == Arch.sm_90a: raise OpError( self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", + f"expects arch to be {Arch.sm_90a}, but got {arch}", suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", ) # Verify that the user provided enum values @@ -193,7 +206,7 @@ class MmaOp(MmaOp): return True -class MmaTrait(Trait): +class MmaTraits(Trait): admissible_fields = [Field.ACCUMULATE] def set(self, field, value, *, loc=None, ip=None) -> None: @@ -207,13 +220,24 @@ class MmaTrait(Trait): self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip ) + def get(self, field, *, loc=None, ip=None) -> Any: + if field not in self.admissible_fields: + raise ValueError( + f"invalid field, must be {Field.ACCUMULATE}, but got {field}" + ) + field_name = f"#cute_nvgpu.atom_mma_field_sm90<{field._to_ir_field_name()}>" + attr = ir.Attribute.parse(field_name) + return _cute_nvgpu_ir.atom_get_value( + Boolean.mlir_type, self.value, attr, loc=loc, ip=ip + ) + @dataclass(frozen=True) class MmaF16BF16Op(MmaOp): """ F16/BF16 warpgroup MMA Operation. - See the `PTX documentation `__. + See the `PTX documentation `__. This Operation covers the instructions using the ``.f16`` or ``.bf16`` qualifiers for the input operands. """ @@ -281,16 +305,16 @@ class MmaF16BF16Op(MmaOp): self.a_src._to_ir(), ) return MmaF16BF16Trait( - _cute_nvgpu_ir.make_sm90_mma( + cute.make_atom( ty, - Boolean(False).ir_value(loc=loc, ip=ip), + (Boolean(False).ir_value(loc=loc, ip=ip),), loc=loc, ip=ip, ) ) -class MmaF16BF16Trait(MmaTrait): +class MmaF16BF16Trait(MmaTraits): pass @@ -299,7 +323,7 @@ class MmaF8Op(MmaOp): """ F16/BF16 warpgroup MMA Operation. - See the `PTX documentation `__. + See the `PTX documentation `__. This Operation covers the instructions using the ``.e4m3`` or ``.e5m2`` qualifiers for the input operands. """ @@ -367,13 +391,111 @@ class MmaF8Op(MmaOp): self.a_src._to_ir(), ) return MmaF8Trait( - _cute_nvgpu_ir.make_sm90_mma( - ty, Boolean(False).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + cute.make_atom( + ty, + (Boolean(False).ir_value(loc=loc, ip=ip),), + loc=loc, + ip=ip, ) ) -class MmaF8Trait(MmaTrait): +class MmaF8Trait(MmaTraits): + pass + + +@dataclass(frozen=True) +class MmaI8Op(MmaOp): + """ + I8 warpgroup MMA Operation. + + See the `PTX documentation `__. + This Operation covers the instructions using the ``.s8`` or ``.u8`` qualifiers for the input operands. + """ + + descriptive_name = "warpgroup I8 MMA Operation" + + def __init__( + self, + a_dtype: Type[Numeric], + b_dtype: Type[Numeric], + acc_dtype: Type[Numeric], + instruction_shape: Shape, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + ) -> None: + super().__init__( + a_dtype, + b_dtype, + acc_dtype, + instruction_shape, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self): + # Input data type verification + if self.a_dtype not in [Int8, Uint8]: + raise OpError( + self, + "expects the 'a_dtype' Op parameter to be one of Int8 or Uint8", + ) + if self.b_dtype not in [Int8, Uint8]: + raise OpError( + self, + "expects the 'b_dtype' Op parameter to be one of Int8 or Uint8", + ) + # Accumulator data type verification + if self.acc_dtype != Int32: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter must be Int32", + ) + + # Verify the instruction shape + instruction_k = 32 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + n = self.shape_mnk[1] + if not (n >= 8 and n <= 256 and (n == 8 or n == 24 or n % 16 == 0)): + raise OpError( + self, + "expects the N-mode to satisfy N=8*i where i={1,2,3,4} ", + f"or N=16*i where i={{3,4,...,15,16}}. But got {n}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaI8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM90Type.get( + shape_mnk.type.attribute, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + (T.si8() if self.a_dtype.signed else T.ui8()), + (T.si8() if self.b_dtype.signed else T.ui8()), + self.acc_dtype.mlir_type, + self.a_src._to_ir(), + ) + return MmaI8Trait( + cute.make_atom( + ty, + (Boolean(False).ir_value(loc=loc, ip=ip),), + loc=loc, + ip=ip, + ) + ) + + +class MmaI8Trait(MmaTraits): pass diff --git a/python/CuTeDSL/cutlass/cute/runtime.py b/python/CuTeDSL/cutlass/cute/runtime.py index 9128c67a..4395d7a1 100644 --- a/python/CuTeDSL/cutlass/cute/runtime.py +++ b/python/CuTeDSL/cutlass/cute/runtime.py @@ -13,41 +13,18 @@ import ctypes from functools import lru_cache import itertools import operator -from time import time -from typing import Union +from typing import Union, Optional # MLIR modules imports from cutlass._mlir import ir import cutlass._mlir.dialects.cute as _cute_ir -from cutlass.base_dsl.dsl import is_dynamic_expression -from cutlass.cutlass_dsl import JitArgAdapterRegistry +from cutlass.cutlass_dsl import JitArgAdapterRegistry, DSLRuntimeError # Local modules imports -from .typing import ( - AddressSpace, - Tensor, - Type, - Pointer, - Boolean, - Numeric, - Float4E2M1FN, - Int64, - Int32, - Int16, - Int8, - Uint64, - Uint32, - Uint16, - Uint8, - Float64, - Float32, - Float16, - BFloat16, - Float8E5M2, -) +from .typing import AddressSpace, Tensor, Type, Pointer, Numeric from . import core -from .core import _Tensor as CoreTensor +from .tensor import _Tensor as CoreTensor class _Pointer(Pointer): @@ -88,9 +65,9 @@ class _Pointer(Pointer): self._assumed_align = assumed_align self._c_pointer = None - assert ( - int(self._pointer) % self._assumed_align == 0 - ), f"pointer must be {self._assumed_align} bytes aligned" + assert int(self._pointer) % self._assumed_align == 0, ( + f"pointer must be {self._assumed_align} bytes aligned" + ) def size_in_bytes(self) -> int: self._desc = ctypes.c_void_p(int(self._pointer)) @@ -109,9 +86,6 @@ class _Pointer(Pointer): assert len(values) == 1 return values[0] - def __extract_mlir_values__(self): - return [self._c_pointer] - # Move mlir Type out of __init__ to decouple with mlir Context @property def mlir_type(self) -> ir.Type: @@ -146,11 +120,7 @@ class _Pointer(Pointer): class _Tensor(Tensor): - def __init__( - self, - tensor, - assumed_align=None, - ): + def __init__(self, tensor, assumed_align=None, use_32bit_stride=False): # If tensor is already a DLPack object, use it directly if hasattr(tensor, "__dlpack_device__") and not hasattr(tensor, "__dlpack__"): self._dlpack_data = tensor @@ -161,13 +131,13 @@ class _Tensor(Tensor): self._is_dynamic = False self._memref_desc = None self._dtype = None + self._use_32bit_stride = use_32bit_stride @property def __class__(self) -> Type[Tensor]: # Cheat to let `type(_Tensor())` to return cute.Tensor return Tensor - @staticmethod def lazily_load_dltensor(func): """Decorator to lazily load the DLTensorWrapper. @@ -177,13 +147,15 @@ class _Tensor(Tensor): def wrapper(self, *args, **kwargs): if self._dltensor_wrapper is None: - self._dltensor_wrapper = _cute_ir.DLTensorWrapper(self._dlpack_data) + self._dltensor_wrapper = _cute_ir.DLTensorWrapper( + self._dlpack_data, self._use_32bit_stride + ) return func(self, *args, **kwargs) return wrapper @lazily_load_dltensor - def mark_layout_dynamic(self, leading_dim: int | None = None): + def mark_layout_dynamic(self, leading_dim: Optional[int] = None): """Marks the tensor layout as dynamic based on the leading dimension. :param leading_dim: The leading dimension of the layout, defaults to None @@ -209,7 +181,7 @@ class _Tensor(Tensor): def mark_compact_shape_dynamic( self, mode: int, - stride_order: tuple[int, ...] | None = None, + stride_order: Optional[tuple[int, ...]] = None, divisibility: int = 1, ): """Marks the tensor shape as dynamic and propagates dynamic and divisibility information to the corresponding strides. @@ -308,10 +280,10 @@ class _Tensor(Tensor): return self.__str__() def __setitem__(self, crd, value): - raise TypeError(f"runtime._Tensor is not indexable") + raise TypeError("runtime._Tensor is not indexable") def __getitem__(self, crd): - raise TypeError(f"runtime._Tensor is not indexable") + raise TypeError("runtime._Tensor is not indexable") @property @lazily_load_dltensor @@ -326,7 +298,7 @@ class _Tensor(Tensor): @property def layout(self): raise NotImplementedError( - f"layout property is not supported in runtime, support in future" + "layout property is not supported in runtime, support in future" ) @property @@ -363,7 +335,7 @@ class _Tensor(Tensor): return core.leading_dim(self.shape, self.stride) def fill(self, value: Numeric): - raise TypeError(f"fill function is not supported in runtime") + raise TypeError("fill function is not supported in runtime") @property @lazily_load_dltensor @@ -389,6 +361,7 @@ class _Tensor(Tensor): def from_dlpack( tensor_dlpack, assumed_align=None, + use_32bit_stride=False, ) -> Tensor: """Convert from tensor object supporting __dlpack__() to a CuTe Tensor. @@ -397,6 +370,10 @@ def from_dlpack( :param assumed_align: Assumed alignment of the tensor (bytes), defaults to None, if None, will use the element size bytes as the assumed alignment. :type assumed_align: int, optional + :param use_32bit_stride: Whether to use 32-bit stride, defaults to False. When True, the dynamic + stride bitwidth will be set to 32 for small problem size (cosize(layout) <= Int32_max) for better performance. + This is only applied when the dimension is dynamic. + :type use_32bit_stride: bool, optional :return: A CuTe Tensor object :rtype: Tensor @@ -415,6 +392,7 @@ def from_dlpack( return _Tensor( tensor_dlpack, assumed_align=assumed_align, + use_32bit_stride=use_32bit_stride, ) diff --git a/python/CuTeDSL/cutlass/cute/tensor.py b/python/CuTeDSL/cutlass/cute/tensor.py new file mode 100644 index 00000000..96834f2b --- /dev/null +++ b/python/CuTeDSL/cutlass/cute/tensor.py @@ -0,0 +1,1976 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Optional, Union, Type, Tuple, overload +from typing_extensions import deprecated +from inspect import isclass +import operator + +from cutlass.cutlass_dsl import ( + dsl_user_op, + lru_cache_ir, + T, + cutlass_arith, + _binary_op_type_promote, +) +from cutlass._mlir import ir +import cutlass._mlir.dialects.cute as _cute_ir +from cutlass._mlir.dialects.cute import ReductionOp as ReductionOp +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir.dialects import vector, arith + +from .core import ( + _unpack_x_tuple, + _pack_int_tuple, + _pack_coord, + _pack_shape, + _ComposedLayout, + is_static, + is_weakly_congruent, + rank, + append, + depth, + flatten, + has_underscore, + make_layout, + slice_, + crd2idx, + size, + leading_dim, + recast_ptr, + recast_layout, +) + +from .typing import ( + IntTuple, + Coord, + Shape, + Stride, + Pointer, + Layout, + ComposedLayout, + Tensor, + AddressSpace, + is_integer, + is_int_tuple, + as_numeric, +) +from .typing import ( + Numeric, + Integer, + Boolean, + Int4, + Uint8, + Int8, + Int32, + Float4E2M1FN, + Float16, + Float32, + BFloat16, +) +from .tuple import transform_leaf, product, product_like, flatten_to_tuple +from .arch import cvt_i8_bf16_intrinsic, cvt_i4_bf16_intrinsic, cvt_f4e2m1_f16_intrinsic + + +@ir.register_value_caster(_cute_ir.MemRefType.get_static_typeid(), replace=True) +@ir.register_value_caster(_cute_ir.CoordTensorType.get_static_typeid(), replace=True) +@ir.register_value_caster( + _cute_nvgpu_ir.SmemDescViewType.get_static_typeid(), replace=True +) +class _Tensor(Tensor): + r"""Builtin Tensor Type as an IR value supporting standard iterator and layout.. + + :param value: The MLIR operation result value to initialize the tensor with + :type value: ir.Value + :param dtype: The user specified data type of the tensor elements. It could be \ + different from the underlying dtype in the iterator. The default is None. + :type dtype: Type[Numeric], optional + :param loc: The source location for the operation, defaults to None + :type loc: Location, optional + :param ip: The insertion point for the operation, defaults to None + :type ip: InsertionPoint, optional + + **Examples:** + + .. code-block:: python + + # Create a tensor with shape (4,8) in row-major layout + tensor = make_tensor(ptr, make_layout(shape=(4,8), stride=(8,1))) + + # Access individual element + val = tensor[0, 0] # or val = tensor[(0, 0)] + + # Slice operation - get first column + subtensor = tensor[None, 0] # or subtensor = tensor[(None, 0)] + """ + + @dsl_user_op + def __init__( + self, value, dtype: Optional[Type[Numeric]] = None, *, loc=None, ip=None + ): + self._dtype = dtype + if isinstance(value, ir.Value): + self.value = value + elif isinstance(value, _Tensor): + self.value = value.value + elif isinstance(value, _Tensor): + self.value = value.value + else: + raise TypeError(f"Expected ir.Value or _Tensor, got {type(value)}") + + # Set iterator + iter_val = _cute_ir.get_iter(self.value, loc=loc, ip=ip) + if isinstance(iter_val, Pointer): + self._iterator = iter_val + elif isinstance(iter_val.type, _cute_ir.IntTupleType): + self._iterator = _unpack_x_tuple(iter_val) + elif isinstance(iter_val, ir.Value): + # Example: SMEM descriptor iterator, not well supported today + self._iterator = iter_val + else: + raise TypeError(f"unsupported iterator type, got {type(iter_val)}") + + # Set dtype + if self._dtype is None: + if is_int_tuple(self.iterator): + self._dtype = IntTuple + elif isinstance(self.iterator, Pointer): + self._dtype = self.iterator.value_type + elif isinstance(self.type, _cute_nvgpu_ir.SmemDescViewType): + # SmemDescViewType do not need dtype + self._dtype = None + else: + raise TypeError(f"unsupported iterator type, got {type(self.iterator)}") + + def __str__(self): + from .core import pretty_str + + return f"tensor<{pretty_str(self.iterator)} o {pretty_str(self.layout)}>" + + def __extract_mlir_values__(self): + return [self.value] + + def __new_from_mlir_values__(self, values): + # Only expecting single value of _Tensor or ir.Value + # In this context, a _Tensor instance is an encapsulated ir.Value which is automatically created + # by value caster for MemRef/CoordTensor/SmemDescView typed values + assert len(values) == 1, f"Expected 1 value, but got {len(values)}" + assert isinstance(values[0], (_Tensor, ir.Value)), ( + f"Expected _Tensor or ir.Value, but got {type(values[0])}" + ) + return _Tensor( + values[0] if isinstance(values[0], ir.Value) else values[0].value, + dtype=self.element_type, + ) + + # Cheat to let `Type(_Tensor())` to return cute.Tensor + @property + def __class__(self) -> Type[Tensor]: + return Tensor + + # Make it behave as if it inherited from ir.Value + @property + @lru_cache_ir() + def type(self) -> ir.Type: + return self.value.type + + @dsl_user_op + def __getitem__( + self, crd: Coord, *, loc=None, ip=None + ) -> Union[Tensor, Numeric, IntTuple]: + """Access or slice tensor elements using coordinates. + + This method implements + * tensor evaluation T(c) = *(E + L(c)) when `c` is a coordinate without slicing, or + * tensor slicing operations T(c) = make_tensor(E + L(c), slice(L, c)) + where E is the iterator/engine and L is the layout + + :param crd: Coordinate or slice specification for accessing tensor elements + :type crd: Coord + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: Tensor element value or sliced subtensor + :rtype: Union[Tensor, ir.Value, IntTuple] + + :raises ValueError: If coordinate access is invalid for the tensor layout + + **Examples:** + + .. code-block:: python + + # Create a tensor with pointer iterator + ptr = make_ptr(cutlass.Float32, 0, cutlass.AddressSpace.gmem) + layout = make_layout((64, 128)) # leftmost mode is major + tensor = make_tensor(ptr, layout) # Tensor using pointer iterator + + # Direct element access loads from memory + val = tensor[0] # Loads element at offset 0 + val = tensor[1] # Loads element at offset 4 (4bytes per Float32) + val = tensor[(0, 1)] # Loads element at offset 64 + + # Create a coord tensor + layout = make_layout((64, 128), stride=(1 * E(0), 1 * E(1))) + tensor = make_tensor((128, 128), layout) + + # Direct element access + val = tensor[0] # Returns (128, 128) + val = tensor[(0, 1)] # Returns (128, 129) + + # Slice access + sliced = view[(3, None)] # Returns tensor slice + + .. note:: + Sub-byte types like Float4E2M1FN and Float6E3M2FN are not supported for scalar + dereference operations. Attempting to set individual elements of tensors with + these element types will result in errors. + + **Examples:** + + .. code-block:: python + + # Unsupported operations with sub-byte types: + ptr = make_ptr(cutlass.Float4E2M1FN, 0, cutlass.AddressSpace.gmem) + tensor = make_tensor(ptr, layout) + # The following will raise an error: + val = tensor[0] # Error: sub-byte scalar dereference not supported + + # Similarly for other sub-byte types: + ptr = make_ptr(cutlass.Float6E3M2FN, 0, cutlass.AddressSpace.gmem) + tensor = make_tensor(ptr, layout) + val = tensor[0] # Error: sub-byte scalar dereference not supported + """ + if has_underscore(crd): + return slice_(self.value, crd, loc=loc, ip=ip) + elif isinstance(self.type, _cute_ir.CoordTensorType): + res = _cute_ir.get_iter( + slice_(self, crd, loc=loc, ip=ip).value, loc=loc, ip=ip + ) + return _unpack_x_tuple(res, loc=loc, ip=ip) + else: + self._check_can_load_store() + self._check_can_dereference() + + crd_val = _pack_coord(crd, loc=loc, ip=ip) + data_val = _cute_ir.memref_load(self.value, crd_val, loc=loc, ip=ip) + return self.element_type(data_val) + + def _cvt_to_dest(self, data: Union["TensorSSA", Numeric], *, loc=None, ip=None): + orig_dtype = data.dtype + # Implicit upcast to wider type + if ( + data.dtype.is_same_kind(self.element_type) + and self.element_type.width >= data.dtype.width + ): + data = data.to(self.element_type, loc=loc, ip=ip) # type: ignore + + if data.dtype.width != self.element_type.width: + raise ValueError( + f"Type mismatch, store {orig_dtype} (-> {data.dtype}) " + f"to Tensor with element type {self.element_type}" + ) + + if data.dtype is Boolean and self.element_type is Boolean: + # Boolean Numeric and Boolean TensorSSA both hold i1 value, but we need int8 value store to memory + val = data.ir_value_int8(loc=loc, ip=ip) + else: + val = data.ir_value(loc=loc, ip=ip) + return val + + @dsl_user_op + def __setitem__( + self, + crd: Coord, + data: Union[int, float, ir.Value, Numeric, "TensorSSA"], + *, + loc=None, + ip=None, + ) -> None: + """Set tensor elements at specified coordinates. + + Assigns values to tensor elements through direct coordinate access or slice assignment. + For slice assignment, the value must be a TensorSSA with matching shape. + + :param crd: Coordinate or slice specification for tensor element assignment + :type crd: Coord + :param data: Value to assign - can be scalar or TensorSSA for slice assignment + :type data: Union[int, float, ir.Value, Numeric, TensorSSA] + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + + :raises ValueError: If tensor type doesn't support load/store operations + :raises ValueError: If slice assignment value is not a TensorSSA + :raises ValueError: If value type doesn't match tensor element type + :raises NotImplementedError: If value type is not supported + + .. note:: + Sub-byte types like Float4E2M1FN and Float6E3M2FN are not supported for scalar + dereference operations. Attempting to set individual elements of tensors with + these element types will result in errors. + + **Examples:** + + .. code-block:: python + + # Unsupported operations with sub-byte types: + ptr = make_ptr(cutlass.Float4E2M1FN, 0, cutlass.AddressSpace.gmem) + tensor = make_tensor(ptr, layout) + # The following will raise an error: + tensor[0] = 1.0 # Error: sub-byte scalar dereference not supported + + # Similarly for other sub-byte types: + ptr = make_ptr(cutlass.Float6E3M2FN, 0, cutlass.AddressSpace.gmem) + tensor = make_tensor(ptr, layout) + tensor[0] = 0.5 # Error: sub-byte scalar dereference not supported + """ + self._check_can_load_store() + + # convert scalar type + if not has_underscore(crd): + self._check_can_dereference() + # First, convert ir.Value to Numeric + if isinstance(data, ir.Value): + data = as_numeric(data) + elif isinstance(data, (int, float, bool)): + data = as_numeric(data) + + if not isinstance(data, Numeric): + raise ValueError(f"unsupported data type: {type(data)}") + + # Implicit upcast to wider type + val = self._cvt_to_dest(data, loc=loc, ip=ip) + if val.type != self.type.value_type: + raise ValueError( + f"type mismatch, store {val.type} to {self.element_type}" + ) + + crd_val = _pack_coord(crd, loc=loc, ip=ip) + _cute_ir.memref_store(self.value, crd_val, val, loc=loc, ip=ip) + else: + if not isinstance(data, TensorSSA): + raise ValueError(f"Expected TensorSSA, but got {data}") + + self.__getitem__(crd, loc=loc, ip=ip).store(data, loc=loc, ip=ip) # type: ignore + + @property + def __class__(self) -> Type[Tensor]: + return Tensor + + # Make it behave as if it inherited from ir.Value + @property + @lru_cache_ir() + def type(self) -> ir.Type: + return self.value.type + + @property + @lru_cache_ir() + def iterator(self) -> Union[Pointer, IntTuple]: + return self._iterator + + @property + @dsl_user_op + @lru_cache_ir() + def layout(self, *, loc=None, ip=None) -> Layout: + return _cute_ir.get_layout(self.value, loc=loc, ip=ip) + + @property + @lru_cache_ir() + def shape(self) -> Shape: + return self.layout.shape + + @property + @lru_cache_ir() + def stride(self) -> Stride: + if isinstance(self.layout.type, _cute_ir.ComposedLayoutType): + raise ValueError("can't get stride from composed layout") + return self.layout.stride + + @property + def leading_dim(self) -> Union[int, Tuple[int], None]: + """Get the leading dimension of this Tensor. + + :return: The index or indices of the first mode (from left to right) with stride 1 + :rtype: Union[int, Tuple[int], None] + :returns: + - int: Single leading dimension index if found + - Tuple[int]: Tuple of indices for nested leading dimensions + - None: If no leading dimension is found + + :postcondition: ``get(self.stride(), mode=self.leading_dim()) == 1 if self.leading_dim() != None else True`` + """ + return leading_dim(self.shape, self.stride) + + @property + @lru_cache_ir() + def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]: + return self._dtype + + @property + @lru_cache_ir() + def memspace(self) -> AddressSpace: + if isinstance(self.iterator, Pointer): + return self.iterator.memspace + + raise ValueError(f"{self} doesn't have memspace") + + @dsl_user_op + def load(self, *, loc=None, ip=None) -> "TensorSSA": + """Load tensor elements as a vector. + + Loads all elements of the tensor into a vector representation, assuming the tensor + has a static shape and is in a memory space that supports load operations. + + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: Vector representation of tensor elements + :rtype: TensorSSA + + :raises ValueError: If tensor has dynamic layout + :raises ValueError: If tensor memory space doesn't support load operations + """ + from .core import is_static + + if not is_static(self.shape): + raise ValueError("dynamic layout doesn't support load") + + self._check_can_load_store() + + res_vect = _cute_ir.memref_load_vec(self.value, row_major=True, loc=loc, ip=ip) + if self.element_type is Boolean: + assert res_vect.type.element_type == T.i8(), ( + f"Boolean tensor must be stored as i8 in memory, but got {res_vect.type.element_type}" + ) + zeros = full_like(self, 0, Int8, loc=loc, ip=ip) + res_vect = arith.cmpi( + arith.CmpIPredicate.ne, res_vect, zeros, loc=loc, ip=ip + ) + return TensorSSA(res_vect, self.shape, self.element_type) + + @dsl_user_op + def store(self, data: "TensorSSA", *, loc=None, ip=None): + """Store vector data into tensor. + + Stores vector data into the tensor, assuming matching shapes and a memory space + that supports store operations. + + :param data: Vector data to store into tensor + :type data: TensorSSA + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + + :raises ValueError: If tensor has dynamic layout + :raises ValueError: If tensor memory space doesn't support store operations + :raises ValueError: If data shape doesn't match tensor shape + """ + if not isinstance(data, TensorSSA): + raise ValueError(f"Expected TensorSSA, but got {type(data)}") + + if not is_static(self.shape): + raise ValueError("Dynamic layout doesn't support vectorized store") + + self._check_can_load_store() + + n_elems = size(self.shape, loc=loc, ip=ip) + if n_elems != size(data.shape, loc=loc, ip=ip): + raise ValueError( + f"lhs and rhs must have the same shape, but got {self.shape} and {data.shape}" + ) + + elem_mlir_type = cutlass_arith.element_type(data.dtype.mlir_type) + if cutlass_arith.is_narrow_precision(elem_mlir_type): + if elem_mlir_type.width * n_elems % 32 != 0: + raise ValueError( + f"narrow precision type must be 32-bit aligned vector, but got {elem_mlir_type} with {n_elems} elements" + ) + + # Implicit upcast to wider type + new_data = self._cvt_to_dest(data, loc=loc, ip=ip) + + return _cute_ir.memref_store_vec( + new_data, self.value, row_major=True, loc=loc, ip=ip + ) + + @dsl_user_op + def fill(self, value: Numeric, *, loc=None, ip=None) -> None: + """Fill tensor with a constant value. + + Fills all elements of the tensor with the specified value, assuming static size + and supported memory space. + + :param value: Value to fill tensor with + :type value: Union[int, float] + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + + :raises NotImplementedError: If tensor has dynamic size + + **Examples:** + + .. code-block:: python + + # Create tensor from numpy array + b = np.random.randn(4, 8).astype(np.float32) + tensor = from_dlpack(b) + + # Fill tensor with constant value + tensor.fill(0.5) # All elements become 0.5 + """ + self._check_can_load_store() + + sz = size(self, loc=loc, ip=ip) + if type(sz) is not int: + raise NotImplementedError(f"dynamic size is not supported: {self.type}") + + # Should we cast to destination type even with narrow cast? + dst_type = self.element_type + scalar_val = dst_type(value, loc=loc, ip=ip) + vect_val = full( + self.shape, fill_value=scalar_val, dtype=dst_type, loc=loc, ip=ip + ) + self.store(vect_val, loc=loc, ip=ip) + + def _check_can_load_store(self): + if not isinstance(self.type, _cute_ir.MemRefType) or self.memspace not in ( + AddressSpace.rmem, + AddressSpace.smem, + AddressSpace.gmem, + AddressSpace.generic, + ): + raise ValueError(f"{self} doesn't support load and store") + + def _check_can_dereference(self): + # Check for sub-byte types and raise error if needed + if self.element_type.width % 8 != 0 and self.element_type is not Boolean: + raise ValueError( + f"Sub-byte scalar dereference not supported for type {self.element_type}" + ) + + +# +# Tensor API +# + + +@dsl_user_op +def make_tensor( + iterator, layout: Union[Shape, Layout, ComposedLayout], *, loc=None, ip=None +) -> Tensor: + """Creates a tensor by composing an engine (iterator/pointer) with a layout. + + A tensor is defined as T = E ∘ L, where E is an engine (array, pointer, or counting iterator) + and L is a layout that maps logical coordinates to physical offsets. The tensor + evaluates coordinates by applying the layout mapping and dereferencing the engine + at the resulting offset. + + :param iterator: Engine component (pointer, iterator, or counting iterator) that provides + data access capabilities + :type iterator: Union[Pointer, IntTuple] + :param layout: Layout component that defines the mapping from logical coordinates to + physical offsets + :type layout: Union[Shape, Layout, ComposedLayout] + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A tensor object representing the composition E ∘ L + :rtype: Tensor + + :raises ValueError: If iterator type is not supported + + **Examples:** + + .. code-block:: python + + # Create a tensor with row-major layout + layout = make_layout((64, 128), stride=(128, 1)) + tensor = make_tensor(ptr, layout) + + # Create a tensor with hierarchical layout + layout = make_layout(((128, 8), (1, 4, 1)), stride=((32, 1), (0, 8, 4096))) + tensor = make_tensor(smem_ptr, layout) + + # Create a coord tensor + layout = make_layout(2, stride=16 * E(0)) + tensor = make_tensor(5, layout) + + Notes: + - The engine (iterator) must support random access operations + - Common engine types include raw pointers, arrays, and random-access iterators + - The layout defines both the shape (logical dimensions) and stride (physical mapping) + - Supports both direct coordinate evaluation T(c) and partial evaluation (slicing) + """ + if not isinstance(layout, (Layout, ComposedLayout)): + layout = make_layout(layout, loc=loc, ip=ip) + # Automatic decay to normal layout + elif isinstance(layout, ComposedLayout) and layout.is_normal: + layout = layout.outer + + res_ty = None + if is_integer(iterator) or isinstance(iterator, tuple): + iterator = _pack_int_tuple(iterator, loc=loc, ip=ip) + res_ty = _cute_ir.CoordTensorType.get(iterator.type, layout.type) + elif isinstance(iterator, Pointer): + iterator = iterator.value + res_ty = _cute_ir.MemRefType.get(iterator.type, layout.type) + elif isinstance(iterator, ir.Value) and isinstance( + iterator.type, _cute_nvgpu_ir.SmemDescType + ): + res_ty = _cute_nvgpu_ir.SmemDescViewType.get(layout.type) + else: + raise TypeError(f"unsupported iterator type, got {type(iterator)}") + + if isinstance(layout, _ComposedLayout): + layout = layout.value + return _cute_ir.make_view( + result=res_ty, iter=iterator, layout=layout, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_identity_tensor(shape: Shape, *, loc=None, ip=None) -> Tensor: + """Creates an identity tensor with the given shape. + + An identity tensor maps each coordinate to itself, effectively creating a counting + sequence within the shape's bounds. This is useful for generating coordinate indices + or creating reference tensors for layout transformations. + + :param shape: The shape defining the tensor's dimensions. Can be a simple integer + sequence or a hierarchical structure ((m,n),(p,q)) + :type shape: Shape + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A tensor that maps each coordinate to itself + :rtype: Tensor + + **Examples:** + + .. code-block:: python + + # Create a simple 1D coord tensor + tensor = make_identity_tensor(6) # [0,1,2,3,4,5] + + # Create a 2D coord tensor + tensor = make_identity_tensor((3,2)) # [(0,0),(1,0),(2,0),(0,1),(1,1),(2,1)] + + # Create hierarchical coord tensor + tensor = make_identity_tensor(((2,1),3)) + # [((0,0),0),((1,0),0),((0,0),1),((1,0),1),((0,0),2),((1,0),2)] + + Notes: + - The shape parameter follows CuTe's IntTuple concept + - Coordinates are ordered colexicographically + - Useful for generating reference coordinates in layout transformations + """ + shape_val = _pack_shape(shape, loc=loc, ip=ip) + return _cute_ir.make_identity_tensor(shape_val, loc=loc, ip=ip) + + +@dsl_user_op +def make_rmem_tensor( + layout_or_shape: Union[Layout, Shape], dtype: Type[Numeric], *, loc=None, ip=None +) -> Tensor: + """Creates a tensor in register memory with the specified layout/shape and data type. + + This function allocates a tensor in register memory (rmem) usually on stack with + either a provided layout or creates a new layout from the given shape. The tensor + will have elements of the specified numeric data type. + + :param layout_or_shape: Either a Layout object defining the tensor's memory organization, + or a Shape defining its dimensions + :type layout_or_shape: Union[Layout, Shape] + :param dtype: The data type for tensor elements (must be a Numeric type) + :type dtype: Type[Numeric] + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A tensor allocated in register memory + :rtype: Tensor + + **Examples:** + + .. code-block:: python + + # Create rmem tensor with explicit layout + layout = make_layout((128, 32)) + tensor = make_rmem_tensor(layout, cutlass.Float16) + + # Create rmem tensor directly from shape + tensor = make_rmem_tensor((64, 64), cutlass.Float32) + + Notes: + - Uses 32-byte alignment to support .128 load/store operations + - Boolean types are stored as 8-bit integers + - Handles both direct shapes and Layout objects + """ + if not issubclass(dtype, Numeric): + raise TypeError(f"value_type must be a type of Numeric, but got {type(dtype)}") + elem_ty = dtype.mlir_type if dtype is not Boolean else T.i8() + + # Alignment for register memory is useless(?), pick-up large enough number + # to allow .128 (> 16B) load store + alignment = 32 + layout = None + if not isinstance(layout_or_shape, Layout): + layout = make_layout(layout_or_shape, loc=loc, ip=ip) + elif isinstance(layout_or_shape, _ComposedLayout): + layout = layout_or_shape.value + else: + layout = layout_or_shape + + ptr_ty = _cute_ir.PtrType.get(elem_ty, AddressSpace.rmem, alignment) + res_ty = _cute_ir.MemRefType.get(ptr_ty, layout.type) + tensor = _cute_ir.memref_alloca(res_ty, layout=layout, loc=loc, ip=ip) + return _Tensor(tensor.value, dtype) + + +@dsl_user_op +@deprecated("`make_fragment` is deprecated, use `make_rmem_tensor` instead") +def make_fragment( + layout_or_shape: Union[Layout, Shape], dtype: Type[Numeric], *, loc=None, ip=None +) -> Tensor: + return make_rmem_tensor(layout_or_shape, dtype, loc=loc, ip=ip) + + +@dsl_user_op +def make_rmem_tensor_like( + src: Union[Layout, ComposedLayout, Tensor], + dtype: Optional[Type[Numeric]] = None, + *, + loc=None, + ip=None, +) -> Tensor: + """Creates a tensor in register memory with the same shape as the input layout but + compact col-major strides. This is equivalent to calling `make_rmem_tensor(make_layout_like(tensor))`. + + This function allocates a tensor in register memory (rmem) usually on stack with + with the compact layout like the source. The tensor will have elements of the + specified numeric data type or the same as the source. + + :param src: The source layout or tensor whose shape will be matched + :type src: Union[Layout, ComposedLayout, Tensor] + :param dtype: The element type for the fragment tensor, defaults to None + :type dtype: Type[Numeric], optional + :param loc: Source location for MLIR operations, defaults to None + :type loc: Location, optional + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: InsertionPoint, optional + + :return: A new layout or fragment tensor with matching shape + :rtype: Union[Layout, Tensor] + + **Examples:** + + Creating a rmem tensor from a tensor: + + .. code-block:: python + + smem_tensor = cute.make_tensor(smem_ptr, layout) + rmem_tensor = cute.make_rmem_tensor_like(smem_tensor, cutlass.Float32) + # frag_tensor will be a register-backed tensor with the same shape + + Creating a fragment with a different element type: + + .. code-block:: python + + tensor = cute.make_tensor(gmem_ptr, layout) + rmem_bool_tensor = cute.make_rmem_tensor_like(tensor, cutlass.Boolean) + # bool_frag will be a register-backed tensor with Boolean elements + + **Notes** + + - When used with a Tensor, if a type is provided, it will create a new + fragment tensor with that element type. + - For layouts with ScaledBasis strides, the function creates a fragment + from the shape only. + - This function is commonly used in GEMM and other tensor operations to + create register storage for intermediate results. + + """ + if not isinstance(src, (Layout, ComposedLayout, Tensor)): + raise TypeError( + f"src must be a Layout or ComposedLayout or Tensor, got {type(src)}" + ) + + if isinstance(src, Tensor): + if isinstance(src.type, _cute_ir.CoordTensorType): + if dtype is None: + raise ValueError( + "dtype must be provided when src is a coordinate tensor" + ) + + res_dtype = dtype + compact_layout = make_layout(src.shape, loc=loc, ip=ip) + src_layout = _cute_ir.make_layout_like(compact_layout, loc=loc, ip=ip) + else: + res_dtype = dtype or src.element_type + src_layout = src.layout + else: + if dtype is None: + raise ValueError("dtype must be provided when src is a layout") + + res_dtype = dtype + src_layout = src + + if isinstance(src_layout, _ComposedLayout): + src_layout = src_layout.value + + res_layout = _cute_ir.make_layout_like(src_layout, loc=loc, ip=ip) + return make_rmem_tensor(res_layout, res_dtype, loc=loc, ip=ip) + + +@overload +def make_fragment_like( + src: Tensor, dtype: Optional[Type[Numeric]], *, loc=None, ip=None +) -> Tensor: ... +@overload +def make_fragment_like(src: Layout, *, loc=None, ip=None) -> Layout: ... +@overload +def make_fragment_like(src: ComposedLayout, *, loc=None, ip=None) -> ComposedLayout: ... + + +@dsl_user_op +def make_fragment_like(src, dtype=None, *, loc=None, ip=None): + # Keep code to avoid potential regression + if isinstance(src, (Layout, _ComposedLayout)): + if isinstance(src, _ComposedLayout): + src = src.value + new_layout = _cute_ir.make_fragment_like(src, loc=loc, ip=ip) + + if dtype is not None: + # call make_rmem_tensor to convert layout to tensor + return make_rmem_tensor(new_layout, dtype, loc=loc, ip=ip) + else: + return new_layout + else: + return make_rmem_tensor_like(src, dtype, loc=loc, ip=ip) + + +@dsl_user_op +def recast_tensor( + src: Tensor, dtype: Type[Numeric], swizzle_=None, *, loc=None, ip=None +): + if not isclass(dtype) or not issubclass(dtype, Numeric): + raise TypeError(f"dtype must be a type of Numeric, but got {dtype}") + + if dtype is Boolean: + dst_width = 8 + else: + dst_width = dtype.width + + if src.element_type is Boolean: + src_width = 8 + else: + src_width = src.element_type.width + + src_iter = recast_ptr(src.iterator, dtype=dtype, loc=loc, ip=ip) + src_layout = recast_layout(dst_width, src_width, src.layout, loc=loc, ip=ip) + return make_tensor(src_iter, src_layout, loc=loc, ip=ip) + + +@dsl_user_op +def domain_offset(coord: Coord, tensor: Tensor, *, loc=None, ip=None) -> Tensor: + offset = crd2idx(coord, tensor.layout, loc=loc, ip=ip) + if isinstance(tensor.iterator, Pointer): + return make_tensor( + tensor.iterator.__add__(offset, loc=loc, ip=ip), + tensor.layout, + loc=loc, + ip=ip, + ) + elif is_integer(tensor.iterator) or isinstance(tensor.iterator, tuple): + new_iter = _cute_ir.add_offset( + _pack_int_tuple(tensor.iterator, loc=loc, ip=ip), + _pack_int_tuple(offset, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + return make_tensor( + _unpack_x_tuple(new_iter, loc=loc, ip=ip), + tensor.layout, + loc=loc, + ip=ip, + ) + else: + raise ValueError(f"unsupported tensor for domain_offset, got {tensor}") + + +@dsl_user_op +def print_tensor( + tensor: Union[Tensor, "TensorSSA"], *, verbose: bool = False, loc=None, ip=None +): + """Print content of the tensor in human readable format. + + Outputs the tensor data in a structured format showing both metadata + and the actual data values. The output includes tensor type information, + layout details, and a formatted array representation of the values. + + :param tensor: The tensor to print + :type tensor: Tensor + :param verbose: If True, includes additional debug information in the output + :type verbose: bool + :param loc: Source location where it's called, defaults to None + :type loc: source location, optional + :param ip: Insertion pointer for IR generation, defaults to None + :type ip: insertion pointer, optional + :raises NotImplementedError: If the tensor type doesn't support trivial dereferencing + + **Example output:** + + .. code-block:: text + + tensor(raw_ptr<@..., Float32, generic, align(4)> o (8,5):(5,1), data= + [[-0.4326, -0.5434, 0.1238, 0.7132, 0.8042], + [-0.8462, 0.9871, 0.4389, 0.7298, 0.6948], + [ 0.3426, 0.5856, 0.1541, 0.2923, 0.6976], + [-0.1649, 0.8811, 0.1788, 0.1404, 0.2568], + [-0.2944, 0.8593, 0.4171, 0.8998, 0.1766], + [ 0.8814, 0.7919, 0.7390, 0.4566, 0.1576], + [ 0.9159, 0.7577, 0.6918, 0.0754, 0.0591], + [ 0.6551, 0.1626, 0.1189, 0.0292, 0.8655]]) + """ + if isinstance(tensor, TensorSSA): + tmp = make_rmem_tensor(tensor.shape, tensor.dtype) + tmp.store(tensor) + tensor = tmp + + if isinstance(tensor.type, _cute_ir.MemRefType): + if tensor.element_type.is_integer: + signed = tensor.element_type.signed + else: + signed = False + else: + signed = True + + _cute_ir.print_view(tensor.value, verbose=verbose, is_signed=signed, loc=loc, ip=ip) + + +def _infer_broadcast_shape(*shapes: Shape) -> Shape: + """ + Infer the broadcasted shape from multiple input shapes according to broadcasting rules. + + :param shapes: Variable number of tensor shapes to broadcast together. + :type shapes: Shape + :return: The broadcasted shape. + :rtype: Shape + :raises ValueError: If no shapes provided or shapes cannot be broadcast together. + """ + + if len(shapes) == 0: + raise ValueError("At least one shape must be provided") + elif len(shapes) == 1: + return shapes[0] + + def _broadcast(*values): + non_one_values = [v for v in values if v != 1] + if len(non_one_values) == 0: + return 1 + elif len(set(non_one_values)) == 1: + return non_one_values[0] + else: + raise ValueError(f"cannot broadcast {values}") + + max_rank = max(rank(shape) for shape in shapes) + ext_shapes = tuple(append(shape, 1, up_to_rank=max_rank) for shape in shapes) + res_shape = transform_leaf(_broadcast, *ext_shapes) + return res_shape + + +class TensorSSA(cutlass_arith.ArithValue): + """A class representing thread local data from CuTe Tensor in value semantic and immutable. + + :param value: Flatten vector as ir.Value holding logic data of SSA Tensor + :type value: ir.Value + :param shape: The nested shape in CuTe of the vector + :type shape: Shape + :param dtype: Data type of the tensor elements + :type dtype: Type[Numeric] + + :ivar _shape: The nested shape in CuTe of the vector + :ivar _dtype: Data type of the tensor elements + + :raises ValueError: If shape is not static + """ + + def __init__(self, value, shape: Shape, dtype: Type[Numeric]): + """Initialize a new TensorSSA object. + + :param value: Flatten vector as ir.Value holding logic data of SSA Tensor + :type value: ir.Value + :param shape: The nested shape in CuTe of the vector + :type shape: Shape + :param dtype: Data type of the tensor elements + :type dtype: Type[Numeric] + :raises ValueError: If shape is not static + """ + if not is_static(shape): + raise ValueError("dynamic shape is not supported") + + signed = dtype.signed if issubclass(dtype, Integer) else False + super().__init__(value, signed) + + self._shape = shape + self._dtype = dtype + self._layout = None + + @property + def dtype(self) -> Type[Numeric]: + return self._dtype + + @property + def element_type(self) -> Type[Numeric]: + return self._dtype + + def __extract_mlir_values__(self): + return [self] + + def __new_from_mlir_values__(self, values): + return TensorSSA(values[0], self.shape, self.dtype) + + def __str__(self): + return f"tensor_value<{self.type} o {self.shape}>" + + @property + def shape(self): + return self._shape + + @overload + def _apply_op( + self, op, other: "TensorSSA", flip=False, *, loc, ip + ) -> "TensorSSA": ... + @overload + def _apply_op( + self, op, other: cutlass_arith.ArithValue, flip=False, *, loc, ip + ) -> "TensorSSA": ... + @overload + def _apply_op( + self, op, other: Union[int, float, bool], flip=False, *, loc, ip + ) -> "TensorSSA": ... + + def _apply_op(self, op, other, flip=False, *, loc=None, ip=None): + # Canonicalize into Numeric + if isinstance(other, (int, float, bool)) or ( + not isinstance(other, TensorSSA) + and isinstance(other, cutlass_arith.ArithValue) + ): + other = as_numeric(other) + + # Promote types + lhs, rhs, res_type = _binary_op_type_promote(self, other) + + # Promote scalar to vector + if not isinstance(rhs, TensorSSA): + assert isinstance(rhs, Numeric), ( + f"Expected rhs to be Numeric, but got {rhs}" + ) + vect_val = vector.broadcast(lhs.type, rhs.ir_value(loc=loc, ip=ip)) + + rhs = TensorSSA(vect_val, lhs.shape, lhs.dtype) + + if flip: + lhs, rhs = rhs, lhs + + if op in ( + operator.lt, + operator.le, + operator.gt, + operator.ge, + operator.eq, + operator.ne, + ): + res_type = Boolean + + assert isinstance(rhs, TensorSSA), f"rhs must be TensorSSA but got {rhs}" + + # broadcast to the same shape + res_shape = _infer_broadcast_shape(lhs.shape, rhs.shape) + lhs = lhs.broadcast_to(res_shape) + rhs = rhs.broadcast_to(res_shape) + + if ( + op in (operator.add, operator.sub) + and lhs.dtype == Boolean + and rhs.dtype == Boolean + ): + res = op(lhs.to(Int32), rhs.to(Int32)) + zero = zeros_like(res) + res = res.__ne__(zero).to(res_type) + else: + lhs_val = lhs.maybe_downcast() + rhs_val = rhs.maybe_downcast() + + if issubclass(lhs.dtype, Integer): + lhs_val = lhs_val.with_signedness(lhs.dtype.signed) + + if issubclass(rhs.dtype, Integer): + rhs_val = rhs_val.with_signedness(rhs.dtype.signed) + + res_vect = op(lhs_val, rhs_val) + res = TensorSSA(res_vect, lhs._shape, res_type) + + return res + + @dsl_user_op + def broadcast_to(self, target_shape: Shape, *, loc=None, ip=None) -> "TensorSSA": + """ + Broadcast the tensor to the target shape. + """ + # pad source shape to the same rank + shape = append(self.shape, 1, up_to_rank=rank(target_shape)) + if shape == target_shape: + return self + + def _check_broadcast(s, t): + if s != t and s != 1: + raise ValueError( + f"src_shape and target_shape must be the same when src_shape is not 1, but got {s} and {t}" + ) + + transform_leaf(_check_broadcast, shape, target_shape) + + # reshape to flatten N-D vector + flat_shp = flatten_to_tuple(shape) + temp_ty = ir.VectorType.get(list(flat_shp), self.dtype.mlir_type) + temp_vect = vector.shape_cast(temp_ty, self, loc=loc, ip=ip) + + # broadcast to result N-D vector + flat_tgt_shp = flatten_to_tuple(target_shape) + temp_tgt_ty = ir.VectorType.get(list(flat_tgt_shp), self.dtype.mlir_type) + temp_tgt_vect = vector.broadcast(temp_tgt_ty, temp_vect, loc=loc, ip=ip) + + res_1d_ty = ir.VectorType.get([size(target_shape)], self.dtype.mlir_type) # type: ignore + res_1d_vect = vector.shape_cast(res_1d_ty, temp_tgt_vect, loc=loc, ip=ip) + + return TensorSSA(res_1d_vect, target_shape, self.dtype) + + @dsl_user_op + def __pow__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the results of tensor^other. + + :param other: The other tensor for exponent. + :type other: TensorSSA + :return: The power of the tensor. + :rtype: TensorSSA + """ + return self._apply_op(operator.pow, other, loc=loc, ip=ip) + + @dsl_user_op + def __rpow__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the results of other^tensor. + + :param other: The other tensor to compute power with. + :type other: TensorSSA + :return: The element-wise power of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.pow, other, flip=True, loc=loc, ip=ip) + + @dsl_user_op + def __add__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the sum of the tensor and another tensor. + + :param other: The other tensor to add. + :type other: TensorSSA + :return: The sum of the two tensors with the same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.add, other, loc=loc, ip=ip) + + @dsl_user_op + def __radd__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the sum of the tensor and another tensor (reverse add) + + :param other: The other tensor to add. + :type other: TensorSSA + :return: The sum of the two tensors with the same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.add, other, flip=True, loc=loc, ip=ip) + + @dsl_user_op + def __sub__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the difference of the tensor and another tensor. + + :param other: The other tensor to subtract. + :type other: TensorSSA + :return: The subtraction of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.sub, other, loc=loc, ip=ip) + + @dsl_user_op + def __rsub__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the difference of the tensor and another tensor (reverse subtract) + + :param other: The other tensor to subtract. + :type other: TensorSSA + :return: The subtraction of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.sub, other, flip=True, loc=loc, ip=ip) + + @dsl_user_op + def __mul__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the multiplication of the tensor and another tensor. + + :param other: The other tensor to multiply. + :type other: TensorSSA + :return: The multiplication of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.mul, other, loc=loc, ip=ip) + + @dsl_user_op + def __rmul__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the multiplication of the tensor and another tensor (reverse multiply) + + :param other: The other tensor to multiply. + :type other: TensorSSA + :return: The multiplication of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.mul, other, flip=True, loc=loc, ip=ip) + + @dsl_user_op + def __mod__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the modulo of the tensor and another tensor. + + :param other: The other tensor to compute modulo with. + :type other: TensorSSA + :return: The element-wise modulo of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.mod, other, loc=loc, ip=ip) + + @dsl_user_op + def __rmod__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the modulo of the tensor and another tensor (reverse modulo) + + :param other: The other tensor to compute modulo with. + :type other: TensorSSA + :return: The element-wise modulo of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.mod, other, flip=True, loc=loc, ip=ip) + + @dsl_user_op + def __floordiv__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the floordiv(//) of the tensor and another tensor. + + :param other: The other tensor to compute floordiv with. + :type other: TensorSSA + :return: The floordiv of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.floordiv, other, loc=loc, ip=ip) + + @dsl_user_op + def __rfloordiv__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the floordiv(//) of the tensor and another tensor (reverse floordiv) + + :param other: The other tensor to compute floordiv with. + :type other: TensorSSA + :return: The floordiv of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.floordiv, other, flip=True, loc=loc, ip=ip) + + @dsl_user_op + def __truediv__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the truediv(/) of the tensor and another tensor. + + :param other: The other tensor to compute truediv with. + :type other: TensorSSA + :return: The truediv of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.truediv, other, loc=loc, ip=ip) + + @dsl_user_op + def __rtruediv__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the truediv(/) of the tensor and another tensor (reverse truediv) + + :param other: The other tensor to compute truediv with. + :type other: TensorSSA + :return: The truediv of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.truediv, other, flip=True, loc=loc, ip=ip) + + @dsl_user_op + def __eq__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the comparison of the tensor and another tensor as mask + + :param other: The other tensor to compare. + :type other: TensorSSA + :return: The comparison of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.eq, other, loc=loc, ip=ip) + + @dsl_user_op + def __ne__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise not equal comparison of the tensor and another tensor. + + :param other: The other tensor to compare. + :type other: TensorSSA + :return: A boolean tensor with same shape as inputs, True where self != other. + :rtype: TensorSSA + """ + return self._apply_op(operator.ne, other, loc=loc, ip=ip) + + @dsl_user_op + def __lt__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise less than comparison of the tensor and another tensor. + + :param other: The other tensor to compare with. + :type other: TensorSSA + :return: A boolean tensor with same shape as inputs, True where self < other. + :rtype: TensorSSA + """ + return self._apply_op(operator.lt, other, loc=loc, ip=ip) + + @dsl_user_op + def __le__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise less than or equal comparison of the tensor and another tensor. + + :param other: The other tensor to compare with. + :type other: TensorSSA + :return: A boolean tensor with same shape as inputs, True where self <= other. + :rtype: TensorSSA + """ + return self._apply_op(operator.le, other, loc=loc, ip=ip) + + @dsl_user_op + def __gt__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise greater than comparison of the tensor and another tensor. + + :param other: The other tensor to compare with. + :type other: TensorSSA + :return: A boolean tensor with same shape as inputs, True where self > other. + :rtype: TensorSSA + """ + return self._apply_op(operator.gt, other, loc=loc, ip=ip) + + @dsl_user_op + def __ge__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise greater than or equal comparison of the tensor and another tensor. + + :param other: The other tensor to compare with. + :type other: TensorSSA + :return: A boolean tensor with same shape as inputs, True where self >= other. + :rtype: TensorSSA + """ + return self._apply_op(operator.ge, other, loc=loc, ip=ip) + + @dsl_user_op + def __xor__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise XOR of the tensor and another tensor. + + :param other: The other tensor to perform XOR with. + :type other: TensorSSA + :return: The element-wise XOR of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.xor, other, loc=loc, ip=ip) + + @dsl_user_op + def __rxor__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the bitwise XOR of the tensor and another tensor. + + :param other: The other tensor to compute XOR with. + :type other: TensorSSA + :return: The element-wise bitwise XOR of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.xor, other, flip=True, loc=loc, ip=ip) + + @dsl_user_op + def __or__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise OR of the tensor and another tensor. + + :param other: The other tensor to perform OR with. + :type other: TensorSSA + :return: The element-wise OR of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.or_, other, loc=loc, ip=ip) + + @dsl_user_op + def __ror__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise OR of the tensor and another tensor. + + :param other: The other tensor to perform OR with. + :type other: TensorSSA + :return: The element-wise OR of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.or_, other, flip=True, loc=loc, ip=ip) + + @dsl_user_op + def __and__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise AND of the tensor and another tensor. + + :param other: The other tensor to perform AND with. + :type other: TensorSSA + :return: The element-wise AND of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.and_, other, loc=loc, ip=ip) + + @dsl_user_op + def __rand__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise AND of the tensor and another tensor. + + :param other: The other tensor to perform AND with. + :type other: TensorSSA + :return: The element-wise AND of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.and_, other, flip=True, loc=loc, ip=ip) + + @dsl_user_op + def __neg__(self, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the negation of the tensor. + + :return: The element-wise negation of the tensor + :rtype: TensorSSA + """ + + return self._apply_op(operator.sub, 0, flip=True, loc=loc, ip=ip) + + def _flatten_shape_and_coord(self, crd, *, loc=None, ip=None): + # Coalesce and flatten source layout at terminal of coordinate + # (N_0,(N_1,...), ...) -> (N_0,N_1,N_2,...) + crd_shp = product_like(self._shape, target_profile=crd, loc=loc, ip=ip) + + # Flatten coordinate + flat_shp = flatten(crd_shp) + assert isinstance(flat_shp, tuple) and is_static(flat_shp) + # (C_0,(C_1,...), ...) -> (C_0,C_1,C_2,...) + flat_crd = flatten(crd) + + assert isinstance(flat_crd, tuple) and is_static(flat_crd) + return flat_shp, flat_crd + + def _build_result(self, res_vect, res_shp, *, loc=None, ip=None): + if isinstance(res_shp, ir.Value): + raise ValueError(f"Expected static shape, but got {self._shape}") + + # cast back to 1D vector + res_1d_ty = ir.VectorType.get([size(res_shp)], self.type.element_type) + res_1d_vect = vector.shape_cast(res_1d_ty, res_vect, loc=loc, ip=ip) + return TensorSSA(res_1d_vect, res_shp, self.dtype) + + @dsl_user_op + def __getitem__( + self, crd: Coord, *, loc=None, ip=None + ) -> Union["TensorSSA", Numeric]: + """Access or slice tensor elements using coordinates. + + This method implements tensor evaluation T(c) = *(E + L(c)) where E is the iterator/engine + and L is the layout. It supports both direct element access and slicing operations. + + :param crd: Coordinate or slice specification for accessing tensor elements + :type crd: Coord + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: Tensor element value or sliced subtensor + :rtype: Union[TensorSSA, Numeric] + + :raises ValueError: If coordinate access is invalid for the tensor layout + + **Examples:** + + .. code-block:: python + + # Create a fragment from rmem as shape (8, 4) + layout = make_layout((8, 4)) + tensor = make_rmem_tensor(layout, Float32) + frg = tensor.load() + + # Direct element access + val = frg[0] # Returns first element of fragment + val = frg[(0, 1)] # Returns element at (0, 1) + + # Slice access + sliced = frg[(3, None)] # Returns fragment slice + """ + # short-cut to no-op + if crd is None: + return self + + if not has_underscore(crd): + if self._layout is None: + self._layout = make_layout(self._shape, loc=loc, ip=ip) + idx = crd2idx(crd, self._layout, loc=loc, ip=ip) + assert not isinstance(idx, tuple), "index must be scalar" + idx_val = as_numeric(idx).ir_value(loc=loc, ip=ip) + res_val = vector.extractelement(self, position=idx_val, loc=loc, ip=ip) + return self.dtype(res_val) + + if not is_static(crd): + raise ValueError("dynamic coordinate is not supported") + + flat_shp, flat_crd = self._flatten_shape_and_coord(crd, loc=loc, ip=ip) + + multi_dim_ty = ir.VectorType.get(list(flat_shp), self.type.element_type) + # vector -> vector + tmp_vect = vector.shape_cast(multi_dim_ty, self, loc=loc, ip=ip) + + # Slice and keep dims matching `_` or None + res_shp = slice_(self._shape, crd, loc=loc, ip=ip) + assert not isinstance(res_shp, ir.Value), ( + f"Expected static shape and coordinates, but got {self._shape} and {crd}" + ) + + # Offsets is index of coordinates if NOT `_` otherwise 0 + offsets = [c if c is not None else 0 for c in flat_crd] + # Sizes is size of shapes if `_` otherwise 1 + sizes = [s if c is None else 1 for s, c in zip(flat_shp, flat_crd)] + # Logic stride to index vector. Only support stride-1 by vector + strides = [1] * rank(flat_shp) + + # Vector slice on N-D vector + res_ty = ir.VectorType.get(list(sizes), self.type.element_type) + res_vect = vector.extract_strided_slice( + res_ty, + tmp_vect, + offsets=offsets, + sizes=sizes, + strides=strides, + loc=loc, + ip=ip, + ) + + # Slice and keep dims matching `_` or None + res_shp = slice_(self._shape, crd, loc=loc, ip=ip) + return self._build_result(res_vect, res_shp, loc=loc, ip=ip) + + @dsl_user_op + def to(self, dtype: Type[Numeric], *, loc=None, ip=None): + """Convert the tensor to a different numeric type. + + :param dtype: The target numeric type to cast to. + :type dtype: Type[Numeric] + :return: A new tensor with the same shape but with elements cast to the target type. + :rtype: TensorSSA + :raises TypeError: If dtype is not a subclass of Numeric. + :raises NotImplementedError: If dtype is an unsigned integer type. + """ + if dtype is ir.Value: + return self + + if not isclass(dtype) or not issubclass(dtype, Numeric): + raise TypeError(f"dtype must be a type of Numeric, but got {type(dtype)}") + + src_dtype = self.dtype + if src_dtype == dtype: + return self + + # maybe downcast can lose signedness + src = self.maybe_downcast().with_signedness(self.signed) + if src_dtype.is_float and dtype.is_float: + if src_dtype == Float4E2M1FN and dtype in (Float16, Float32): + res_vect = cvt_f4e2m1_f16_intrinsic( + src, size(self.shape), loc=loc, ip=ip + ) + if dtype == Float32: + res_vect = cutlass_arith.cvtf( + res_vect, dtype.mlir_type, loc=loc, ip=ip + ) + else: + res_vect = cutlass_arith.cvtf(src, dtype.mlir_type, loc=loc, ip=ip) + elif src_dtype.is_float and issubclass(dtype, Integer): + res_vect = cutlass_arith.fptoi( + src, dtype.signed, dtype.mlir_type, loc=loc, ip=ip + ) + elif issubclass(src_dtype, Integer) and dtype.is_float: + # fast conversion path for supported combinations + if src_dtype in (Int8, Uint8) and dtype == BFloat16: + res_vect = cvt_i8_bf16_intrinsic(src, size(self.shape), loc=loc, ip=ip) + elif src_dtype == Int4 and dtype == BFloat16: + res_vect = cvt_i4_bf16_intrinsic(src, size(self.shape), loc=loc, ip=ip) + else: + res_vect = cutlass_arith.itofp( + src, src_dtype.signed, dtype.mlir_type, loc=loc, ip=ip + ) + else: + res_vect = cutlass_arith.int_to_int(src, dtype, loc=loc, ip=ip) + + return TensorSSA(res_vect, self._shape, dtype) + + @dsl_user_op + def ir_value(self, *, loc=None, ip=None): + return self + + @dsl_user_op + def ir_value_int8(self, *, loc=None, ip=None): + """ + Returns int8 ir value of Boolean tensor. + When we need to store Boolean tensor ssa, use ir_value_int8(). + + :param loc: Source location information, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: Optional[InsertionPoint], optional + :return: The int8 value of this Boolean + :rtype: ir.Value + """ + assert self.element_type is Boolean, ( + f"Only boolean type needs to be converted to int8, got {self.element_type}" + ) + + if not hasattr(self, "_value_int8"): + self._value_int8 = arith.extsi( + T.vector(self.type.shape[0], T.i8()), self, loc=loc, ip=ip + ) + return self._value_int8 + + @dsl_user_op + def reduce(self, op, init_val, reduction_profile: Coord, *, loc=None, ip=None): + """ + Perform reduce on selected modes with given predefined reduction op. + + :param op: The reduction operator to use (operator.add or operator.mul) + :type op: operator + :param init_val: The initial value for the reduction + :type init_val: numeric + :param reduction_profile: Specifies which dimensions to reduce. Dimensions marked with `None` are kept. + :type reduction_profile: Coord + + :return: The reduced tensor + :rtype: TensorSSA + + **Examples:** + + .. code-block:: python + + reduce(f32 o (4,)) + => f32 + + reduce(f32 o (4, 5)) + => f32 + reduce(f32 o (4, (5, 4)), reduction_profile=(None, 1)) + => f32 o (4,) + reduce(f32 o (4, (5, 4)), reduction_profile=(None, (None, 1))) + => f32 o (4, (5,)) + """ + # short-cut to no-op + if reduction_profile is None: + return self + + if not is_weakly_congruent(reduction_profile, self.shape): + raise ValueError( + f"Expected reduction_profile be weakly congruent to the shape of the tensor, " + f"but got {reduction_profile} and {self.shape}" + ) + + if op is ReductionOp.ADD: + red_kind = vector.CombiningKind.ADD + elif op is ReductionOp.MUL: + red_kind = vector.CombiningKind.MUL + elif op is ReductionOp.MAX: + red_kind = vector.CombiningKind.MAXIMUMF + elif op is ReductionOp.MIN: + red_kind = vector.CombiningKind.MINIMUMF + else: + raise NotImplementedError( + f"{op} is not supported, expected one of " + f"{ReductionOp.ADD, ReductionOp.MUL, ReductionOp.MAX, ReductionOp.MIN}" + ) + + elem_type = self.element_type + # Canonicalize to `Numeric` and convert into MLIR value + init_val = ( + as_numeric(init_val).to(elem_type, loc=loc, ip=ip).ir_value(loc=loc, ip=ip) + ) + + if depth(reduction_profile) == 0: + return vector.reduction( + elem_type.mlir_type, red_kind, self, acc=init_val, loc=loc, ip=ip + ) + + flat_shp, flat_prof = self._flatten_shape_and_coord( + reduction_profile, loc=loc, ip=ip + ) + assert isinstance(flat_prof, tuple), ( + f"Expected flat_prof to be a tuple, got {type(flat_prof)}" + ) + assert depth(flat_shp) == 1 and depth(flat_prof) == 1 + assert rank(flat_shp) == rank(flat_prof) + + temp_ty = ir.VectorType.get(list(flat_shp), elem_type.mlir_type) + temp_vect = vector.shape_cast(temp_ty, self, loc=loc, ip=ip) + + red_dims = [i for i, x in enumerate(flat_prof) if x is not None] + + temp_acc_shp = slice_(flat_shp, flat_prof, loc=loc, ip=ip) + temp_acc_ty = ir.VectorType.get(list(temp_acc_shp), elem_type.mlir_type) + + init_val = vector.broadcast(temp_acc_ty, init_val, loc=loc, ip=ip) + res_vect = vector.multi_reduction( + red_kind, temp_vect, acc=init_val, reduction_dims=red_dims, loc=loc, ip=ip + ) + + # Slice and keep dims matching `_` or None + res_shp = slice_(self.shape, reduction_profile, loc=loc, ip=ip) + return self._build_result(res_vect, res_shp, loc=loc, ip=ip) + + +@dsl_user_op +def full(shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> TensorSSA: + """ + Return a new TensorSSA of given shape and type, filled with fill_value. + + :param shape: Shape of the new tensor. + :type shape: tuple + :param fill_value: Value to fill the tensor with. + :type fill_value: scalar + :param dtype: Data type of the tensor. + :type dtype: Type[Numeric] + :return: Tensor of fill_value with the specified shape and dtype. + :rtype: TensorSSA + """ + size = product(shape, loc=loc, ip=ip) + if not is_static(size): + raise ValueError("shape must be static") + + if isinstance(fill_value, (ir.Value, int, float, bool)): + fill_value = dtype(fill_value) + elif isinstance(fill_value, Numeric): + fill_value = fill_value.to(dtype, loc=loc, ip=ip) + else: + raise ValueError(f"Expected fill_value be numeric type, but got {fill_value}") + + res_ty = T.vector(size, dtype.mlir_type) + res_val = vector.splat(res_ty, fill_value.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + return TensorSSA(res_val, shape, dtype) + + +@dsl_user_op +def full_like( + a: Union[TensorSSA, Tensor], + fill_value, + dtype: Union[None, Type[Numeric]] = None, + *, + loc=None, + ip=None, +) -> TensorSSA: + """ + Return a full TensorSSA with the same shape and type as a given array. + + :param a: The shape and data-type of `a` define these same attributes of the returned array. + :type a: array_like + :param fill_value: Fill value. + :type fill_value: array_like + :param dtype: Overrides the data type of the result, defaults to None + :type dtype: Union[None, Type[Numeric]], optional + :return: Tensor of `fill_value` with the same shape and type as `a`. + :rtype: TensorSSA + + .. seealso:: + :func:`empty_like`: Return an empty array with shape and type of input. + :func:`ones_like`: Return an array of ones with shape and type of input. + :func:`zeros_like`: Return an array of zeros with shape and type of input. + :func:`full`: Return a new array of given shape filled with value. + + **Examples:** + + .. code-block:: python + + frg = cute.make_rmem_tensor((2, 3), Float32) + a = frg.load() + b = cute.full_like(a, 1.0) + """ + if not hasattr(a, "shape"): + raise TypeError(f"Expected `a` be shaped type, but got {type(a)}") + + res_dtype = dtype if dtype is not None else a.dtype # type: ignore + return full(a.shape, fill_value, res_dtype, loc=loc, ip=ip) + + +@dsl_user_op +def empty_like(a, dtype=None, *, loc=None, ip=None): + """ + Return a new TensorSSA with the same shape and type as a given array, without initializing entries. + + :param a: The shape and data-type of `a` define these same attributes of the returned array. + :type a: TensorSSA + :param dtype: Overrides the data type of the result, defaults to None + :type dtype: Type[Numeric], optional + :return: Uninitialized tensor with the same shape and type (unless overridden) as `a`. + :rtype: TensorSSA + """ + return full_like(a, 0, dtype, loc=loc, ip=ip) + + +@dsl_user_op +def ones_like(a, dtype=None, *, loc=None, ip=None): + """ + Return a TensorSSA of ones with the same shape and type as a given array. + + :param a: The shape and data-type of `a` define these same attributes of the returned array. + :type a: TensorSSA + :param dtype: Overrides the data type of the result, defaults to None + :type dtype: Type[Numeric], optional + :return: Tensor of ones with the same shape and type (unless overridden) as `a`. + :rtype: TensorSSA + """ + return full_like(a, 1, dtype, loc=loc, ip=ip) + + +@dsl_user_op +def zeros_like(a, dtype=None, *, loc=None, ip=None): + """ + Return a TensorSSA of zeros with the same shape and type as a given array. + + :param a: The shape and data-type of `a` define these same attributes of the returned array. + :type a: TensorSSA + :param dtype: Overrides the data type of the result, defaults to None + :type dtype: Type[Numeric], optional + :return: Tensor of zeros with the same shape and type (unless overridden) as `a`. + :rtype: TensorSSA + """ + return full_like(a, 0, dtype, loc=loc, ip=ip) + + +@dsl_user_op +def where( + cond: TensorSSA, + x: Union[TensorSSA, Numeric], + y: Union[TensorSSA, Numeric], + *, + loc=None, + ip=None, +) -> TensorSSA: + """ + Return elements chosen from x or y depending on condition; will auto broadcast x or y if needed. + + :param cond: Where True, yield x, where False, yield y. + :type cond: TensorSSA + :param x: Values from which to choose when condition is True. + :type x: Union[TensorSSA, Numeric] + :param y: Values from which to choose when condition is False. + :type y: Union[TensorSSA, Numeric] + :return: A tensor with elements from x where condition is True, and elements from y where condition is False. + :rtype: TensorSSA + """ + + # Helper function to promote scalars to tensors or broadcast tensors to target shape + def promote_and_broadcast(v, shape): + if isinstance(v, TensorSSA): + return v.broadcast_to(shape) + elif isinstance(v, (bool, int, float, ir.Value, Numeric)): + v = as_numeric(v) + return full(shape, v, v.dtype) + else: + raise ValueError(f"cannot promote {type(v)} to tensor") + + # Determine shapes for broadcasting - at least one input must be a tensor + x_is_tensor = isinstance(x, TensorSSA) + y_is_tensor = isinstance(y, TensorSSA) + if not (x_is_tensor or y_is_tensor): + raise ValueError( + f"at least one of x and y must be tensor, but got {type(x)} and {type(y)}" + ) + x_shape = x.shape if x_is_tensor else y.shape + y_shape = y.shape if y_is_tensor else x.shape + + # Promote both operands to tensors with broadcast shape + res_shape = _infer_broadcast_shape(cond.shape, x_shape, y_shape) + cond = promote_and_broadcast(cond, res_shape) + x = promote_and_broadcast(x, res_shape) + y = promote_and_broadcast(y, res_shape) + + if x.dtype != y.dtype: + raise ValueError( + f"x and y must have the same dtype, but got {x.dtype} and {y.dtype}" + ) + + if cond.dtype != Boolean: + raise ValueError(f"cond must be Boolean type, but got {cond.dtype}") + + cond_val = cond.ir_value(loc=loc, ip=ip) + res_val = arith.select(cond_val, x, y, loc=loc, ip=ip) + return TensorSSA(res_val, x.shape, x.dtype) + + +@dsl_user_op +def any_(x: TensorSSA, *, loc=None, ip=None) -> Boolean: + """ + Test whether any tensor element evaluates to True. + + :param x: Input tensor. + :type x: TensorSSA + :return: Returns a TensorSSA scalar containing True if any element of x is True, False otherwise. + :rtype: TensorSSA + """ + is_true = x != full_like(x, 0, x.dtype, loc=loc, ip=ip) + return Boolean( + vector.reduction(T.bool(), vector.CombiningKind.OR, is_true, loc=loc, ip=ip) + ) + + +@dsl_user_op +def all_(x: TensorSSA, *, loc=None, ip=None) -> Boolean: + """ + Test whether all tensor elements evaluate to True. + + :param x: Input tensor. + :type x: TensorSSA + :return: Returns a TensorSSA scalar containing True if all elements of x are True, False otherwise. + :rtype: TensorSSA + """ + is_true = x != full_like(x, 0, x.dtype, loc=loc, ip=ip) + return Boolean( + vector.reduction(T.bool(), vector.CombiningKind.AND, is_true, loc=loc, ip=ip) + ) diff --git a/python/CuTeDSL/cutlass/cute/testing.py b/python/CuTeDSL/cutlass/cute/testing.py index 88e0da04..ebb9824c 100644 --- a/python/CuTeDSL/cutlass/cute/testing.py +++ b/python/CuTeDSL/cutlass/cute/testing.py @@ -13,37 +13,37 @@ import functools import inspect import logging import os -from enum import Enum -from inspect import isclass from itertools import product from time import time -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Type, Union, Callable, Optional, Dict, List, Any import cuda.bindings.driver as cuda_driver import cuda.bindings.runtime as cuda_runtime -import numpy as np -import cutlass._mlir.ir as ir import cutlass.base_dsl.jit_executor +from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, dsl_user_op + +from .typing import Numeric, Int8, Boolean + import cutlass.cute as cute +from cutlass.cute import nvgpu + from cutlass._mlir.dialects import builtin, cf, nvvm, vector -from cutlass.cute import core, nvgpu -from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, t, dsl_user_op @dsl_user_op def assert_(cond, msg=None, *, loc=None, ip=None): - cf.assert_(t.Boolean(cond).ir_value(), msg if msg else "", loc=loc, ip=ip) + cf.assert_(Boolean(cond).ir_value(), msg if msg else "", loc=loc, ip=ip) -def _maybe_recast_tensor_from_f4(src: core.Tensor, tv_layout: core.Layout): +def _maybe_recast_tensor_from_f4(src: cute.Tensor, tv_layout: cute.Layout): if src.element_type.width == 4: - tv_layout = core.recast_layout(8, 4, tv_layout) - src = core.recast_tensor(src, dtype=t.Int8) + tv_layout = cute.recast_layout(8, 4, tv_layout) + src = cute.recast_tensor(src, dtype=Int8) return src, tv_layout -def _maybe_recast_to_f4(input: core.TensorSSA, dtype: Type[core.Numeric]): +def _maybe_recast_to_f4(input: cute.TensorSSA, dtype: Type[Numeric]): """Conditionally recasts the tensor to 4-bit type if the destination type is 4-bit. :param input: The input tensor to recast. @@ -51,22 +51,22 @@ def _maybe_recast_to_f4(input: core.TensorSSA, dtype: Type[core.Numeric]): :raises TypeError: If dtype is not a subclass of Numeric. :return: A new tensor recast to 4-bit if dtype is 4-bit, otherwise returns self unchanged. """ - if not isclass(dtype) or not issubclass(dtype, core.Numeric): + if not inspect.isclass(dtype) or not issubclass(dtype, Numeric): raise TypeError(f"dst_ty must be a type of Numeric, but got {dtype}") if dtype.width == 4: - recast_shape = core.recast_layout(4, 8, core.make_layout(input.shape)).shape + recast_shape = cute.recast_layout(4, 8, cute.make_layout(input.shape)).shape i4_vec = vector.bitcast( T.vector(input.type.shape[0] * 2, T.i(4)), input.maybe_downcast() ) res_vect = builtin.unrealized_conversion_cast( [T.vector(i4_vec.type.shape[0], dtype.mlir_type)], [i4_vec] ) - return core.TensorSSA(res_vect, recast_shape, dtype) + return cute.TensorSSA(res_vect, recast_shape, dtype) return input -def _maybe_recast_from_f4(input: core.TensorSSA, src_dtype: Type[core.Numeric]): +def _maybe_recast_from_f4(input: cute.TensorSSA, src_dtype: Type[Numeric]): """Conditionally recasts the tensor from 4-bit type if the source type is 4-bit. :param input: The input tensor to recast. @@ -74,27 +74,27 @@ def _maybe_recast_from_f4(input: core.TensorSSA, src_dtype: Type[core.Numeric]): :raises TypeError: If src_dtype is not a subclass of Numeric. :return: A new tensor recast from 4-bit if src_dtype is 4-bit, otherwise returns self unchanged. """ - if not isclass(src_dtype) or not issubclass(src_dtype, core.Numeric): + if not inspect.isclass(src_dtype) or not issubclass(src_dtype, Numeric): raise TypeError(f"src_ty must be a type of Numeric, but got {src_dtype}") if src_dtype.width == 4: - recast_shape = core.recast_layout(8, 4, core.make_layout(input.shape)).shape + recast_shape = cute.recast_layout(8, 4, cute.make_layout(input.shape)).shape i4_vec = builtin.unrealized_conversion_cast( [T.vector(input.type.shape[0], T.i(4))], [input.maybe_downcast()] ) res_vect = vector.bitcast(T.vector(i4_vec.type.shape[0] // 2, T.i8()), i4_vec) - return core.TensorSSA(res_vect, recast_shape, core.Int8) + return cute.TensorSSA(res_vect, recast_shape, Int8) return input @CuTeDSL.kernel def _convert_kernel( - gSrc: core.Tensor, - gDst: core.Tensor, - cSrc: core.Tensor, - src_tv_layout: core.Layout, - dst_tv_layout: core.Layout, - src_shape: core.Shape, + gSrc: cute.Tensor, + gDst: cute.Tensor, + cSrc: cute.Tensor, + src_tv_layout: cute.Layout, + dst_tv_layout: cute.Layout, + src_shape: cute.Shape, src_ty, dst_ty, ): @@ -110,9 +110,9 @@ def _convert_kernel( # compose with CTA TV layout # tid, vid -> address - tidfrgSrc = core.composition(ctaSrc, src_tv_layout) # (T,V) - tidfrgDst = core.composition(ctaDst, dst_tv_layout) # (T,V) - tidfrgCSrc = core.composition(ctaCSrc, src_tv_layout) # (T,V) + tidfrgSrc = cute.composition(ctaSrc, src_tv_layout) # (T,V) + tidfrgDst = cute.composition(ctaDst, dst_tv_layout) # (T,V) + tidfrgCSrc = cute.composition(ctaCSrc, src_tv_layout) # (T,V) # print(f"tidfrgSrc = {tidfrgSrc.type}") # slice for threads @@ -123,19 +123,19 @@ def _convert_kernel( # print(f"thrSrc = {thrSrc.type}") # predicate - if core.elem_less(thrCSrc[0], src_shape): + if cute.elem_less(thrCSrc[0], src_shape): # allocate fragments for gmem->rmem - frgSrc = core.make_fragment( - core.get(src_tv_layout, mode=[1]), gSrc.element_type + frgSrc = cute.make_rmem_tensor( + cute.get(src_tv_layout, mode=[1]), gSrc.element_type ) # (V) - frgDst = core.make_fragment( - core.get(dst_tv_layout, mode=[1]), gDst.element_type + frgDst = cute.make_rmem_tensor( + cute.get(dst_tv_layout, mode=[1]), gDst.element_type ) # (V) # print(f"frgSrc = {frgSrc.type}") # Move data to reg address space - copy_atom_load = core.make_copy_atom(nvgpu.CopyUniversalOp(), gSrc.element_type) - core.copy(copy_atom_load, thrSrc, frgSrc) + copy_atom_load = cute.make_copy_atom(nvgpu.CopyUniversalOp(), gSrc.element_type) + cute.copy(copy_atom_load, thrSrc, frgSrc) vec_src = frgSrc.load() vec_src = _maybe_recast_to_f4(vec_src, src_ty) @@ -144,49 +144,48 @@ def _convert_kernel( frgDst.store(vec_dst) # Copy the results back to c - copy_atom_stg = core.make_copy_atom(nvgpu.CopyUniversalOp(), gDst.element_type) - core.copy(copy_atom_stg, frgDst, thrDst) + copy_atom_stg = cute.make_copy_atom(nvgpu.CopyUniversalOp(), gDst.element_type) + cute.copy(copy_atom_stg, frgDst, thrDst) @CuTeDSL.jit(preprocess=False) def _convert( - src: core.Tensor, - dst: core.Tensor, + src: cute.Tensor, + dst: cute.Tensor, leading_mode: Constexpr, elem_per_copy: Constexpr, ): - # Step 1. figure proper tv_layout src_ty = src.element_type dst_ty = dst.element_type - tv_layout = core.make_layout((128, elem_per_copy), stride=(elem_per_copy, 1)) + tv_layout = cute.make_layout((128, elem_per_copy), stride=(elem_per_copy, 1)) # Step 2. maybe recast from f4 tensor src, src_tv_layout = _maybe_recast_tensor_from_f4(src, tv_layout) dst, dst_tv_layout = _maybe_recast_tensor_from_f4(dst, tv_layout) src_shape = src.shape # predicate tensor - idA = core.make_identity_tensor(src.shape) + idA = cute.make_identity_tensor(src.shape) # Step 3. select a proper tiling pattern as (...,TileV, ...) src_cta_tiler = [ 1, - ] * core.rank(src.layout) - src_cta_tiler[leading_mode] = core.size(src_tv_layout) # (...,TileV,...) + ] * cute.rank(src.layout) + src_cta_tiler[leading_mode] = cute.size(src_tv_layout) # (...,TileV,...) dst_cta_tiler = [ 1, - ] * core.rank(dst.layout) - dst_cta_tiler[leading_mode] = core.size(dst_tv_layout) # (...,TileV,...) + ] * cute.rank(dst.layout) + dst_cta_tiler[leading_mode] = cute.size(dst_tv_layout) # (...,TileV,...) # Step 4. partition input and output tensor by cta tiler. - gS = core.zipped_divide( + gS = cute.zipped_divide( src, tuple(src_cta_tiler) ) # ((...,TileV,...),(...,RestV,...)) - cS = core.zipped_divide( + cS = cute.zipped_divide( idA, tuple(src_cta_tiler) ) # ((...,TileV,...),(...,RestV,...)) - gD = core.zipped_divide( + gD = cute.zipped_divide( dst, tuple(dst_cta_tiler) ) # ((...,TileV,...),(...,RestV,...)) # print(f"{gS.type=}") @@ -201,8 +200,8 @@ def _convert( src_ty, dst_ty, ).launch( - grid=[core.size(gS, mode=[1]), 1, 1], - block=[core.size(src_tv_layout, mode=[0]), 1, 1], + grid=[cute.size(gS, mode=[1]), 1, 1], + block=[cute.size(src_tv_layout, mode=[0]), 1, 1], ) @@ -210,10 +209,10 @@ def _convert( # And when src or dst dtype is narrow precision(Float4E2M1FN/Float8E8M0FNU/Float8E4M3FN), the shape of # their leading dimension should be 4(fp8)/8(fp4) element align. (nvgpu.cvt_fptrunc/cvt_fpext # needs 32-bits aligned input/output) -def convert(src: core.Tensor, dst: core.Tensor): - assert len(src.shape) == len( - dst.shape - ), "Shape of src and dst tensors should be the same rank." +def convert(src: cute.Tensor, dst: cute.Tensor): + assert len(src.shape) == len(dst.shape), ( + "Shape of src and dst tensors should be the same rank." + ) # find leading mode leading_mode = [ idx @@ -329,9 +328,9 @@ def _does_kernel_use_stream( :rtype: bool """ - assert int(stream) != int( - cuda_driver.CUstream_flags.CU_STREAM_DEFAULT - ), "Stream must be a non-default stream" + assert int(stream) != int(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT), ( + "Stream must be a non-default stream" + ) err = cuda_runtime.cudaStreamBeginCapture( stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal @@ -474,8 +473,13 @@ def benchmark( elapsed_time = float("nan") if use_cuda_graphs: - # Check if the callable is a JitExecutor - if not isinstance(callable, cutlass.base_dsl.jit_executor.JitExecutor): + # Check if the callable is a JitCompiledFunction or JitExecutor + # These are functions that can be called to launch kernels + compiled_types = ( + cutlass.base_dsl.jit_executor.JitCompiledFunction, + cutlass.base_dsl.jit_executor.JitExecutor, + ) + if not isinstance(callable, compiled_types): raise TypeError("Function must be precompiled to be used with CUDA Graphs") # Check if the stream is a non-default stream @@ -502,7 +506,7 @@ def benchmark( # Assertion is >= since we may launch multiple kernels in one host function if num_nodes < warmup_iterations: raise ValueError( - f"CUDA stream passed to benchmark does not match the stream the kernel was launched in" + "CUDA stream passed to benchmark does not match the stream the kernel was launched in" ) # Capture profiling graph @@ -549,7 +553,6 @@ def benchmark( _cuda_success(err, "Error on destroying graph") else: - if int(stream) != int( cuda_driver.CUstream_flags.CU_STREAM_DEFAULT ) and not _does_kernel_use_stream( @@ -599,12 +602,474 @@ def get_workspace_count( :rtype: int """ num_l2_cache_bytes = cutlass.utils.HardwareInfo().get_l2_cache_size_in_bytes() - return max( - 1, - min( - warmup_iterations + iterations, # Don't create more workspaces than needed - (num_l2_cache_bytes + one_workspace_bytes - 1) - // one_workspace_bytes, # Ceiling division - ), - ) + num_workspaces = (num_l2_cache_bytes * 3) // one_workspace_bytes + 1 + num_iters = warmup_iterations + iterations + return num_iters if num_iters < num_workspaces else num_workspaces + +######################################### +# Autotuning/Tuning utilities +######################################### + + +def _benchmark_for_autotune( + callable: Callable, + *args, + warmup_iterations: int, + iterations: int, + use_cold_l2: bool, + print_verbose: bool, + current_stream: Optional[cuda_driver.CUstream] = None, + **kwargs, +) -> float: + """Benchmarks a callable function with the specified parameters. + + This function differs from the benchmark function in that it is used for autotuning. In this case we + do not loop through workspaces to keep the L2 cache cold. Instead we rely on writing to an L2 cache sized address to keep the L2 cache cold. + + The primary reason for doing this is that we do not have information on how to generate the workspaces for the kernel when autotuning. + We also do not have information on how much memory the workspaces take up. + + This benchmarking is done as a close approximation of the actual runtime of the kernel in an E2E system, + where we may have clock throttling, a warm cache, or other factors that could affect the runtime of the kernel. + + :param callable: The function to benchmark + :type callable: Callable + :param args: Arguments to pass to the callable function + :param warmup_iterations: Number of warmup iterations, defaults to 10 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations, defaults to 100 + :type iterations: int, optional + :param use_cold_l2: Whether to clear L2 cache between runs, defaults to True + :type use_cold_l2: bool, optional + :param print_verbose: Whether to print verbose output, defaults to False + :type print_verbose: bool, optional + :param current_stream: Stream to benchmark in, defaults to CUDA stream default + :type current_stream: CUstream, None + :param kwargs: Additional keyword arguments to pass to the callable function + + :return: The benchmark time in microseconds + :rtype: float + """ + if current_stream is None: + current_stream = cuda_driver.CUstream( + cuda_driver.CUstream_flags.CU_STREAM_DEFAULT + ) + + if int(current_stream) != int( + cuda_driver.CUstream(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT) + ) and not _does_kernel_use_stream(callable, current_stream, *args, **kwargs): + raise ValueError(f"Incorrect stream passed to kernel: {current_stream}") + + if use_cold_l2: + from cutlass.utils import HardwareInfo + + # use memset to clear L2 cache + hardware_info = HardwareInfo() + num_l2_cache_bytes = hardware_info.get_l2_cache_size_in_bytes() + err, cache_ptr = cuda_driver.cuMemAlloc(int(num_l2_cache_bytes)) + _cuda_success(err, "Error on allocating memory") + + # Create CUDA events for timing + err, start_event = cuda_driver.cuEventCreate( + cuda_driver.CUevent_flags.CU_EVENT_DEFAULT + ) + _cuda_success(err, "Error on creating event") + err, end_event = cuda_driver.cuEventCreate( + cuda_driver.CUevent_flags.CU_EVENT_DEFAULT + ) + _cuda_success(err, "Error on creating event") + try: + # warmup + for _ in range(warmup_iterations): + callable(*args, **kwargs) + + time = 0 + execution_time_ms = [] + for _ in range(iterations): + if use_cold_l2: + # clear L2 cache by memset to zero for every run + err = cuda_driver.cuMemsetD32Async( + cache_ptr, 0, int(num_l2_cache_bytes // 4), current_stream + ) + _cuda_success(err, "Error on memset") + err = cuda_driver.cuEventRecord(start_event, current_stream) + _cuda_success(err, "Error on recording event") + callable(*args, **kwargs) + err = cuda_driver.cuEventRecord(end_event, current_stream) + _cuda_success(err, "Error on recording event") + err = cuda_driver.cuEventSynchronize(end_event) + _cuda_success(err, "Error on synchronizing event") + err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event) + _cuda_success(err, "Error on querying event") + execution_time_ms.append(elapsed_time) + # unit: us + time_us = sum(execution_time_ms) / len(execution_time_ms) + except Exception as e: + print(f"This config execution error: {e}") + time_us = float("inf") + if print_verbose: + print(f"Execution time: {time_us:.4f} us") + + if use_cold_l2: + err = cuda_driver.cuMemFree(cache_ptr) + _cuda_success(err, "Error on freeing memory") + err = cuda_driver.cuEventDestroy(start_event) + _cuda_success(err, "Error on destroying event") + err = cuda_driver.cuEventDestroy(end_event) + _cuda_success(err, "Error on destroying event") + return time_us + + +class autotune_jit: + """Auto-tuning tool supporting both dictionary and parameterized decorator styles. + The autotune_jit class can be used as a decorator or a function. + When used as a decorator, it will automatically tune the function based on the parameters. + When used as a function, it will return a decorator that can be used to decorate a function. + For example: + .. code-block:: python + + @autotune_jit(params_dict={'param1': [1, 2, 3], 'param2': [4, 5, 6]}, update_on_change=['param3']) + @cute.jit + def user_function(param1=1, param2=2, param3=3): + # contents of the function + pass + + The function will be automatically tuned over all combinations of param1 and param2 whenever param3 changes . + For non-specified parameters, the default value in user_function will be used (e.g., `param3` in `user_function`). + .. code-block:: python + user_function(a, b, c) # Autotunes code + user_function(a, b, c) # This call pulls the best kernel from cache + + Known Limitations: + - Only supports functions that are decorated with cute.jit + - If the function which is decorated with cute.jit is call method of a class, and the class has internal state that + is used as constexpr arguments in the function, the autotuner will not be able to find the best configuration. + + Note: The autotuner has the same semantics as cute.compile. If the function is compiled, but global variables are changed, + the autotuner will not recompile the kernel. + """ + + logger = None + + @classmethod + def _initialize_logger(cls): + """Ensure the logger is initialized""" + if cls.logger is None: + cls.logger = logging.getLogger(__name__ + "_Autotune") + if not cls.logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + cls.logger.addHandler(handler) + if ( + os.environ.get("CUTE_DSL_LOG_AUTOTUNE") is not None + and os.environ.get("CUTE_DSL_LOG_AUTOTUNE") != "0" + ): + cls.logger.setLevel(logging.INFO) + + @classmethod + def _create_tuning_wrapper( + cls, func, warmup_iterations, iterations, autotune_update_params + ): + """Create a wrapper function that performs auto-tuning + + Args: + func: Original function + + Returns: + Decorated wrapper function + """ + + # Initialize autotune parameters + if not hasattr(func, "_autotune_params"): + func._original_func = func + func._autotune_params = {} + func._autotune_update_params = autotune_update_params + func._best_kernel = dict() + func._best_config = dict() + + # Create wrapper function for auto-tuning + @functools.wraps(func) + def tuning_wrapper(*args, **kwargs): + parameters = inspect.signature(func._original_func).parameters.keys() + tuning_key = list() + for param_name in func._autotune_update_params: + if param_name in kwargs.keys(): + tuning_key.append(kwargs[param_name]) + else: + index = list(parameters).index(param_name) + if index < len(args): + tuning_key.append(args[index]) + tuning_key = tuple(tuning_key) + if tuning_key in func._best_kernel.keys(): + cls.logger.info( + f"Using cached best configuration: {func._best_config[tuning_key]}" + ) + return func._best_kernel[tuning_key](*args, **kwargs) + + # Get all parameter configurations + params_dict = func._autotune_params + keys = list(params_dict.keys()) + values = list(params_dict.values()) + + min_time = float("inf") + + best_kernel = None + # Record start time + start = time() + + # Iterate through all possible configuration combinations + for config_values in product(*values): + # Build current configuration + current_config = dict(zip(keys, config_values)) + cls.logger.info(f"Tuning configuration: {current_config}") + + try: + # Call the original function, using current configuration to replace default parameters + # For example, if current_config contains "cluster_shape_mn": (2, 1) + # It will override func's default parameter value + merged_kwargs = {**kwargs, **current_config} + compiled_func = cute.compile( + func._original_func, *args, **merged_kwargs + ) + + # Detect which constexpr arguments we need to remove from args and merged_kwargs + # This is done because after compiling our function signature will change, removing all constexpr arguments. + indexes_to_remove = list() + for arg in compiled_func.args_spec.get_constexpr_args(): + if arg["argument_name"] in merged_kwargs: + del merged_kwargs[arg["argument_name"]] + elif arg["argument_index"] is not None: + indexes_to_remove.append(arg["argument_index"]) + if arg["argument_name"] not in func._autotune_update_params: + # Handle the case where the programmer avoided autotuning over constexpr values, and + # recompile in that case + func._autotune_update_params.append( + arg["argument_name"] + ) + + # Remove constexpr arguments from args + args_no_constexpr = list(args) + for index in sorted(indexes_to_remove, reverse=True): + del args_no_constexpr[index] + + # Benchmark the compiled function + cur_time = _benchmark_for_autotune( + compiled_func, + *args_no_constexpr, + warmup_iterations=warmup_iterations, + iterations=iterations, + use_cold_l2=True, + print_verbose=False, + **merged_kwargs, + ) + + cls.logger.info(f" Execution time: {cur_time} us") + + # Update best results + if cur_time < min_time: + min_time = cur_time + best_kernel = compiled_func + best_config = current_config + + except NotImplementedError as e: + cls.logger.info( + f" Encountered unimplemented error, abort execution: {e}" + ) + raise e + except (ValueError, TypeError) as e: + cls.logger.info(f" Configuration parameter skipping: {e}") + raise e + continue + except Exception as e: + cls.logger.info(f" Execution error skipping: {e}") + raise e + continue + + end = time() + tuning_time = end - start + + if best_kernel is None: + raise ValueError("No best kernel found") + + cls.logger.info( + f"Best configuration: {best_config}, execution time: {min_time} us" + ) + cls.logger.info(f"Total tuning time: {tuning_time} s") + func._best_kernel[tuning_key] = best_kernel + func._best_config[tuning_key] = best_config + return best_kernel(*args, **kwargs) + + # Append autotune wrapper to not conflict with the jit kernel names + tuning_wrapper.__name__ = func.__name__ + "_autotune_wrapper" + tuning_wrapper.__qualname__ = func.__qualname__ + "_autotune_wrapper" + + return tuning_wrapper + + return func # If already has a wrapper, return the original function + + def __init__( + self, + params_dict: Dict[str, List[Any]] = None, + update_on_change: List[str] = None, + warmup_iterations=10, + iterations=100, + ): + """Initialize the autotune_jit decorator. + + :param params_dict: Dictionary containing parameter names and their possible values + :type params_dict: Dict[str, List[Any]], optional + :param update_on_change: Whether to retune when the parameters changes, defaults to None + :type update_on_change: bool, optional + :param warmup_iterations: Number of warmup iterations, defaults to 100 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations, defaults to 100 + :type iterations: int, optional + """ + # Initialize logger + self._initialize_logger() + + # Save parameter dictionary + self.params_dict = params_dict or {} + self.update_on_change = update_on_change or list() + + # Save iterations + self.warmup_iterations = warmup_iterations + self.iterations = iterations + + def __call__(self, func): + """Called when class instance is used as a decorator. + + :param func: Function to be decorated + :type func: Callable + :return: Decorated function + :rtype: Callable + """ + # Create wrapper function + decorated_func = self._create_tuning_wrapper( + func, self.warmup_iterations, self.iterations, self.update_on_change + ) + + # Use the wrapper if it exists, otherwise use the original function + result_func = ( + decorated_func if hasattr(decorated_func, "_autotune_params") else func + ) + + # Add parameters from the dictionary to the function's autotune parameters + for param_name, param_values in self.params_dict.items(): + result_func._autotune_params[param_name] = param_values + + return result_func + + +def tune( + func: Callable[[Any], Callable[[], Any]], + params_dict: Dict[str, List[Any]] = None, + kernel_arguments: JitArguments = JitArguments(), + warmup_iterations=10, + iterations=100, + stream: Optional[cuda_driver.CUstream] = None, +) -> Dict[str, Any]: + """Tuning tool to suport arbitrary functions. The user must provide a function that returns a callable, which + takes no arguments to be tuned over. + Best practice is to return a jit function that is compiled with cute.compile for optimal performance. + For example: + .. code-block:: python + + def user_function(param1=1, param2=2, param3=3) -> Callable[[], Any]: + # contents of the function + return lambda : compiled_func(param1, param2, param3) + + config = tune(user_function, params_dict={'param1': [1, 2, 3], 'param2': [4, 5, 6]}, update_on_change=['param3']) + + :param func: Function to be tuned, note that errors raised in the function will be ignored and the next configuration will be tried. + :type func: Callable[[Any], Callable[[], Any]] + :param params_dict: Dictionary containing parameter names and their possible values + :type params_dict: Dict[str, List[Any]], optional + :param kernel_arguments: Kernel arguments to launch callable with, defaults to JitArguments() + :type kernel_arguments: JitArguments, optional + :param warmup_iterations: Number of warmup iterations, defaults to 10 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations, defaults to 100 + :type iterations: int, optional + :param stream: Stream kernel is launched in, defaults to CUDA stream default + :type stream: CUstream, None + :return: Best configuration + :rtype: Dict[str, Any] + """ + logger = logging.getLogger(__name__ + "_Autotune") + if not logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + if ( + os.environ.get("CUTE_DSL_LOG_AUTOTUNE") is not None + and os.environ.get("CUTE_DSL_LOG_AUTOTUNE") != "0" + ): + logger.setLevel(logging.INFO) + + if stream is None: + stream = cuda_driver.CUstream(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT) + + # Get all parameter configurations + keys = list(params_dict.keys()) + values = list(params_dict.values()) + + min_time = float("inf") + + best_config = None + # Record start time + start = time() + + # Iterate through all possible configuration combinations + for config_values in product(*values): + # Build current configuration + current_config = dict(zip(keys, config_values)) + logger.info(f"Tuning configuration: {current_config}") + + try: + merged_kwargs = {**kernel_arguments.kwargs, **current_config} + + compiled_func = func(*kernel_arguments.args, **merged_kwargs) + # Benchmark the compiled function + cur_time = _benchmark_for_autotune( + compiled_func, + warmup_iterations=warmup_iterations, + iterations=iterations, + use_cold_l2=True, + print_verbose=False, + current_stream=stream, + ) + + logger.info(f" Execution time: {cur_time} us") + + # Update best results + if cur_time < min_time: + min_time = cur_time + best_config = current_config + + except NotImplementedError as e: + logger.info(f" Encountered unimplemented error, abort execution: {e}") + raise e + except (ValueError, TypeError) as e: + logger.info(f" Configuration parameter skipping: {e}") + continue + except Exception as e: + logger.info(f" Execution error skipping: {e}") + continue + + end = time() + tuning_time = end - start + + if best_config is None: + raise ValueError("No best kernel found") + + logger.info(f"Best configuration: {best_config}, execution time: {min_time} us") + logger.info(f"Total tuning time: {tuning_time} s") + return best_config diff --git a/python/CuTeDSL/cutlass/cute/tuple.py b/python/CuTeDSL/cutlass/cute/tuple.py new file mode 100644 index 00000000..815ef08a --- /dev/null +++ b/python/CuTeDSL/cutlass/cute/tuple.py @@ -0,0 +1,331 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from inspect import signature +from itertools import chain +from typing import Any, Callable, Union, Tuple, List, Iterable + +from cutlass.cutlass_dsl import is_dynamic_expression, dsl_user_op +from cutlass._mlir import ir +import cutlass._mlir.dialects.cute as _cute_ir + +from .typing import XTuple, IntTuple, Shape, Coord, Boolean, is_integer + + +def wrap(x) -> Tuple[Any, ...]: + """ + Wraps the input into a tuple if not a tuple. + """ + if isinstance(x, tuple): + return x + return (x,) + + +def flatten_to_tuple(a: XTuple) -> Tuple[Any, ...]: + """Flattens a potentially nested tuple structure into a flat tuple. + + This function recursively traverses the input structure and flattens it into + a single-level tuple, preserving the order of elements. + + :param a: The structure to flatten + :type a: Union[IntTuple, Coord, Shape, Stride] + :return: A flattened tuple containing all elements from the input + :rtype: tuple + + **Examples:** + + .. code-block:: python + + flatten_to_tuple((1, 2, 3)) # Returns (1, 2, 3) + flatten_to_tuple(((1, 2), 3)) # Returns (1, 2, 3) + flatten_to_tuple((1, (2, (3,)))) # Returns (1, 2, 3) + """ + if not isinstance(a, tuple): + return wrap(a) + else: + return tuple(chain.from_iterable(tuple(flatten_to_tuple(x) for x in a))) + + +def unflatten( + sequence: Union[Tuple[Any, ...], List[Any], Iterable[Any]], profile: XTuple +) -> XTuple: + """Unflatten a flat tuple into a nested tuple structure according to a profile. + + This function transforms a flat sequence of elements into a nested tuple structure + that matches the structure defined by the profile parameter. It traverses the profile + structure and populates it with elements from the sequence. + + sequence must be long enough to fill the profile. Raises RuntimeError if it is not. + + :param sequence: A flat sequence of elements to be restructured + :type sequence: Union[Tuple[Any, ...], List[Any], Iterable[Any]] + :param profile: A nested tuple structure that defines the shape of the output + :type profile: XTuple + :return: A nested tuple with the same structure as profile but containing elements from sequence + :rtype: XTuple + + **Examples:** + + .. code-block:: python + + unflatten([1, 2, 3, 4], ((0, 0), (0, 0))) # Returns ((1, 2), (3, 4)) + """ + + def _make_generator(): + for element in sequence: + yield element + + xs = _make_generator() + return transform_leaf(lambda _: next(xs), profile) + + +@dsl_user_op +def product(a: Union[IntTuple, Shape], *, loc=None, ip=None): + # Local import to avoid circular dependency + from .core import _pack_int_tuple, _unpack_x_tuple + + """Return product of the given IntTuple or Shape. + + Computes the product of all elements in the input tuple or shape. + + Returns static value if type is static otherwise dynamic value. + + :param a: The input tuple or shape + :type a: IntTuple or Shape + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: Static product of IntTuple or Shape if static, otherwise a Value + :rtype: int or Value + :raises TypeError: If input is not an IntTuple or Shape + """ + if is_integer(a): + return a + if isinstance(a, tuple): + a_val = _pack_int_tuple(a, loc=loc, ip=ip) + res = _cute_ir.tuple_product(a_val, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + else: + raise TypeError(f"expects IntTuple or Shape, but got {type(a)}") + + +@dsl_user_op +def product_like(a: IntTuple, target_profile: XTuple, *, loc=None, ip=None) -> IntTuple: + """Return product of the given IntTuple or Shape at leaves of `target_profile`. + + This function computes products according to the structure defined by target_profile. + + :param a: The input tuple or shape + :type a: IntTuple or Shape + :param target_profile: The profile that guides how products are computed + :type target_profile: XTuple + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: The resulting tuple with products computed according to target_profile + :rtype: IntTuple or Shape + :raises TypeError: If inputs have incompatible types + :raises ValueError: If inputs have incompatible shapes + """ + # Perform product at leaf of `target_profile` + if not isinstance(target_profile, tuple): + return product(a, loc=loc, ip=ip) + + if not isinstance(a, tuple): + raise TypeError(f"expects `a` tuple but got {a}") + + if len(a) != len(target_profile): + raise ValueError("expects `a` and `guide` have the same rank") + + return tuple(product_like(x, g, loc=loc, ip=ip) for x, g in zip(a, target_profile)) + + +@dsl_user_op +def product_each(a: IntTuple, *, loc=None, ip=None) -> IntTuple: + from .core import _pack_int_tuple, _unpack_x_tuple + + """Compute products for each component of the input. + + Returns a rank(a) tuple result such that ``get(result, mode=[i]) == product(get(a, mode=[i]))`` + + :param a: The input IntTuple or Shape + :type a: IntTuple or Shape + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: A tuple containing products for each component + :rtype: tuple + :raises TypeError: If input is not an IntTuple or Shape + """ + if is_integer(a): + return a + + if not isinstance(a, tuple): + raise TypeError(f"expects IntTuple or Shape, but got {type(a)}") + + if a == (): + return 1 + + a_val = _pack_int_tuple(a, loc=loc, ip=ip) + res = _cute_ir.tuple_product_each(a_val, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + + +def find_if( + t: Union[tuple, ir.Value, int], + pred_fn: Callable[[int, Tuple[int, ...]], bool], + *, + loc=None, + ip=None, +) -> Union[int, Tuple[int, ...], None]: + from .core import rank, get + + """Find the first position in t where pred_fn(val, pos) returns True. + + :param t: The search space + :type t: Union[tuple, ir.Value, int] + :param pred_fn: A callable object (lambda, function, etc.) that predicates the value and position in t. + It takes the current leaf value and position, returns True if the value or position is satisfied. + :type pred_fn: Callable[[int, Tuple[int, ...]], bool] + :return: Index if found at top level, tuple of indices showing nested position, or None if not found + :rtype: Union[int, Tuple[int, ...], None] + + **Examples:** + + .. code-block:: python + + # Find the first position of x in t + t = (3, 4) + find_if(t, pred_fn=lambda val, pos: val == x) + + .. code-block:: python + + # find the leading dimension + shape = (3, 4) + stride = (4, 1) + # Find value 1 in stride where the corresponding shape is not 1 + def pred_fn(val, pos): + mode = [pos] if isinstance(pos, int) else list(pos) + return val == 1 and get(shape, mode) != 1 + find_if(stride, pred_fn=pred_fn) + """ + + def _find_if_impl(curr, pos, *, loc=None, ip=None): + if isinstance(curr, tuple): + # Recursively search nested tuple + for i in range(rank(curr)): + sub_curr = get(curr, mode=[i], loc=loc, ip=ip) + sub_pos = (pos, i) if isinstance(pos, int) else pos + (i,) + res_pos = _find_if_impl(sub_curr, sub_pos, loc=loc, ip=ip) + if res_pos is not None: + return res_pos + else: + # For leaf values, check if it matches x + if pred_fn(curr, pos): + return pos + return None + + def _check_pred_fn(): + if not callable(pred_fn): + raise TypeError(f"pred_fn must be callable, but got {type(pred_fn)}") + + sig = signature(pred_fn) + if len(sig.parameters) != 2: + raise ValueError( + f"pred_fn must have two parameters (value, pos), but got {len(sig.parameters)}" + ) + + _check_pred_fn() + + for i in range(rank(t)): + curr = get(t, mode=[i], loc=loc, ip=ip) + res_pos = _find_if_impl(curr, i, loc=loc, ip=ip) + if res_pos is not None: + return res_pos + return None + + +@dsl_user_op +def find( + t: Union[tuple, ir.Value, int], x: int, *, loc=None, ip=None +) -> Union[int, Tuple[int, ...], None]: + """Find the first position of a value ``x`` in a hierarchical structure ``t``. + + Searches for the first occurrence of x in t, optionally excluding positions + where a comparison value matches. The search can traverse nested structures + and returns either a single index or a tuple of indices for nested positions. + + :param t: The search space + :type t: Union[tuple, ir.Value, int] + :param x: The static integer x to search for + :type x: int + :return: Index if found at top level, tuple of indices showing nested position, or None if not found + :rtype: Union[int, Tuple[int, ...], None] + """ + if not isinstance(x, int): + raise TypeError(f"find() requires a static x to search for, but got {x}") + + def pred_fn(val, pos): + # Skip dynamic values which can't be compared + return not is_dynamic_expression(val) and val == x + + return find_if(t, pred_fn=pred_fn, loc=loc, ip=ip) + + +def transform_leaf(f, *args): + """ + Apply a function to the leaf nodes of nested tuple structures. + + This function traverses nested tuple structures in parallel and applies the function f + to corresponding leaf nodes. All input tuples must have the same nested structure. + + :param f: Function to apply to leaf nodes + :type f: Callable + :param args: One or more nested tuple structures with matching profiles + :return: A new nested tuple with the same structure as the inputs, but with leaf values transformed by f + :raises TypeError: If the input tuples have different nested structures + + **Example:** + + .. code-block:: python + + >>> transform_leaf(lambda x: x + 1, (1, 2)) + (2, 3) + >>> transform_leaf(lambda x, y: x + y, (1, 2), (3, 4)) + (4, 6) + >>> transform_leaf(lambda x: x * 2, ((1, 2), (3, 4))) + ((2, 4), (6, 8)) + """ + if all(isinstance(t, tuple) for t in args): + return tuple(transform_leaf(f, *_args) for _args in zip(*args)) + elif all(not isinstance(t, tuple) for t in args): + return f(*args) + else: + raise TypeError(f"profile of input tuples doesn't match: {args}") + + +@dsl_user_op +def elem_less( + lhs: Union[Shape, IntTuple, Coord], + rhs: Union[Shape, IntTuple, Coord], + *, + loc=None, + ip=None, +) -> Boolean: + from .core import _pack_coord + + # Coord is super set of IntTuple and Shape + lhs_val = _pack_coord(lhs, loc=loc, ip=ip) + rhs_val = _pack_coord(rhs, loc=loc, ip=ip) + return Boolean(_cute_ir.elem_less(lhs_val, rhs_val, loc=loc, ip=ip)) diff --git a/python/CuTeDSL/cutlass/cute/typing.py b/python/CuTeDSL/cutlass/cute/typing.py index 215e71d9..3048f550 100644 --- a/python/CuTeDSL/cutlass/cute/typing.py +++ b/python/CuTeDSL/cutlass/cute/typing.py @@ -10,13 +10,12 @@ # is strictly prohibited. from abc import ABC, abstractmethod -from typing import ForwardRef, Tuple, Union, Any, Type, List +from typing import ForwardRef, Tuple, Union, Any, Type, List, Optional from cutlass.base_dsl.typing import * from cutlass._mlir import ir -import cutlass._mlir.extras.types as T -from cutlass._mlir.dialects.cute import AddressSpace +from cutlass._mlir.dialects.cute import AddressSpace, ConstrainedIntType Int = Union[int, Integer] @@ -24,7 +23,6 @@ Int = Union[int, Integer] ScaledBasis = ForwardRef("ScaledBasis") - IntTuple = Union[Int, Tuple["IntTuple", ...]] Shape = Union[Int, Tuple["Shape", ...]] Stride = Union[Int, ScaledBasis, Tuple["Stride", ...]] @@ -35,7 +33,7 @@ class Layout(ir.Value): def __init__(self, op_result): super().__init__(op_result) - def __str__(self): ... + def __str__(self) -> str: ... def get_hier_coord(self, idx) -> Coord: """Return the (hierarchical) ND logical coordinate corresponding to the linear index""" @@ -48,12 +46,95 @@ class Layout(ir.Value): def stride(self, *, loc=None, ip=None) -> Stride: ... +class ComposedLayout(ABC): + r"""ComposedLayout represents the functional composition of layouts in CuTe. + + **Formally:** + + .. math:: + + R(c) := (inner \circ offset \circ outer)(c) := inner(offset + outer(c)) + + where: + + - inner: The inner layout or swizzle that is applied last + - offset: An integer tuple representing a coordinate offset + - outer: The outer layout that is applied first + + This composition allows for complex transformations of coordinates and indices, + enabling operations like tiling, partitioning, and reshaping of data. + + :ivar inner: The inner layout or swizzle component + :ivar offset: The coordinate offset applied between inner and outer layouts + :ivar outer: The outer layout component + :ivar max_alignment: The maximum alignment of the composed layout + + **Examples:** + + .. code-block:: python + + # Create a composed layout with inner layout, offset, and outer layout + + # inner layout: (4, 8):(1, 4) + inner_layout = make_layout((4, 8)) + + offset = (0, 0) + + # outer layout: (2, 2):(1@0, 1@1) + outer_layout = make_layout((2, 2), stride=(1 * E(0), 1 * E(1))) + + # composed layout: (inner o offset o outer) + composed = make_composed_layout(inner_layout, offset, outer_layout) + + # Accessing components of the composed layout + inner = composed.inner + offset = composed.offset + outer = composed.outer + + # map coordinate (0, 1) to linear index + # - outer(0, 1) = (0, 1) + # - offset + outer(0, 1) = (0, 1) + # - inner(0, 1) = 0 * 1 + 1 * 4 = 4 + idx = crd2idx((0, 1), composed) + + # Composition is used in many tiling operations + # For example, in logical_product, raked_product, and blocked_product + """ + + @property + @abstractmethod + def type(self) -> ir.Type: ... + + @property + @abstractmethod + def is_normal(self) -> bool: ... + + @property + @abstractmethod + def inner(self, *, loc=None, ip=None): ... + + @property + @abstractmethod + def offset(self, *, loc=None, ip=None) -> IntTuple: ... + + @property + @abstractmethod + def outer(self, *, loc=None, ip=None) -> Layout: ... + + @property + @abstractmethod + def shape(self, *, loc=None, ip=None): ... + + @abstractmethod + def __call__(self, coord: Coord, loc=None, ip=None) -> IntTuple: ... + + Tile = Union[Int, None, Layout, Tuple["Tile", ...]] +Tiler = Union[Shape, Layout, Tile] + # XTuple is super set of above types -XTuple = Union[IntTuple, Shape, Stride, Coord, Tile] - -Tiler = Union[Shape, Layout, Tile] +XTuple = Union[Any, Tuple["XTuple", ...]] class Pointer(ABC): @@ -70,6 +151,8 @@ class Pointer(ABC): def align(self, min_align: int) -> "Pointer": ... + def __add__(self, other: int, *, loc=None, ip=None) -> "Pointer": ... + def __get_mlir_types__(self) -> List[ir.Type]: ... def __extract_mlir_values__(self) -> List[ir.Value]: ... @@ -78,51 +161,69 @@ class Pointer(ABC): class Tensor(ABC): - """ - Abstract base class for CuTe jit function and runtime _Tensor + r"""Abstract base class for Tensor representations in CuTe DSL. - A CuTe Tensor is iterator with layout + A CuTe Tensor is iterator with layout. A tensor evaluates the layout by mapping a + coordinate to the codomain, offsets the iterator accordingly, and dereferences + the result to obtain the tensor's value. - :Examples: + **Formally:** + + .. math:: + + T(c) = (E \circ L)(c) = *(E + L(c)) + + where + + - :math:`E` is the iterator/engine + - :math:`L` is the layout + + **Notes:** + + - The tensor supports both direct element access via coordinates and slicing operations + - Load/store operations are only supported for specific memory spaces (rmem, smem, gmem, generic) + - For composed layouts, stride information is not directly accessible + - Dynamic layouts do not support vector load/store operations + + **Examples:** Create tensor from torch.tensor with Host Runtime: .. code-block:: python - >>> import torch - >>> from cutlass.cute.runtime import from_dlpack - >>> mA = from_dlpack(torch.tensor([1, 3, 5], dtype=torch.int32)) - >>> mA.shape - (3,) - >>> mA.stride - (1,) - >>> mA.layout - (3,):(1,) + import torch + from cutlass.cute.runtime import from_dlpack + + mA = from_dlpack(torch.tensor([1, 3, 5], dtype=torch.int32)) + print(mA.shape) # (3,) + print(mA.stride) # (1,) + print(mA.layout) # (3,):(1,) Define JIT function: .. code-block:: python @cute.jit - def add(a: Tensor, b: Tensor, res: Tensor): ... + def add(a: Tensor, b: Tensor, res: Tensor): + res.store(a.load() + b.load()) Call JIT function from python: .. code-block:: python - >>> import torch - >>> a = torch.tensor([1, 3, 5], dtype=torch.int32) - >>> b = torch.tensor([2, 4, 6], dtype=torch.int32) - >>> c = torch.zeros([3], dtype=torch.int32) - >>> mA = from_dlpack(a) - >>> mB = from_dlpack(b) - >>> mC = from_dlpack(c) - >>> add(mA, mB, mC) - >>> c - tensor([3, 7, 11], dtype=torch.int32) + import torch + a = torch.tensor([1, 3, 5], dtype=torch.int32) + b = torch.tensor([2, 4, 6], dtype=torch.int32) + c = torch.zeros([3], dtype=torch.int32) + mA = from_dlpack(a) + mB = from_dlpack(b) + mC = from_dlpack(c) + add(mA, mB, mC) + print(c) # tensor([3, 7, 11], dtype=torch.int32) """ - def __str__(self): ... + @abstractmethod + def __str__(self) -> str: ... @abstractmethod def __getitem__(self, idx) -> Union["Tensor", ir.Value, IntTuple]: ... @@ -143,7 +244,7 @@ class Tensor(ABC): @property @abstractmethod - def iterator(self): ... + def iterator(self) -> Union[Pointer, IntTuple]: ... @property def layout(self) -> Union[Layout, "ComposedLayout"]: ... @@ -151,16 +252,19 @@ class Tensor(ABC): @property def shape(self) -> Shape: ... + @property + def stride(self) -> Stride: ... + def load(self, *, loc=None, ip=None) -> "TensorSSA": ... def store(self, data: "TensorSSA", *, loc=None, ip=None): ... - def mark_layout_dynamic(self, leading_dim: int | None = None) -> "Tensor": ... + def mark_layout_dynamic(self, leading_dim: Optional[int] = None) -> "Tensor": ... def mark_compact_shape_dynamic( self, mode: int, - stride_order: tuple[int, ...] | None = None, + stride_order: Optional[tuple[int, ...]] = None, divisibility: int = 1, ) -> "Tensor": ... @@ -168,11 +272,27 @@ class Tensor(ABC): def fill(self, value: Numeric) -> None: ... +def is_integer(a) -> bool: + """Check if an object is static integer or dynamic integer""" + return isinstance(a, (int, Integer)) or ( + isinstance(a, ir.Value) + and isinstance(a.type, (ir.IntegerType, ConstrainedIntType)) + ) + + +def is_int_tuple(a) -> bool: + if isinstance(a, tuple): + return all([is_int_tuple(x) for x in a]) + else: + return is_integer(a) + + __all__ = [ "Coord", "Numeric", "Integer", "Boolean", + "Int4", "Int8", "Int16", "Int32", @@ -204,4 +324,6 @@ __all__ = [ "Tile", "Tiler", "XTuple", + "is_integer", + "is_int_tuple", ] diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/__init__.py b/python/CuTeDSL/cutlass/cutlass_dsl/__init__.py index 06ea3f6f..10dc52d0 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/__init__.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/__init__.py @@ -28,11 +28,11 @@ from ..base_dsl.ast_helpers import ( any_executor, all_executor, range_value_check, - range_perf_warning, cf_symbol_check, redirect_builtin_function, copy_members, get_locals_or_none, + closure_check, ) from ..base_dsl import * @@ -42,5 +42,21 @@ from ..base_dsl._mlir_helpers.gpu import * from ..base_dsl._mlir_helpers.op import dsl_user_op from ..base_dsl.runtime import * from ..base_dsl.runtime import cuda as cuda_helpers -from ..base_dsl.compiler import compile +from ..base_dsl.compiler import ( + CompileCallable, + OptLevel, + PtxasOptions, + EnableAssertions, + GenerateLineInfo, + KeepCUBIN, + KeepPTX, + GPUArch, + LinkLibraries, +) from ..base_dsl.runtime.jit_arg_adapters import * + + +from ..base_dsl.utils.logger import _init_logger_with_client_name + +# Initialize logger +_init_logger_with_client_name("CUTE_DSL") diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/cutlass.py b/python/CuTeDSL/cutlass/cutlass_dsl/cutlass.py index 1630c873..19684d9b 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/cutlass.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/cutlass.py @@ -15,14 +15,26 @@ regarding to that dialect. """ # Local module imports -from itertools import chain from types import GenericAlias, SimpleNamespace, UnionType -from typing import Callable, Union, Type, List, Union, Sequence, ForwardRef, Any +from typing import ( + Callable, + Union, + List, + Tuple, + Sequence, + ForwardRef, + Any, + get_origin, + get_args, +) import functools import pkgutil from dataclasses import is_dataclass, fields +from math import ceil +from itertools import chain from collections.abc import Sequence import builtins +import ctypes from ..base_dsl import * from ..base_dsl import compiler @@ -30,8 +42,8 @@ from ..base_dsl.dsl import is_dynamic_expression, extract_mlir_values from ..base_dsl.typing import * from ..base_dsl.typing import DynamicExpression, get_mlir_types from ..base_dsl.runtime.jit_arg_adapters import is_arg_spec_constexpr +from ..base_dsl.runtime import cuda as cuda_helpers -from ..base_dsl.ast_helpers import const_expr # MLIR Imports from cutlass._mlir import ir, execution_engine, passmanager @@ -43,7 +55,8 @@ from cutlass._mlir.extras import types as T # Helpers from ..base_dsl._mlir_helpers import arith as cutlass_arith -from ..base_dsl._mlir_helpers import lru_cache_ir +from ..base_dsl._mlir_helpers.op import dsl_user_op +from ..base_dsl._mlir_helpers.arith import const from ..base_dsl.ast_helpers import ( loop_selector, @@ -60,7 +73,6 @@ from ..base_dsl.ast_helpers import ( any_executor, all_executor, range_value_check, - range_perf_warning, cf_symbol_check, ) @@ -132,6 +144,22 @@ def is_cute_algebra_type(arg_spec): return False +def _build_kernel_attrs(config) -> dict: + kernel_attrs = {} + if config.min_blocks_per_mp > 1: + kernel_attrs = { + cuda_helpers.cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT: ceil( + config.min_blocks_per_mp + * config.smem + / cuda_helpers.get_device_attribute( + cuda_helpers.cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR + ) + * 100 + ) + } + return kernel_attrs + + def _get_c_pointers_cutlass(obj): """ This is an extended version of `get_c_pointers` that supports dataclasses, SimpleNamespace, and dict. @@ -203,8 +231,8 @@ class CutlassBaseDSL(BaseDSL): ) -> Any: return False - def _build_gpu_module(self, attrs): - self.gpu_module = gpu.GPUModuleOp(ir.StringAttr.get("kernels")) + def _build_gpu_module(self, attrs, loc=None): + self.gpu_module = gpu.GPUModuleOp(ir.StringAttr.get("kernels"), loc=loc) with ir.InsertionPoint(self.gpu_module.bodyRegion.blocks.append(*[])): pass @@ -232,9 +260,9 @@ class CutlassBaseDSL(BaseDSL): return ir.InsertionPoint(self.gpu_module.bodyRegion.blocks[0]) def _generate_kernel_attrs(self, config: BaseDSL.LaunchConfig) -> dict: - assert isinstance( - config, BaseDSL.LaunchConfig - ), f"Expect LaunchConfig for @kernel, but got {type(config)}" + assert isinstance(config, BaseDSL.LaunchConfig), ( + f"Expect LaunchConfig for @kernel, but got {type(config)}" + ) ret = {} # generate launch bound attr from LaunchConfig @@ -278,7 +306,7 @@ class CutlassBaseDSL(BaseDSL): version_hash.update(chunk) except Exception: raise DSLRuntimeError( - f"Failed to read the shared library file libCutlassIRPythonCAPI.so." + "Failed to read the shared library file libCutlassIRPythonCAPI.so." "The file may not exist or may not be readable." "Please re-install the package." ) @@ -315,6 +343,50 @@ class CutlassBaseDSL(BaseDSL): allocator, callback = self._smem_usage_tracker return callback(allocator) + @staticmethod + def gpu_launch_func( + async_token, + async_dependencies, + kernel, + grid_size_x, + grid_size_y, + grid_size_z, + block_size_x, + block_size_y, + block_size_z, + kernel_operands, + *, + cluster_size_x=None, + cluster_size_y=None, + cluster_size_z=None, + dynamic_shared_memory_size=None, + async_object=None, + use_pdl=False, + loc=None, + ip=None, + ) -> ir.Value: + op = gpu.LaunchFuncOp( + asyncToken=async_token, + asyncDependencies=async_dependencies, + kernel=kernel, + gridSizeX=grid_size_x, + gridSizeY=grid_size_y, + gridSizeZ=grid_size_z, + blockSizeX=block_size_x, + blockSizeY=block_size_y, + blockSizeZ=block_size_z, + kernelOperands=kernel_operands, + clusterSizeX=cluster_size_x, + clusterSizeY=cluster_size_y, + clusterSizeZ=cluster_size_z, + dynamicSharedMemorySize=dynamic_shared_memory_size, + asyncObject=async_object, + loc=loc, + ip=ip, + ) + op.attributes["use_pdl"] = ir.BoolAttr.get(use_pdl) + return _get_op_result_or_op_results(op) + def _kernel_helper(self, funcBody, *args, **kwargs): class _CutlassIrKernelGenHelper(BaseDSL._KernelGenHelper): def __init__(self, dsl: CutlassBaseDSL): @@ -344,16 +416,17 @@ class CutlassBaseDSL(BaseDSL): kernelSym = kwargs.get("kernelSym", None) kernelOperands = kwargs.get("kernelOperands", None) requiredArgs = kwargs.get("requiredArgs", None) + loc = kwargs.get("loc", None) assert kernelSym is not None, "kernelSym being None is not expected!" - assert ( - requiredArgs is not None - ), "requiredArgs being None is not expected!" - assert ( - kernelOperands is not None - ), "kernelOperands being None is not expected!" - assert isinstance( - requiredArgs.config, BaseDSL.LaunchConfig - ), f"Expect LaunchConfig for @kernel, but got {type(requiredArgs.config)}" + assert requiredArgs is not None, ( + "requiredArgs being None is not expected!" + ) + assert kernelOperands is not None, ( + "kernelOperands being None is not expected!" + ) + assert isinstance(requiredArgs.config, BaseDSL.LaunchConfig), ( + f"Expect LaunchConfig for @kernel, but got {type(requiredArgs.config)}" + ) cfg = requiredArgs.config @@ -379,7 +452,7 @@ class CutlassBaseDSL(BaseDSL): if not isinstance(cfg.async_deps, (list, tuple)): cfg.async_deps = [cfg.async_deps] is_async = len(cfg.async_deps) > 0 - token = gpu.launch_func( + token = CutlassBaseDSL.gpu_launch_func( gpu.AsyncTokenType.get() if is_async else None, cfg.async_deps, kernelSym, @@ -393,6 +466,8 @@ class CutlassBaseDSL(BaseDSL): ) ), dynamic_shared_memory_size=cfg.smem, + use_pdl=cfg.use_pdl, + loc=loc, ) return token if is_async else None @@ -432,7 +507,7 @@ class CutlassBaseDSL(BaseDSL): expected_base = get_args(arg_annotation)[0] if not issubclass(arg, expected_base): return DSLRuntimeError( - f"expects argument #{arg_index+1} ({arg_name}) to be Type[{expected_base}], but got {arg}" + f"expects argument #{arg_index + 1} ({arg_name}) to be Type[{expected_base}], but got {arg}" ) # Handle Union types and generic types elif origin is Union or isinstance(arg_annotation, UnionType): @@ -445,13 +520,17 @@ class CutlassBaseDSL(BaseDSL): for ty in allowed_types ): return DSLRuntimeError( - f"expects argument #{arg_index+1} ({arg_name}) to be one of {allowed_types}, but got {type(arg)}" + f"expects argument #{arg_index + 1} ({arg_name}) to be one of {allowed_types}, but got {type(arg)}" ) + elif isinstance(arg_annotation, GenericAlias): + # skip generic types such as List[int], Tuple[int, int], etc. for performance consideration? + pass + elif isinstance(arg_annotation, type): # Handle simple type annotations if not isinstance(arg, arg_annotation) and arg is not None: return DSLRuntimeError( - f"expects argument #{arg_index+1} ({arg_name}) to be {arg_annotation}, but got {type(arg)}" + f"expects argument #{arg_index + 1} ({arg_name}) to be {arg_annotation}, but got {type(arg)}" ) # Everything looks good if we are here return None @@ -618,6 +697,7 @@ class KernelLauncher: self.dsl.frame = inspect.currentframe().f_back self.dsl._preprocess_launch_config_args(args, kwargs) config = self.dsl.LaunchConfig(*args, **kwargs) + kernel_attrs = _build_kernel_attrs(config) kernel_generator = self.dsl.kernel_launcher( requiredArgs=["config"], @@ -627,7 +707,7 @@ class KernelLauncher: )(self.funcBody) ret, name = kernel_generator(*self.func_args, **self.func_kwargs, config=config) - self.dsl.kernel_symbols.append(name) + self.dsl.kernel_info[name] = kernel_attrs self.dsl.frame = None return ret.launch_op_ret @@ -790,9 +870,9 @@ def to_index(value): if is_dynamic_expression(value): if isinstance(value, Numeric): value = value.ir_value() - assert ir.IntegerType.isinstance( - value.type - ), f"expects integer type, but got {value.type}" + assert ir.IntegerType.isinstance(value.type), ( + f"expects integer type, but got {value.type}" + ) res = arith.index_cast(T.index(), value) else: res = const(int(value), ty=T.index()) @@ -831,7 +911,6 @@ def _validate_iter_args_structure(iter_args, ir_values): return count_values(iter_args) == len(ir_values) - # ============================================================================= # DSL implementation of Python Build-in Operators # ============================================================================= @@ -839,7 +918,7 @@ def _validate_iter_args_structure(iter_args, ir_values): def _minmax(op, *args, loc=None, ip=None): """Computes the minimum or maximum value from the provided arguments.""" - from ..base_dsl.typing import _binary_op, _binary_op_type_promote + from ..base_dsl.typing import _binary_op_type_promote # AST Traversal doesn't support early exit in if executor x = None @@ -854,34 +933,39 @@ def _minmax(op, *args, loc=None, ip=None): # Handle case for min(a, b, c, ...) and min([x, y], [b]) and min(a, (x, y, z)) elif len(args) > 1: res, *xs = tuple(args) + for x in xs: - lhs = as_numeric(op(res, loc=loc, ip=ip)) - rhs = as_numeric(op(x, loc=loc, ip=ip)) emitter = getattr(cutlass_arith, f"_{op.__name__}") - - lhs, rhs, res_type = _binary_op_type_promote(lhs, rhs, promote_bool=True) - - if isinstance(lhs.value, cutlass_arith.ArithValue) and isinstance( - lhs, Integer - ): - lhs_val = lhs.value.with_signedness(lhs.signed) + if not (is_dynamic_expression(res) or is_dynamic_expression(x)): + res = emitter(op(res), op(x)) else: - lhs_val = lhs.value + lhs = as_numeric(op(res, loc=loc, ip=ip)) + rhs = as_numeric(op(x, loc=loc, ip=ip)) + lhs, rhs, res_type = _binary_op_type_promote( + lhs, rhs, promote_bool=True + ) - if isinstance(rhs.value, cutlass_arith.ArithValue) and isinstance( - rhs, Integer - ): - rhs_val = rhs.value.with_signedness(rhs.signed) - else: - rhs_val = rhs.value + if isinstance(lhs.value, cutlass_arith.ArithValue) and isinstance( + lhs, Integer + ): + lhs_val = lhs.value.with_signedness(lhs.signed) + else: + lhs_val = lhs.value - res = res_type(emitter(lhs_val, rhs_val), loc=loc, ip=ip) + if isinstance(rhs.value, cutlass_arith.ArithValue) and isinstance( + rhs, Integer + ): + rhs_val = rhs.value.with_signedness(rhs.signed) + else: + rhs_val = rhs.value + res = res_type(emitter(lhs_val, rhs_val), loc=loc, ip=ip) x = res else: raise DSLNotImplemented(f"{type(args)} is not supported") return x +@dsl_user_op def min(*args, loc=None, ip=None): """Computes the minimum value from the provided arguments. @@ -940,6 +1024,7 @@ def min(*args, loc=None, ip=None): return _minmax(min, *args, loc=loc, ip=ip) +@dsl_user_op def max(*args, loc=None, ip=None): """Computes the maximum value from the provided arguments. @@ -1275,7 +1360,7 @@ def for_generate( def _createI32Attr(value): if not isinstance(value, int): - raise DSLRuntimeError(f"value must be int.") + raise DSLRuntimeError("value must be int.") return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), value) ir_iter_args = extract_mlir_values(iter_args) if iter_args is not None else None @@ -1529,23 +1614,22 @@ def in_(lhs, rhs): def _lte_gte(lhs, rhs, op): def native_lte_gte(lhs, rhs, op): - match op: - case "<": - return lhs < rhs - case "<=": - if hasattr(lhs, "__le__"): - return lhs <= rhs - else: - return not_(lhs > rhs) - case ">": - return lhs > rhs - case ">=": - if hasattr(lhs, "__ge__"): - return lhs >= rhs - else: - return not_(lhs < rhs) - case _: - raise DSLRuntimeError(f"Unsupported comparison operator: {op}") + if op == "<": + return lhs < rhs + elif op == "<=": + if hasattr(lhs, "__le__"): + return lhs <= rhs + else: + return not_(lhs > rhs) + elif op == ">": + return lhs > rhs + elif op == ">=": + if hasattr(lhs, "__ge__"): + return lhs >= rhs + else: + return not_(lhs < rhs) + else: + raise DSLRuntimeError(f"Unsupported comparison operator: {op}") if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): return native_lte_gte(lhs, rhs, op) @@ -1571,15 +1655,14 @@ def _lte_gte(lhs, rhs, op): # Ref https://docs.python.org/3/tutorial/datastructures.html#comparing-sequences-and-other-types # If one sequence is an initial sub-sequence of the other, the shorter sequence is the smaller (lesser) one has_valid_mask = any_(mask) - match op: - case "<": - length_result = len(lhs) < len(rhs) - case ">": - length_result = len(lhs) > len(rhs) - case "<=": - length_result = len(lhs) <= len(rhs) - case ">=": - length_result = len(lhs) >= len(rhs) + if op == "<": + length_result = len(lhs) < len(rhs) + elif op == ">": + length_result = len(lhs) > len(rhs) + elif op == "<=": + length_result = len(lhs) <= len(rhs) + elif op == ">=": + length_result = len(lhs) >= len(rhs) if type(has_valid_mask) == bool: return result if has_valid_mask else length_result else: @@ -1624,30 +1707,29 @@ def _compare_dispatch(lhs, rhs, op): :return: The result of the comparison, which may be a boolean or a DSL-specific type. :raises DSLRuntimeError: If the operator is not supported. """ - match op: - # 'is' and 'is not' are pure python operators - case "is": - return lhs is rhs - case "is not": - return lhs is not rhs - case "in": - return in_(lhs, rhs) - case "not in": - return not_(in_(lhs, rhs)) - case "==": - return equal(lhs, rhs) - case "!=": - return not_equal(lhs, rhs) - case "<": - return less_than(lhs, rhs) - case ">": - return greater_than(lhs, rhs) - case ">=": - return greater_equal(lhs, rhs) - case "<=": - return less_equal(lhs, rhs) - case _: - raise DSLRuntimeError(f"Unsupported comparison operator: {op}") + # 'is' and 'is not' are pure python operators + if op == "is": + return lhs is rhs + elif op == "is not": + return lhs is not rhs + elif op == "in": + return in_(lhs, rhs) + elif op == "not in": + return not_(in_(lhs, rhs)) + elif op == "==": + return equal(lhs, rhs) + elif op == "!=": + return not_equal(lhs, rhs) + elif op == "<": + return less_than(lhs, rhs) + elif op == ">": + return greater_than(lhs, rhs) + elif op == ">=": + return greater_equal(lhs, rhs) + elif op == "<=": + return less_equal(lhs, rhs) + else: + raise DSLRuntimeError(f"Unsupported comparison operator: {op}") def _compare_executor(left, comparators, ops): diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/cutlass_ast_decorators.py b/python/CuTeDSL/cutlass/cutlass_dsl/cutlass_ast_decorators.py index b5b4d895..c0b55f1a 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/cutlass_ast_decorators.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/cutlass_ast_decorators.py @@ -9,25 +9,17 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import List, Tuple -from types import NoneType +from typing import List + from cutlass._mlir import ir -from cutlass._mlir.dialects import scf, arith -from cutlass._mlir.extras import types as T +from cutlass._mlir.dialects import scf from collections.abc import Sequence from ..base_dsl.dsl import is_dynamic_expression from ..base_dsl.ast_helpers import * from ..base_dsl.utils.logger import log from ..base_dsl import typing as t -from ..base_dsl.typing import ( - Int32, - Float32, - Boolean, - Numeric, - get_mlir_types, - as_numeric, -) +from ..base_dsl.typing import Boolean, Numeric, as_numeric from . import cutlass as cutlass_dsl from .tree_utils import PyTreeDef, check_tree_equal @@ -35,6 +27,8 @@ from .tree_utils import PyTreeDef, check_tree_equal # AST Helpers # ============================================================================= +NoneType = type(None) + class LoopUnroll(ir.Attribute): def __init__(self, **kwargs): @@ -312,9 +306,12 @@ def _loop_execute_range_dynamic( start_ = t.as_numeric(start) stop_ = t.as_numeric(stop) step_ = t.as_numeric(step) - assert start_ is not t.Int32, "Start is required for scf.for" - assert stop_ is not t.Int32, "Stop is required for scf.for" - assert step_ is not t.Int32, "Step is required for scf.for" + if start_.dtype is not t.Int32: + raise DSLRuntimeError(f"expected Int32 for start, got {start_.dtype}") + if stop_.dtype is not t.Int32: + raise DSLRuntimeError(f"expected Int32 for stop, got {stop_.dtype}") + if step_.dtype is not t.Int32: + raise DSLRuntimeError(f"expected Int32 for step, got {step_.dtype}") start_ = start_.ir_value() stop_ = stop_.ir_value() step_ = step_.ir_value() diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/tree_utils.py b/python/CuTeDSL/cutlass/cutlass_dsl/tree_utils.py index 599b72ea..2bc94269 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/tree_utils.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/tree_utils.py @@ -19,6 +19,8 @@ from ..base_dsl._mlir_helpers.arith import ArithValue from ..base_dsl.common import DSLBaseError from .._mlir import ir +NoneType = type(None) + # ============================================================================= # Tree Utils # ============================================================================= @@ -124,6 +126,7 @@ def is_constexpr_field(field: dataclasses.Field) -> bool: # PyTreeDef # ============================================================================= + class NodeType(NamedTuple): """ Represents a node in a pytree structure. @@ -133,6 +136,7 @@ class NodeType(NamedTuple): to_iterable: Function to convert node to iterable form from_iterable: Function to reconstruct node from iterable form """ + name: str to_iterable: Callable from_iterable: Callable @@ -147,6 +151,7 @@ class PyTreeDef(NamedTuple): node_metadata: SimpleNamespace metadata associated with this node child_treedefs: Tuple of child tree definitions """ + node_type: NodeType node_metadata: SimpleNamespace child_treedefs: tuple["PyTreeDef", ...] @@ -163,6 +168,7 @@ class Leaf: node_metadata: SimpleNamespace metadata associated with this leaf ir_type_str: String representation of the IR type """ + is_numeric: bool = False is_none: bool = False node_metadata: SimpleNamespace = None @@ -267,6 +273,7 @@ def set_dataclass_attributes( kwargs[field] = getattr(instance, field) return dataclasses.replace(instance, **kwargs) + def default_dataclass_from_iterable( metadata: SimpleNamespace, children: Iterable[Any] ) -> Any: @@ -509,7 +516,7 @@ def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]: return list(children_iter), treedef -def get_registered_node_types_or_insert(x: Any) -> NodeType | None: +def get_registered_node_types_or_insert(x: Any) -> Union[NodeType, None]: """ Get the registered node type for an object, registering it if necessary. @@ -569,7 +576,7 @@ def create_leaf_for_value( ) -def _tree_flatten(x: Any) -> tuple[Iterable[Any], PyTreeDef | Leaf]: +def _tree_flatten(x: Any) -> tuple[Iterable[Any], Union[PyTreeDef, Leaf]]: """ Internal function to flatten a tree structure. @@ -587,50 +594,52 @@ def _tree_flatten(x: Any) -> tuple[Iterable[Any], PyTreeDef | Leaf]: Raises: DSLTreeFlattenError: If the object type is not supported """ - match x: - case None: - return [], create_leaf_for_value(x, is_none=True) + if x is None: + return [], create_leaf_for_value(x, is_none=True) - case ArithValue() if is_dynamic_expression(x): - v = x.__extract_mlir_values__() - return v, create_leaf_for_value( - x, - node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x), - ir_type_str=str(v[0].type), - ) + elif isinstance(x, ArithValue) and is_dynamic_expression(x): + v = x.__extract_mlir_values__() + return v, create_leaf_for_value( + x, + node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x), + ir_type_str=str(v[0].type), + ) - case ArithValue(): - return [x], create_leaf_for_value(x, is_numeric=True) + elif isinstance(x, ArithValue): + return [x], create_leaf_for_value(x, is_numeric=True) - case ir.Value(): - return [x], create_leaf_for_value(x) + elif isinstance(x, ir.Value): + return [x], create_leaf_for_value(x) - case Numeric(): - v = x.__extract_mlir_values__() - return v, create_leaf_for_value( - x, - node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x), - ir_type_str=str(v[0].type), - ) + elif isinstance(x, Numeric): + v = x.__extract_mlir_values__() + return v, create_leaf_for_value( + x, + node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x), + ir_type_str=str(v[0].type), + ) - case _: - node_type = get_registered_node_types_or_insert(x) - if node_type: - node_metadata, children = node_type.to_iterable(x) - children_flat, child_trees = unzip2(map(_tree_flatten, children)) - flattened = it.chain.from_iterable(children_flat) - return flattened, PyTreeDef( - node_type, node_metadata, tuple(child_trees) - ) - - # Try to convert to numeric - try: - nval = as_numeric(x).ir_value() - return [nval], create_leaf_for_value(nval, is_numeric=True) - except Exception: + else: + node_type = get_registered_node_types_or_insert(x) + if node_type: + node_metadata, children = node_type.to_iterable(x) + if children is None: + # Flatten should not return None, it should return an empty list for real empty cases raise DSLTreeFlattenError( - "Flatten Error", get_fully_qualified_class_name(x) + "Flatten Error: children is None", get_fully_qualified_class_name(x) ) + children_flat, child_trees = unzip2(map(_tree_flatten, children)) + flattened = it.chain.from_iterable(children_flat) + return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees)) + + # Try to convert to numeric + try: + nval = as_numeric(x).ir_value() + return [nval], create_leaf_for_value(nval, is_numeric=True) + except Exception: + raise DSLTreeFlattenError( + "Flatten Error", get_fully_qualified_class_name(x) + ) def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any: @@ -656,7 +665,7 @@ def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any: return _tree_unflatten(treedef, iter(xs)) -def _tree_unflatten(treedef: PyTreeDef | Leaf, xs: Iterator[Any]) -> Any: +def _tree_unflatten(treedef: Union[PyTreeDef, Leaf], xs: Iterator[Any]) -> Any: """ Internal function to reconstruct a tree structure. @@ -670,24 +679,18 @@ def _tree_unflatten(treedef: PyTreeDef | Leaf, xs: Iterator[Any]) -> Any: Returns: The reconstructed object """ - match treedef: - case Leaf(is_none=True): + if isinstance(treedef, Leaf): + if getattr(treedef, "is_none", False): return None - - case Leaf( - node_metadata=metadata - ) if metadata and metadata.is_dynamic_expression: + metadata = getattr(treedef, "node_metadata", None) + if metadata and getattr(metadata, "is_dynamic_expression", False): return metadata.original_obj.__new_from_mlir_values__([next(xs)]) - - case Leaf(is_numeric=True): + if getattr(treedef, "is_numeric", False): return as_numeric(next(xs)) - - case Leaf(): - return next(xs) - - case PyTreeDef(): - children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs) - return treedef.node_type.from_iterable(treedef.node_metadata, children) + return next(xs) + elif isinstance(treedef, PyTreeDef): + children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs) + return treedef.node_type.from_iterable(treedef.node_metadata, children) def _check_tree_equal(lhs: Union[PyTreeDef, Leaf], rhs: Union[PyTreeDef, Leaf]) -> bool: @@ -704,29 +707,26 @@ def _check_tree_equal(lhs: Union[PyTreeDef, Leaf], rhs: Union[PyTreeDef, Leaf]) Returns: bool: True if the trees are structurally equal, False otherwise """ - match (lhs, rhs): - case (Leaf(), Leaf()): - return lhs.is_none == rhs.is_none and lhs.ir_type_str == rhs.ir_type_str + if isinstance(lhs, Leaf) and isinstance(rhs, Leaf): + return lhs.is_none == rhs.is_none and lhs.ir_type_str == rhs.ir_type_str + elif isinstance(lhs, PyTreeDef) and isinstance(rhs, PyTreeDef): + lhs_metadata = lhs.node_metadata + rhs_metadata = rhs.node_metadata - case (PyTreeDef(), PyTreeDef()): - lhs_metadata = lhs.node_metadata - rhs_metadata = rhs.node_metadata + lhs_fields = getattr(lhs_metadata, "fields", []) + rhs_fields = getattr(rhs_metadata, "fields", []) + lhs_constexpr_fields = getattr(lhs_metadata, "constexpr_fields", []) + rhs_constexpr_fields = getattr(rhs_metadata, "constexpr_fields", []) - lhs_fields = getattr(lhs_metadata, "fields", []) - rhs_fields = getattr(rhs_metadata, "fields", []) - lhs_constexpr_fields = getattr(lhs_metadata, "constexpr_fields", []) - rhs_constexpr_fields = getattr(rhs_metadata, "constexpr_fields", []) - - return ( - lhs.node_type == rhs.node_type - and lhs_fields == rhs_fields - and lhs_constexpr_fields == rhs_constexpr_fields - and len(lhs.child_treedefs) == len(rhs.child_treedefs) - and all(map(_check_tree_equal, lhs.child_treedefs, rhs.child_treedefs)) - ) - - case _: - return False + return ( + lhs.node_type == rhs.node_type + and lhs_fields == rhs_fields + and lhs_constexpr_fields == rhs_constexpr_fields + and len(lhs.child_treedefs) == len(rhs.child_treedefs) + and all(map(_check_tree_equal, lhs.child_treedefs, rhs.child_treedefs)) + ) + else: + return False def check_tree_equal(lhs: PyTreeDef, rhs: PyTreeDef) -> int: @@ -752,7 +752,7 @@ def check_tree_equal(lhs: PyTreeDef, rhs: PyTreeDef) -> int: assert len(lhs.child_treedefs) == len(rhs.child_treedefs) def find_first_difference( - index_and_pair: tuple[int, tuple[PyTreeDef, PyTreeDef]] + index_and_pair: tuple[int, tuple[PyTreeDef, PyTreeDef]], ) -> int: index, (l, r) = index_and_pair return index if not _check_tree_equal(l, r) else -1 diff --git a/python/CuTeDSL/cutlass/pipeline/__init__.py b/python/CuTeDSL/cutlass/pipeline/__init__.py index 7df24dd6..2f88819c 100644 --- a/python/CuTeDSL/cutlass/pipeline/__init__.py +++ b/python/CuTeDSL/cutlass/pipeline/__init__.py @@ -35,6 +35,7 @@ from .sm90 import ( PipelineTmaAsync, PipelineTmaMultiConsumersAsync, PipelineTmaStore, + PipelineOrder, PipelineProducer, PipelineConsumer, ) @@ -52,6 +53,7 @@ __all__ = [ "SyncObject", "MbarrierArray", "NamedBarrier", + "PipelineOrder", "TmaStoreFence", "PipelineUserType", "PipelineState", @@ -65,4 +67,12 @@ __all__ = [ "PipelineTmaStore", "PipelineProducer", "PipelineConsumer", + "make_pipeline_state", + "pipeline_init_wait", + "arrive", + "arrive_unaligned", + "wait", + "wait_unaligned", + "arrive_and_wait", + "sync", ] diff --git a/python/CuTeDSL/cutlass/pipeline/helpers.py b/python/CuTeDSL/cutlass/pipeline/helpers.py index b5b94899..480a0527 100644 --- a/python/CuTeDSL/cutlass/pipeline/helpers.py +++ b/python/CuTeDSL/cutlass/pipeline/helpers.py @@ -44,17 +44,17 @@ class CooperativeGroup: CooperativeGroup contains size and alignment restrictions for an Agent. """ - def __init__(self, agent: Agent, size: int = 1, alignment: int = 1): + def __init__(self, agent: Agent, size: int = 1, alignment=None): + if alignment is not None: + warnings.warn( + "The 'alignment' parameter of CooperativeGroup's constructor is deprecated and " + "will be removed in a subsequent release, please remove it from your code.", + DeprecationWarning, + stacklevel=2, + ) + if agent is Agent.Thread: assert size > 0 - if size == 32: - assert ( - size == alignment - ), "Error: Alignment does not match number of threads in a warp." - elif size == 128: - assert ( - size == alignment - ), "Error: Alignment does not match number of threads in a warpgroup." elif agent is Agent.ThreadBlock: raise NotImplementedError("Error: Not yet supported.") elif agent is Agent.ThreadBlockCluster: @@ -222,18 +222,18 @@ class MbarrierArray(SyncObject): if self.op_type is PipelineOp.AsyncThread: self.arrive_mbarrier(index, dst) elif self.op_type is PipelineOp.TCGen05Mma: - assert ( - cta_group is not None - ), "Error: CTA group must be provided for TCGen05Mma." + assert cta_group is not None, ( + "Error: CTA group must be provided for TCGen05Mma." + ) self.arrive_tcgen05mma(index, dst, cta_group) elif self.op_type in [PipelineOp.TmaLoad]: self.arrive_and_expect_tx(index, self.tx_count) elif self.op_type is PipelineOp.AsyncLoad: self.arrive_cp_async_mbarrier(index) else: - assert ( - False - ), f"Error: MbarrierArray is not supported for PipelineOp: {_get_pipeline_op(self.op_type)}." + assert False, ( + f"Error: MbarrierArray is not supported for PipelineOp: {_get_pipeline_op(self.op_type)}." + ) def arrive_mbarrier(self, index: int, dst_rank: Optional[int] = None) -> None: if dst_rank is None: @@ -351,9 +351,6 @@ class NamedBarrier(SyncObject): self.arrive_and_wait() def wait_unaligned(self) -> None: - warnings.warn( - "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." - ) llvm.inline_asm( None, [Int32(self.barrier_id).ir_value(), Int32(self.num_threads).ir_value()], @@ -373,7 +370,7 @@ class NamedBarrier(SyncObject): raise NotImplementedError("Error: Not supported.") def sync(self) -> None: - cute.arch.barrier(barrier_id=self.barrier_id) + self.arrive_and_wait() def get_barrier(self) -> int: return self.barrier_id @@ -409,9 +406,9 @@ class TmaStoreFence(SyncObject): # TmaStoreFence doesn't have mbarriers def get_barrier(self) -> None: - assert ( - False - ), "Error: TmaStoreFence doesn't use mbarriers and cannot return a barrier." + assert False, ( + "Error: TmaStoreFence doesn't use mbarriers and cannot return a barrier." + ) def max(self) -> None: raise NotImplementedError("Error: Not supported.") @@ -538,9 +535,9 @@ def make_pipeline_state(type: PipelineUserType, stages: int): Int32(0), ) else: - assert ( - False - ), "Error: invalid PipelineUserType specified for make_pipeline_state." + assert False, ( + "Error: invalid PipelineUserType specified for make_pipeline_state." + ) ############################################################################## @@ -574,9 +571,9 @@ def _sync(group: Agent): cute.arch.cluster_arrive() cute.arch.cluster_wait() else: - assert ( - False - ), "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead." + assert False, ( + "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead." + ) def _mbarrier_i64_to_ptr(val: Int64) -> cute.Pointer: diff --git a/python/CuTeDSL/cutlass/pipeline/sm100.py b/python/CuTeDSL/cutlass/pipeline/sm100.py index 2feed8cc..60cb83e2 100644 --- a/python/CuTeDSL/cutlass/pipeline/sm100.py +++ b/python/CuTeDSL/cutlass/pipeline/sm100.py @@ -9,17 +9,14 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -import enum -from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Union -import warnings +from typing import Optional +import cutlass import cutlass.cute as cute from cutlass.cutlass_dsl import Boolean, if_generate from cutlass.pipeline import ( - Agent, CooperativeGroup, PipelineOp, PipelineState, @@ -42,7 +39,9 @@ class PipelineTmaUmma(PipelineAsync): cta_group: cute.nvgpu.tcgen05.CtaGroup @staticmethod - def _compute_mcast_arrival_mask(cta_layout_vmnk: cute.Layout): + def _compute_mcast_arrival_mask( + cta_layout_vmnk: cute.Layout, mcast_mode_mn: tuple[int, int] + ): """ Computes a mask for signaling arrivals to multicasting threadblocks. """ @@ -69,12 +68,18 @@ class PipelineTmaUmma(PipelineAsync): cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=1 ) - return ( - tma_mcast_mask_a - | tma_mcast_mask_b - | tma_mcast_mask_a_peer - | tma_mcast_mask_b_peer - ) + assert not (mcast_mode_mn[0] == 0 and mcast_mode_mn[1] == 0) + if mcast_mode_mn[0] == 1 and mcast_mode_mn[1] == 1: + return ( + tma_mcast_mask_a + | tma_mcast_mask_b + | tma_mcast_mask_a_peer + | tma_mcast_mask_b_peer + ) + elif mcast_mode_mn[1] == 1: + return tma_mcast_mask_b | tma_mcast_mask_b_peer + assert mcast_mode_mn[0] == 1 + return tma_mcast_mask_a | tma_mcast_mask_a_peer @staticmethod def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout): @@ -100,6 +105,7 @@ class PipelineTmaUmma(PipelineAsync): tx_count: int, barrier_storage: cute.Pointer = None, cta_layout_vmnk: Optional[cute.Layout] = None, + mcast_mode_mn: tuple[int, int] = (1, 1), ): """ This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma. @@ -115,6 +121,8 @@ class PipelineTmaUmma(PipelineAsync): :type tx_count: int :param cta_layout_vmnk: Layout of the cluster shape :type cta_layout_vmnk: cute.Layout | None + :param mcast_mode_mn: Tuple of two integers, specifying whether mcast is enabled for the m and n modes. At least one of the two integers must be 1. + :type mcast_mode_mn: tuple[int, int] """ if not isinstance(barrier_storage, cute.Pointer): raise ValueError( @@ -140,7 +148,9 @@ class PipelineTmaUmma(PipelineAsync): # All threadblocks are leaders if not using clusters is_leader_cta = True else: - producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk) + producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask( + cta_layout_vmnk, mcast_mode_mn + ) is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk) cta_group = ( @@ -429,6 +439,7 @@ class PipelineUmmaAsync(PipelineAsync): """ self.sync_object_full.arrive(state.index, self.producer_mask, self.cta_group) + @cute.jit def producer_tail(self, state: PipelineState): """ Make sure the last used buffer empty signal is visible to producer. @@ -443,11 +454,9 @@ class PipelineUmmaAsync(PipelineAsync): ) is_leader_cta = cta_rank_in_cluster % 2 == 0 - def then_body(): + if is_leader_cta: # Assume state contains that next useful buffer # So we only need to advance to num_stages - 1 times to last used buffer - for i in range(self.num_stages - 1): + for i in cutlass.range_constexpr(self.num_stages - 1): state.advance() self.producer_acquire(state) - - if_generate(is_leader_cta, then_body) diff --git a/python/CuTeDSL/cutlass/pipeline/sm90.py b/python/CuTeDSL/cutlass/pipeline/sm90.py index 5fc19960..8e926354 100644 --- a/python/CuTeDSL/cutlass/pipeline/sm90.py +++ b/python/CuTeDSL/cutlass/pipeline/sm90.py @@ -9,26 +9,19 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -import enum -from typing import Type, Tuple -from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Union -import warnings +from typing import Optional -import cutlass import cutlass.cute as cute from cutlass.cutlass_dsl import Boolean, Int32, if_generate - from cutlass.pipeline import ( - Agent, CooperativeGroup, - PipelineOp, - SyncObject, MbarrierArray, - TmaStoreFence, - PipelineUserType, + PipelineOp, PipelineState, + PipelineUserType, + SyncObject, + TmaStoreFence, make_pipeline_state, pipeline_init_wait, ) @@ -273,7 +266,6 @@ class PipelineAsync: return self.make_producer(), self.make_consumer() - @dataclass(frozen=True) class PipelineCpAsync(PipelineAsync): """ @@ -338,7 +330,9 @@ class PipelineTmaAsync(PipelineAsync): @staticmethod @cute.jit - def init_empty_barrier_arrive_signal(cta_layout_vmnk: cute.Layout, tidx: Int32): + def init_empty_barrier_arrive_signal( + cta_layout_vmnk: cute.Layout, tidx: Int32, mcast_mode_mn: tuple[int, int] + ): """ Initialize the empty barrier arrive signal This function returns the destination cta rank and a boolean indicating if the signalling thread is the same as the current thread @@ -357,21 +351,27 @@ class PipelineTmaAsync(PipelineAsync): dst_cta_coord = cta_layout_vmnk.get_hier_coord(dst_rank) cur_cta_coord = cta_layout_vmnk.get_hier_coord(cta_rank_in_cluster) - is_same_row = ( + is_mcast_mode_m = ( dst_cta_coord[0] == cur_cta_coord[0] and dst_cta_coord[1] == cur_cta_coord[1] and dst_cta_coord[3] == cur_cta_coord[3] ) - is_same_col = ( + is_mcast_mode_n = ( dst_cta_coord[0] == cur_cta_coord[0] and dst_cta_coord[2] == cur_cta_coord[2] and dst_cta_coord[3] == cur_cta_coord[3] ) - is_same_row_or_col = is_same_row or is_same_col - is_signalling_thread_final = is_signalling_thread and is_same_row_or_col + assert not (mcast_mode_mn[0] == 0 and mcast_mode_mn[1] == 0) + if mcast_mode_mn[0] == 1 and mcast_mode_mn[1] == 0: + is_signalling_thread = is_signalling_thread and is_mcast_mode_m + elif mcast_mode_mn[0] == 0 and mcast_mode_mn[1] == 1: + is_signalling_thread = is_signalling_thread and is_mcast_mode_n + elif mcast_mode_mn[0] == 1 and mcast_mode_mn[1] == 1: + is_mcast_mode_m_or_n = is_mcast_mode_m or is_mcast_mode_n + is_signalling_thread = is_signalling_thread and is_mcast_mode_m_or_n - return dst_rank, is_signalling_thread_final + return dst_rank, is_signalling_thread @staticmethod def create( @@ -383,6 +383,7 @@ class PipelineTmaAsync(PipelineAsync): barrier_storage: cute.Pointer = None, cta_layout_vmnk: Optional[cute.Layout] = None, tidx: Optional[Int32] = None, + mcast_mode_mn: tuple[int, int] = (1, 1), ): """ This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync. @@ -400,6 +401,8 @@ class PipelineTmaAsync(PipelineAsync): :type cta_layout_vmnk: cute.Layout | None :param tidx: thread index to consumer async threads :type tidx: Int32 | None + :param mcast_mode_mn: Tuple of two integers, specifying whether mcast is enabled for the m and n modes. At least one of the two integers must be 1. + :type mcast_mode_mn: tuple[int, int] """ if not isinstance(barrier_storage, cute.Pointer): raise ValueError( @@ -425,7 +428,9 @@ class PipelineTmaAsync(PipelineAsync): ( dst_rank, is_signalling_thread, - ) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx) + ) = PipelineTmaAsync.init_empty_barrier_arrive_signal( + cta_layout_vmnk, tidx, mcast_mode_mn + ) if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: dst_rank = None else: @@ -653,11 +658,107 @@ class PipelineTmaStore(PipelineAsync): self.sync_object_full.tail() +@dataclass(frozen=True) +class PipelineOrder: + """ + PipelineOrder is used for managing ordered pipeline execution with multiple groups. + + This class implements a pipeline ordering mechanism where work is divided into groups + and stages, allowing for controlled progression through pipeline stages with proper + synchronization between different groups. + + The pipeline ordering works as follows: + - The pipeline is divided into 'length' number of groups + - Each group has 'depth' number of stages + - Groups execute in a specific order with synchronization barriers + - Each group waits for the previous group to complete before proceeding + + **Example:** + + .. code-block:: python + + # Create pipeline order with 3 groups, each with 2 stages + pipeline_order = PipelineOrder.create( + barrier_storage=smem_ptr, # shared memory pointer for barriers + depth=2, # 2 stages per group + length=3, # 3 groups total + group_id=0, # current group ID (0, 1, or 2) + producer_group=producer_warp # cooperative group for producers + ) + + # In the pipeline loop + for stage in range(num_stages): + pipeline_order.wait() # Wait for previous group to complete + # Process current stage + pipeline_order.arrive() # Signal completion to next group + """ + + sync_object_full: SyncObject + depth: int + length: int + group_id: int + state: PipelineState + + @staticmethod + def create( + barrier_storage: cute.Pointer, + depth: int, + length: int, + group_id: int, + producer_group: CooperativeGroup, + ): + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.AsyncThread + + producer = (producer_type, producer_group) + + num_stages = depth * length + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer + ) + + pipeline_init_wait() + + return PipelineOrder( + sync_object_full, + depth, + length, + group_id, + PipelineState( + depth, + Int32(0), + Int32(0), + Int32(group_id == 0), # phase + ), + ) + + def get_barrier_for_current_stage_idx(self, group_id): + return self.state.index * self.length + group_id + + def arrive(self): + signalling_id = (self.group_id + 1) % self.length + idx = self.get_barrier_for_current_stage_idx(signalling_id) + cute.arch.mbarrier_arrive(self.sync_object_full.get_barrier(idx)) + self.state.advance() + + def wait(self): + idx = self.get_barrier_for_current_stage_idx(self.group_id) + cute.arch.mbarrier_wait( + self.sync_object_full.get_barrier(idx), self.state.phase + ) + + ################################################################# # Utilities to help user of pipeline to simplify the workflow ################################################################# +@dataclass(frozen=True) class ImmutableResourceHandle: __origin: PipelineAsync __immutable_state: PipelineState @@ -703,10 +804,11 @@ class ImmutableResourceHandle: self.__origin, self.__immutable_state.__new_from_mlir_values__(values) ) + class PipelineProducer: """A class representing a producer in an asynchronous pipeline. - The Producer class manages the producer side of an asynchronous pipeline, handling + This class manages the producer side of an asynchronous pipeline, handling synchronization and state management for producing data. It provides methods for acquiring, committing, and advancing through pipeline stages. @@ -719,22 +821,31 @@ class PipelineProducer: **Examples:** - .. code-block:: python + .. code-block:: python - pipeline = PipelineAsync.create(...) - producer = pipeline.create_producer(producer_group, stages) - for i in range(iterations): - handle = producer.acquire_and_advance() # Wait for buffer to be empty - # Produce data - producer.commit(handle) # Signal data is ready - # An alternative way to do this is: - # handle.commit() # Signal data is ready + pipeline = PipelineAsync.create(...) + producer, consumer = pipeline.make_participants() + for i in range(iterations): + # Try to acquire the current buffer without blocking + try_acquire_token = producer.try_acquire() + + # Do something else independently + ... + + # Wait for current buffer to be empty & Move index to next stage + # If try_acquire_token is True, return immediately + # If try_acquire_token is False, block until buffer is empty + handle = producer.acquire_and_advance(try_acquire_token) + + # Produce data + handle.commit() """ __pipeline: PipelineAsync __state: PipelineState __group: CooperativeGroup + @dataclass(frozen=True) class ImmutableResourceHandle(ImmutableResourceHandle): @property def barrier(self): @@ -749,6 +860,7 @@ class PipelineProducer: def commit(self): """Signal that data production is complete for the current stage. + This allows consumers to start processing the data. """ self.get_origin().producer_commit( @@ -769,6 +881,10 @@ class PipelineProducer: self.__state = state self.__group = group + def reset(self): + """Reset the count of how many handles this producer has committed.""" + self.__state.reset_count() + def acquire( self, try_acquire_token: Optional[Boolean] = None, @@ -794,13 +910,18 @@ class PipelineProducer: def acquire_and_advance( self, try_acquire_token: Optional[Boolean] = None ) -> ImmutableResourceHandle: - """Wait for the current buffer to be empty before producing data. - Then advance to the next stage. - This is a blocking operation. + """Acquire the current buffer and advance to the next pipeline stage. - :param try_acquire_token: Optional token to try to acquire the buffer + This method combines the acquire() and advance() operations into a single call. + It first waits for the current buffer to be empty before producing data, + then advances the pipeline to the next stage. + + :param try_acquire_token: Token indicating whether to try non-blocking acquire. + If True, returns immediately without waiting. If False or None, blocks + until buffer is empty. :type try_acquire_token: Optional[Boolean] - :return: A handle to the producer for committing the data + :return: A handle to the producer that can be used to commit data to the + acquired buffer stage :rtype: ImmutableResourceHandle """ handle = self.acquire(try_acquire_token) @@ -808,27 +929,37 @@ class PipelineProducer: return handle def try_acquire(self) -> Boolean: - """Try to acquire the current buffer without blocking. + """Attempt to acquire the current buffer without blocking. - :return: True if acquisition was successful, False otherwise + This method tries to acquire the current buffer stage for producing data + without waiting. It can be used to check buffer availability before + committing to a blocking acquire operation. + + :return: A boolean token indicating whether the buffer was successfully acquired :rtype: Boolean """ return self.__pipeline.producer_try_acquire(self.__state) def commit(self, handle: Optional[ImmutableResourceHandle] = None): """Signal that data production is complete for the current stage. + This allows consumers to start processing the data. + + :param handle: Optional handle to commit, defaults to None + :type handle: Optional[ImmutableResourceHandle] + :raises AssertionError: If provided handle does not belong to this producer """ if handle is not None: - assert ( - handle.get_origin() is self - ), "ResourceHandle does not belong to this PipelineProducer instance" + assert handle.get_origin() is self, ( + "ResourceHandle does not belong to this PipelineProducer instance" + ) handle.commit() else: self.__pipeline.producer_commit(self.__state) def tail(self): """Ensure all used buffers are properly synchronized before producer exit. + This should be called before the producer finishes to avoid dangling signals. """ self.__pipeline.producer_tail(self.__state) @@ -854,6 +985,7 @@ class PipelineProducer: self.__pipeline, self.__state.__new_from_mlir_values__(values), self.__group ) + class PipelineConsumer: """A class representing a consumer in an asynchronous pipeline. @@ -869,22 +1001,35 @@ class PipelineConsumer: :type __group: CooperativeGroup **Examples:** - .. code-block:: python - pipeline = PipelineAsync.create(...) - consumer = pipeline.create_consumer(consumer_group, stages) - for i in range(iterations): - handle = consumer.wait_and_advance() # Wait for data to be ready - # Consume data - consumer.release(handle) # Signal buffer is empty - # An alternative way to do this is: - # handle.release() # Signal buffer is empty + .. code-block:: python + + pipeline = PipelineAsync.create(...) + producer, consumer = pipeline.make_participants() + for i in range(iterations): + # Try to wait for buffer to be full + try_wait_token = consumer.try_wait() + + # Do something else independently + ... + + # Wait for buffer to be full & Move index to next stage + # If try_wait_token is True, return immediately + # If try_wait_token is False, block until buffer is full + handle = consumer.wait_and_advance(try_wait_token) + + # Consume data + handle.release( ) # Signal buffer is empty + + # Alternative way to do this is: + # handle.release() # Signal buffer is empty """ __pipeline: PipelineAsync __state: PipelineState __group: CooperativeGroup + @dataclass(frozen=True) class ImmutableResourceHandle(ImmutableResourceHandle): def release(self): """Signal that data production is complete for the current stage. @@ -908,14 +1053,20 @@ class PipelineConsumer: self.__group = group self.__state = state - def wait(self, try_wait_token: Optional[Boolean] = None) -> ImmutableResourceHandle: - """Wait for data to be ready in the current buffer. - This is a blocking operation. + def reset(self): + """Reset the count of how many handles this consumer has consumed.""" + self.__state.reset_count() - :param try_wait_token: Optional token to try to wait for the buffer + def wait(self, try_wait_token: Optional[Boolean] = None) -> ImmutableResourceHandle: + """Wait for data to be ready in the current buffer. This is a blocking operation + that will not return until data is available. + + :param try_wait_token: Token used to attempt a non-blocking wait for the buffer. + If provided and True, returns immediately if buffer is not ready. :type try_wait_token: Optional[Boolean] - :return: A handle to the consumer for releasing the data - :rtype: PipelineConsumerHandle + :return: An immutable handle to the consumer that can be used to release the buffer + once data consumption is complete + :rtype: ImmutableResourceHandle """ self.__pipeline.consumer_wait(self.__state, try_wait_token) handle = PipelineConsumer.ImmutableResourceHandle( @@ -924,29 +1075,40 @@ class PipelineConsumer: return handle def advance(self): - """Move to the next pipeline stage.""" + """Advance the consumer to the next pipeline stage. + + This updates the internal state to point to the next buffer in the pipeline. + Should be called after consuming data from the current buffer. + """ self.__state.advance() def wait_and_advance( self, try_wait_token: Optional[Boolean] = None ) -> ImmutableResourceHandle: - """Wait for data to be ready in the current buffer. - Then advance to the next stage. - This is a blocking operation. + """Atomically wait for data and advance to next pipeline stage. - :param try_wait_token: Optional token to try to wait for the buffer + This is a convenience method that combines wait() and advance() into a single + atomic operation. It will block until data is available in the current buffer, + then automatically advance to the next stage. + + :param try_wait_token: Token used to attempt a non-blocking wait for the buffer. + If provided and True, returns immediately if buffer is not ready. :type try_wait_token: Optional[Boolean] - :return: A handle to the consumer for releasing the data - :rtype: PipelineConsumerHandle + :return: An immutable handle to the consumer that can be used to release the buffer + once data consumption is complete + :rtype: ImmutableResourceHandle """ handle = self.wait(try_wait_token) self.advance() return handle def try_wait(self) -> Boolean: - """Try to check if data is ready without blocking. + """Non-blocking check if data is ready in the current buffer. - :return: True if data is ready, False otherwise + This method provides a way to test if data is available without blocking. + Unlike wait(), this will return immediately regardless of buffer state. + + :return: True if data is ready to be consumed, False if the buffer is not yet ready :rtype: Boolean """ return self.__pipeline.consumer_try_wait(self.__state) @@ -956,9 +1118,9 @@ class PipelineConsumer: This allows producers to start producing new data. """ if handle is not None: - assert ( - handle.get_origin() is self - ), "ResourceHandle does not belong to this PipelineConsumer instance" + assert handle.get_origin() is self, ( + "ResourceHandle does not belong to this PipelineConsumer instance" + ) handle.release() else: self.__pipeline.consumer_release(self.__state) diff --git a/python/CuTeDSL/cutlass/torch.py b/python/CuTeDSL/cutlass/torch.py index e5ee5777..09c30609 100644 --- a/python/CuTeDSL/cutlass/torch.py +++ b/python/CuTeDSL/cutlass/torch.py @@ -18,14 +18,13 @@ from typing import Optional, Type, Union from cutlass.cute.typing import ( Numeric, Boolean, - Float, - Integer, TFloat32, Float8E4M3B11FNUZ, Float8E4M3FN, Float8E5M2, Float8E8M0FNU, Float4E2M1FN, + Int4, Tensor, ) from cutlass.cute.runtime import from_dlpack @@ -169,7 +168,7 @@ def convert_cute_tensor( ) -> Tensor: """ Change the value of the cute tensor to make its value converted from a fp32 torch tensor. - Used for fp8 types tensor creatation now. + Used for fp8 and int4 types tensor creatation now. """ # if torch_tensor is on cpu, create a gpu copy if f32_torch_tensor.device.type == "cpu": @@ -177,6 +176,7 @@ def convert_cute_tensor( # Fp8 type need explicit type conversion if dtype in { + Int4, Float8E5M2, Float8E4M3FN, Float8E8M0FNU, @@ -281,7 +281,9 @@ def cute_tensor_like( """ # allocate device buffer for cute tensor - if cutlass_dtype.is_float and cutlass_dtype.width <= 8: + if (cutlass_dtype.is_float and cutlass_dtype.width <= 8) or ( + cutlass_dtype.is_integer and cutlass_dtype.width == 4 + ): torch_dtype = torch.int8 else: torch_dtype = dtype(cutlass_dtype) @@ -298,7 +300,9 @@ def cute_tensor_like( cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) # initialize the cute tensor data - if cutlass_dtype.is_float and cutlass_dtype.width <= 8: + if (cutlass_dtype.is_float and cutlass_dtype.width <= 8) or ( + cutlass_dtype.is_integer and cutlass_dtype.width == 4 + ): cute_tensor = convert_cute_tensor( data_ref.to(dtype=torch.float32), cute_tensor, diff --git a/python/CuTeDSL/cutlass/utils/__init__.py b/python/CuTeDSL/cutlass/utils/__init__.py index aec0a186..024a4096 100644 --- a/python/CuTeDSL/cutlass/utils/__init__.py +++ b/python/CuTeDSL/cutlass/utils/__init__.py @@ -33,6 +33,10 @@ from .blackwell_helpers import ( from .hopper_helpers import ( sm90_get_smem_store_op, + make_smem_layout_a as sm90_make_smem_layout_a, + make_smem_layout_b as sm90_make_smem_layout_b, + make_smem_layout_epi as sm90_make_smem_layout_epi, + compute_tile_shape_or_override, ) from .blockscaled_layout import ( @@ -56,14 +60,11 @@ from .tensormap_manager import ( TensorMapManager, ) -from .smem_allocator import SmemAllocator +from .smem_allocator import SmemAllocator, get_smem_capacity_in_bytes +from .tmem_allocator import TmemAllocator from .layout import LayoutEnum -from .smem_capacity import ( - get_smem_capacity_in_bytes, -) - from .distributed_helpers import ( spin_lock_wait, spin_lock_multimem_arrive, @@ -76,9 +77,14 @@ from .distributed_helpers import ( sm_wise_inter_gpu_multimem_barrier, ) +from . import hopper_helpers as sm90 +from . import blackwell_helpers as sm100 + + __all__ = [ "get_smem_capacity_in_bytes", "SmemAllocator", + "TmemAllocator", "LayoutEnum", "WorkTileInfo", "PersistentTileSchedulerParams", @@ -90,4 +96,15 @@ __all__ = [ "create_initial_search_state", "GroupedGemmTileSchedulerHelper", "HardwareInfo", + "compute_epilogue_tile_shape", + "get_smem_store_op", + "get_tmem_load_op", + "get_num_tmem_alloc_cols", + "make_smem_layout_a", + "make_smem_layout_b", + "make_smem_layout_epi", + "make_trivial_tiled_mma", + "make_blockscaled_trivial_tiled_mma", + "sm90", + "sm100", ] diff --git a/python/CuTeDSL/cutlass/utils/ampere_helpers.py b/python/CuTeDSL/cutlass/utils/ampere_helpers.py deleted file mode 100644 index 1341756f..00000000 --- a/python/CuTeDSL/cutlass/utils/ampere_helpers.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from enum import Enum -from typing_extensions import deprecated -import warnings - - -@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead") -class SmemCapacity(Enum): - SM80_SMEM_CAPACITY_BYTES = (164 - 1) * 1024 - SM86_SMEM_CAPACITY_BYTES = (100 - 1) * 1024 - SM89_SMEM_CAPACITY_BYTES = (100 - 1) * 1024 - - -warnings.warn( - "SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead", - DeprecationWarning, - stacklevel=2, -) -# Dictionary to map compute capability to SMEM capacity -SMEM_CAPACITY = { - "sm80": SmemCapacity.SM80_SMEM_CAPACITY_BYTES.value, - "sm86": SmemCapacity.SM86_SMEM_CAPACITY_BYTES.value, - "sm89": SmemCapacity.SM89_SMEM_CAPACITY_BYTES.value, -} diff --git a/python/CuTeDSL/cutlass/utils/blackwell_helpers.py b/python/CuTeDSL/cutlass/utils/blackwell_helpers.py index 6fb6bf4d..13cac5c2 100644 --- a/python/CuTeDSL/cutlass/utils/blackwell_helpers.py +++ b/python/CuTeDSL/cutlass/utils/blackwell_helpers.py @@ -9,11 +9,8 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from enum import Enum from math import log2, ceil from typing import List, Type, Union, Tuple -from typing_extensions import deprecated -import warnings from cutlass.cutlass_dsl import ( Float16, @@ -40,7 +37,7 @@ from cutlass.cute.nvgpu.tcgen05 import ( MmaMXF8Op, MmaMXF4Op, MmaMXF4NVF4Op, - OperandSource, + OperandSource as Tcgen05OperandSource, OperandMajorMode, CtaGroup, Ld16x64bOp, @@ -63,23 +60,8 @@ from cutlass.cute.nvgpu.cpasync import ( ) from cutlass.utils.layout import LayoutEnum - -@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead") -class SmemCapacity(Enum): - SM100_SMEM_CAPACITY_BYTES = (228 - 1) * 1024 - SM120_SMEM_CAPACITY_BYTES = (100 - 1) * 1024 - - -warnings.warn( - "SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead", - DeprecationWarning, - stacklevel=2, -) -# Dictionary to map compute capability to SMEM capacity -SMEM_CAPACITY = { - "sm100": SmemCapacity.SM100_SMEM_CAPACITY_BYTES.value, - "sm120": SmemCapacity.SM120_SMEM_CAPACITY_BYTES.value, -} +# Type alias for documentation clarity +OperandSource = Tcgen05OperandSource @dsl_user_op @@ -368,8 +350,8 @@ def get_tmem_load_op( :rtype: cute.CopyAtom :raises ValueError: If the function cannot handle the given combination of accumulation - and dimension types, or if it cannot determine the appropriate configuration based on - the input parameters. + and dimension types, or if it cannot determine the appropriate configuration based on + the input parameters. """ is_m_major = layout_d.is_m_major_c() @@ -557,9 +539,9 @@ def get_num_tmem_alloc_cols( # Sum up the num_tmem_alloc_cols_per_tensor num_tmem_alloc_cols = sum(num_tmem_alloc_cols_per_tensor) - # Round up num_tmem_cols_total to the nearest power of 2 + # Round up num_tmem_cols_total to the nearest power of 2 and make sure it is at least 32 if rounding: - num_tmem_alloc_cols = 1 << ceil(log2(num_tmem_alloc_cols)) + num_tmem_alloc_cols = max(1 << ceil(log2(num_tmem_alloc_cols)), 32) # Validate the number of TMEM allocation columns SM100_TMEM_CAPACITY_COLUMNS = 512 @@ -627,6 +609,54 @@ def get_smem_layout_atom_ab( return SmemLayoutAtomKind.K_INTER +@dsl_user_op +def make_smem_layout( + leading_mode: OperandMajorMode, + smem_tile_shape: cute.Tile, + a_dtype: Type[Numeric], + num_stages: int, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + """Construct a staged SMEM layout for an operand given its major mode and tile shape. + + This helper: + + 1. Selects a SMEM layout atom using simple heuristics based on the operand's major mode, + element type, and the size of the major dimension in ``smem_tile_shape``. + 2. Tiles the atom to ``smem_tile_shape`` and appends a staging dimension of length ``num_stages``. + 3. Orders the ``(M, N, stage)`` axes so the major dimension is contiguous, then coalesces. + + :param leading_mode: Operand major mode (``MN`` or ``K``) of the staged operand. + :type leading_mode: cute.nvgpu.tcgen05.OperandMajorMode + :param smem_tile_shape: 2D SMEM tile shape to stage (before the staging dimension is appended). + :type smem_tile_shape: cute.Tile + :param a_dtype: Element type of the staged operand. + :type a_dtype: Type[Numeric] + :param num_stages: Number of pipeline stages (depth of the staging dimension). + :type num_stages: int + + :return: Staged SMEM layout for the operand. + :rtype: Union[cute.Layout, cute.ComposedLayout] + """ + + smem_layout_atom_kind = get_smem_layout_atom_ab( + leading_mode, a_dtype, smem_tile_shape, loc=loc, ip=ip + ) + smem_layout_atom = make_smem_layout_atom( + smem_layout_atom_kind, a_dtype, loc=loc, ip=ip + ) + + is_k_major = leading_mode == OperandMajorMode.K + smem_layout = cute.tile_to_shape( + smem_layout_atom, + cute.append(smem_tile_shape, num_stages), + order=(0, 1, 2) if is_k_major else (1, 0, 2), + ) + return cute.coalesce(smem_layout, target_profile=(1, 1, 1), loc=loc, ip=ip) + + @dsl_user_op def make_smem_layout_a( tiled_mma: cute.TiledMma, @@ -638,6 +668,7 @@ def make_smem_layout_a( ip=None, ) -> Union[cute.Layout, cute.ComposedLayout]: """This function helps with: + 1. Get the partitioned shape of the A tensor based on the tiled_mma & MMA tiler. 2. Select the heuristic SMEM layout atom based on the A tensor's majorness, the data type, and the major mode size. 3. cute.Tile the SMEM layout atom to the MMA tile shape. @@ -664,26 +695,18 @@ def make_smem_layout_a( cute.size(a_smem_shape[0][0], loc=loc, ip=ip) * a_smem_shape[1], cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2], ) + smem_layout_atom_kind = get_smem_layout_atom_ab( + tiled_mma.op.a_major_mode, a_dtype, a_smem_shape_mn_k, loc=loc, ip=ip + ) a_smem_layout_atom = make_smem_layout_atom( - get_smem_layout_atom_ab( - tiled_mma.op.a_major_mode, - a_dtype, - a_smem_shape_mn_k, - loc=loc, - ip=ip, - ), - a_dtype, - loc=loc, - ip=ip, + smem_layout_atom_kind, a_dtype, loc=loc, ip=ip ) - a_smem_layout_staged = tile_to_mma_shape( - a_smem_layout_atom, - cute.append(a_smem_shape, num_stages, loc=loc, ip=ip), - order=((1, 0, 2) if not is_k_major else (0, 1, 2)), - loc=loc, - ip=ip, + + a_smem_shape = cute.append(a_smem_shape, num_stages, loc=loc, ip=ip) + order = (2, 1, 3) if not is_k_major else (1, 2, 3) + return tile_to_mma_shape( + a_smem_layout_atom, a_smem_shape, order=order, loc=loc, ip=ip ) - return a_smem_layout_staged @dsl_user_op @@ -697,6 +720,7 @@ def make_smem_layout_b( ip=None, ) -> Union[cute.Layout, cute.ComposedLayout]: """This function helps: + 1. Get the partitioned shape of the B tensor based on the tiled_mma & MMA tiler. 2. Select the heuristic SMEM layout atom based on the B tensor's majorness, the data type, and the major mode size. 3. cute.Tile the SMEM layout atom to the MMA tile shape. @@ -723,27 +747,19 @@ def make_smem_layout_b( cute.size(b_smem_shape[0][0], loc=loc, ip=ip) * b_smem_shape[1], cute.size(b_smem_shape[0][1], loc=loc, ip=ip) * b_smem_shape[2], ) - b_smem_layout_atom = make_smem_layout_atom( - get_smem_layout_atom_ab( - tiled_mma.op.b_major_mode, - b_dtype, - b_smem_shape_nk, - loc=loc, - ip=ip, - ), - b_dtype, - loc=loc, - ip=ip, + + smem_layout_atom_kind = get_smem_layout_atom_ab( + tiled_mma.op.b_major_mode, b_dtype, b_smem_shape_nk, loc=loc, ip=ip ) - b_smem_layout_staged = tile_to_mma_shape( - b_smem_layout_atom, - cute.append(b_smem_shape, num_stages, loc=loc, ip=ip), - order=((1, 0, 2) if not is_k_major else (0, 1, 2)), - loc=loc, - ip=ip, + b_smem_layout_atom = make_smem_layout_atom( + smem_layout_atom_kind, b_dtype, loc=loc, ip=ip ) - return b_smem_layout_staged + b_smem_shape = cute.append(b_smem_shape, num_stages, loc=loc, ip=ip) + order = (2, 1, 3) if not is_k_major else (1, 2, 3) + return tile_to_mma_shape( + b_smem_layout_atom, b_smem_shape, order=order, loc=loc, ip=ip + ) @dsl_user_op @@ -801,6 +817,7 @@ def make_smem_layout_epi( ip=None, ) -> Union[cute.Layout, cute.ComposedLayout]: """This function helps: + 1. Select the heuristic SMEM layout atom based on the epilog tile shape, the epilog tensor's majorness, and the element type. 2. cute.Tile the SMEM layout atom to the epilog tile shape. @@ -823,21 +840,17 @@ def make_smem_layout_epi( cute.shape(epi_tile, loc=loc, ip=ip), loc=loc, ip=ip ) - c_smem_layout_atom = make_smem_layout_atom( - get_smem_layout_atom_epi( - epi_layout, - epi_dtype, - epi_tile, - loc=loc, - ip=ip, - ), - epi_dtype, - loc=loc, - ip=ip, + smem_atom_kind = get_smem_layout_atom_epi( + epi_layout, epi_dtype, epi_tile, loc=loc, ip=ip ) + c_smem_layout_atom = make_smem_layout_atom( + smem_atom_kind, epi_dtype, loc=loc, ip=ip + ) + + epilog_shape = cute.append(epilog_shape, epi_stage, loc=loc, ip=ip) epi_smem_layout_staged = cute.tile_to_shape( c_smem_layout_atom, - cute.append(epilog_shape, epi_stage, loc=loc, ip=ip), + epilog_shape, order=((1, 0, 2) if not epi_layout.is_n_major_c() else (0, 1, 2)), loc=loc, ip=ip, @@ -875,7 +888,7 @@ def make_trivial_tiled_mma( :param mma_tiler_mn: The shape (M, N, K) of the MMA tiler. :type mma_tiler_mn: Tuple[int, int] :param a_source: The source of operand A (SMEM by default or TMEM). - :type a_source: OperandSource + :type a_source: cutlass.cute.nvgpu.tcgen05.OperandSource :return: A tiled MMA atom. :rtype: cute.TiledMma @@ -926,7 +939,7 @@ def make_trivial_tiled_mma( else: raise TypeError(f"unsupported ab_dtype, got {ab_dtype}") - return cute.make_tiled_mma(cute.make_mma_atom(mma_op)) + return cute.make_tiled_mma(cute.make_mma_atom(mma_op), loc=loc, ip=ip) @dsl_user_op @@ -961,7 +974,7 @@ def make_blockscaled_trivial_tiled_mma( :param mma_tiler_mn: The shape (M, N, K) of the MMA tiler. :type mma_tiler_mn: Tuple[int, int] :param a_source: The source of operand A (SMEM by default or TMEM). - :type a_source: OperandSource + :type a_source: cutlass.cute.nvgpu.tcgen05.OperandSource :return: A tiled MMA atom. :rtype: cute.TiledMma @@ -996,7 +1009,7 @@ def make_blockscaled_trivial_tiled_mma( else: raise TypeError(f"unsupported ab_dtype, got {ab_dtype}") - return cute.make_tiled_mma(cute.make_mma_atom(mma_op)) + return cute.make_tiled_mma(cute.make_mma_atom(mma_op), loc=loc, ip=ip) @dsl_user_op @@ -1133,3 +1146,19 @@ def cluster_shape_to_tma_atom_SFB( raise ValueError( f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}" ) + + +__all__ = [ + "compute_epilogue_tile_shape", + "get_smem_store_op", + "get_tmem_load_op", + "get_num_tmem_alloc_cols", + "make_smem_layout_a", + "make_smem_layout_b", + "make_smem_layout_epi", + "make_trivial_tiled_mma", + "make_blockscaled_trivial_tiled_mma", + "cluster_shape_to_tma_atom_A", + "cluster_shape_to_tma_atom_B", + "cluster_shape_to_tma_atom_SFB", +] diff --git a/python/CuTeDSL/cutlass/utils/blockscaled_layout.py b/python/CuTeDSL/cutlass/utils/blockscaled_layout.py index fa1e2eb7..31e64e1b 100644 --- a/python/CuTeDSL/cutlass/utils/blockscaled_layout.py +++ b/python/CuTeDSL/cutlass/utils/blockscaled_layout.py @@ -10,7 +10,6 @@ # is strictly prohibited. from dataclasses import dataclass, field -from typing import Union from cutlass.cutlass_dsl import dsl_user_op @@ -97,6 +96,7 @@ def make_smem_layout_sfa( ) -> cute.Layout: """ Make smem layout for SFA based on: + 1. BlockScaledBasicChunk 2. MMA tiler shape 3. Scale factor vector size @@ -161,6 +161,7 @@ def make_smem_layout_sfb( ) -> cute.Layout: """ Make smem layout for SFB based on: + 1. BlockScaledBasicChunk 2. MMA tiler shape 3. Scale factor vector size @@ -224,6 +225,7 @@ def make_tmem_layout_sfa( ip=None, ) -> cute.Layout: """Make tmem layout for SFA based on: + 1. SFA smem layout per stage 2. Cta tile shape m 3. tiled MMA atom thr size @@ -241,7 +243,7 @@ def make_tmem_layout_sfa( :return: TMEM layout for SFA :rtype: cute.Layout """ - atom_thr_size = cute.size(tiled_mma.thr_id.shape) + atom_thr_size = cute.size(tiled_mma.thr_id.shape, loc=loc, ip=ip) cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size sfa_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfa( @@ -261,6 +263,7 @@ def make_tmem_layout_sfb( ip=None, ) -> cute.Layout: """Make tmem layout for SFB based on: + 1. SFB smem layout per stage 2. Cta tile shape m 3. tiled MMA atom thr size @@ -278,7 +281,7 @@ def make_tmem_layout_sfb( :return: TMEM layout for SFB :rtype: cute.Layout """ - atom_thr_size = cute.size(tiled_mma.thr_id.shape) + atom_thr_size = cute.size(tiled_mma.thr_id.shape, loc=loc, ip=ip) cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size sfb_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfb( diff --git a/python/CuTeDSL/cutlass/utils/distributed_helpers.py b/python/CuTeDSL/cutlass/utils/distributed_helpers.py index 5853c56c..6e569e0c 100644 --- a/python/CuTeDSL/cutlass/utils/distributed_helpers.py +++ b/python/CuTeDSL/cutlass/utils/distributed_helpers.py @@ -13,16 +13,16 @@ from functools import partial from typing import Tuple import cutlass.cute as cute -from cutlass.cutlass_dsl import T, dsl_user_op, while_generate +from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir import ir -from cutlass._mlir.dialects import arith, llvm, nvvm, scf +from cutlass._mlir.dialects import llvm, nvvm from cutlass._mlir.dialects.nvvm import ( MemOrderKind, MemScopeKind, AtomicOpKind, ) -from cutlass.cute.typing import Pointer, Int32, Boolean +from cutlass.cute.typing import Pointer, Int32 @dsl_user_op @@ -41,32 +41,42 @@ def atomicAdd(dst_ptr: Pointer, val: Int32, loc=None, ip=None) -> Int32: @cute.jit def ld_bypass(input_tensor: cute.Tensor): - fragment = cute.make_fragment(input_tensor.layout, input_tensor.element_type) + fragment = cute.make_rmem_tensor(input_tensor.layout, input_tensor.element_type) copy_atom_load = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), input_tensor.element_type, memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE, memory_scope=cute.nvgpu.common.MemoryScope.SYS, ) - cute.copy(copy_atom_load, input_tensor, fragment) + cute.copy_atom_call(copy_atom_load, input_tensor, fragment) vals = fragment.load() return vals + @cute.jit -def spin_lock_wait(lock_ptr: Pointer, expect_count: Int32, mem_order : str = "relaxed", mem_scope : str = "gpu", loc=None, ip=None) -> None: +def spin_lock_wait( + lock_ptr: Pointer, + expect_count: Int32, + mem_order: str = "relaxed", + mem_scope: str = "gpu", + loc=None, + ip=None, +) -> None: """ wait on a spin lock until the expected count is reached. """ res = 0 while res != expect_count: res = nvvm.atomicrmw( - T.i32(), - AtomicOpKind.CAS, - lock_ptr.llvm_ptr, + T.i32(), + AtomicOpKind.CAS, + lock_ptr.llvm_ptr, Int32(0).ir_value(loc=loc, ip=ip), b=Int32(expect_count).ir_value(loc=loc, ip=ip), - mem_order=MemOrderKind.ACQUIRE if mem_order == "acquire" else MemOrderKind.RELAXED, - syncscope=MemScopeKind.GPU if mem_scope == "gpu" else MemScopeKind.SYS + mem_order=( + MemOrderKind.ACQUIRE if mem_order == "acquire" else MemOrderKind.RELAXED + ), + syncscope=MemScopeKind.GPU if mem_scope == "gpu" else MemScopeKind.SYS, ) @@ -77,7 +87,7 @@ def multimem_red_add_sys_release(mc_ptr: Pointer, loc=None, ip=None) -> None: """ llvm.inline_asm( None, - [mc_ptr.toint().ir_value()], + [mc_ptr.toint().ir_value(loc=loc, ip=ip)], "multimem.red.release.sys.global.add.u32 [$0], 1;", "l", has_side_effects=True, @@ -86,6 +96,7 @@ def multimem_red_add_sys_release(mc_ptr: Pointer, loc=None, ip=None) -> None: ip=ip, ) + @dsl_user_op def multimem_red_add_gpu_relaxed(mc_ptr: Pointer, loc=None, ip=None) -> None: """ @@ -93,7 +104,7 @@ def multimem_red_add_gpu_relaxed(mc_ptr: Pointer, loc=None, ip=None) -> None: """ llvm.inline_asm( None, - [mc_ptr.toint().ir_value()], + [mc_ptr.toint().ir_value(loc=loc, ip=ip)], "multimem.red.relaxed.gpu.global.add.u32 [$0], 1;", "l", has_side_effects=True, @@ -110,7 +121,9 @@ def spin_lock_multimem_arrive(lock_ptr: Pointer, loc=None, ip=None) -> None: multimem_red_add_gpu_relaxed(lock_ptr, loc=loc, ip=ip) -def sm_wise_inter_gpu_multimem_barrier(barrier : Pointer, barrier_mc : Pointer, num_ranks, loc=None, ip=None) -> None : +def sm_wise_inter_gpu_multimem_barrier( + barrier: Pointer, barrier_mc: Pointer, num_ranks, loc=None, ip=None +) -> None: """ barrier for inter-gpu sm-wise """ @@ -119,7 +132,9 @@ def sm_wise_inter_gpu_multimem_barrier(barrier : Pointer, barrier_mc : Pointer, pid = bidx + bidy * bdimx + bidz * bdimx * bdimy multimem_red_add_sys_release(barrier_mc + pid, loc=loc, ip=ip) cute.arch.fence_proxy(cute.arch.ProxyKind.alias) - spin_lock_wait(barrier + pid, num_ranks, mem_order="acquire", mem_scope="sys", loc=loc, ip=ip) + spin_lock_wait( + barrier + pid, num_ranks, mem_order="acquire", mem_scope="sys", loc=loc, ip=ip + ) @dsl_user_op @@ -129,9 +144,9 @@ def multimem_ld_reduce_base( ptx_string: str = "", loc=None, ip=None, -) -> Tuple[Int32, Int32, Int32, Int32]: +) -> Tuple[Int32, Int32, Int32, Int32]: # ld reduce 8xf16 elts - mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value() + mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) return_struct = llvm.inline_asm( ir.Type.parse("!llvm.struct<(i32,i32,i32,i32)>"), [mc_ptr_int], @@ -146,11 +161,26 @@ def multimem_ld_reduce_base( return return_regs[0], return_regs[1], return_regs[2], return_regs[3] -multimem_ld_reduce_8xf16 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.f16x2 {$0, $1, $2, $3}, [$4];") -multimem_ld_reduce_4xf32 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.v4.f32 {$0, $1, $2, $3}, [$4];") -multimem_ld_reduce_8xbf16 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.bf16x2 {$0, $1, $2, $3}, [$4];") -multimem_ld_reduce_16xe4m3 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e4m3x4 {$0, $1, $2, $3}, [$4];") -multimem_ld_reduce_16xe5m2 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e5m2x4 {$0, $1, $2, $3}, [$4];") +multimem_ld_reduce_8xf16 = partial( + multimem_ld_reduce_base, + ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.f16x2 {$0, $1, $2, $3}, [$4];", +) +multimem_ld_reduce_4xf32 = partial( + multimem_ld_reduce_base, + ptx_string="multimem.ld_reduce.sys.relaxed.global.add.v4.f32 {$0, $1, $2, $3}, [$4];", +) +multimem_ld_reduce_8xbf16 = partial( + multimem_ld_reduce_base, + ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.bf16x2 {$0, $1, $2, $3}, [$4];", +) +multimem_ld_reduce_16xe4m3 = partial( + multimem_ld_reduce_base, + ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e4m3x4 {$0, $1, $2, $3}, [$4];", +) +multimem_ld_reduce_16xe5m2 = partial( + multimem_ld_reduce_base, + ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e5m2x4 {$0, $1, $2, $3}, [$4];", +) @dsl_user_op @@ -165,7 +195,7 @@ def multimem_st_4xb32( ip=None, ) -> None: # st 4x32 bits of data - mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value() + mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) llvm.inline_asm( T.i32(), [mc_ptr_int, x, y, z, w], @@ -176,4 +206,3 @@ def multimem_st_4xb32( loc=loc, ip=ip, ) - diff --git a/python/CuTeDSL/cutlass/utils/gemm_helper.py b/python/CuTeDSL/cutlass/utils/gemm_helper.py new file mode 100644 index 00000000..e69de29b diff --git a/python/CuTeDSL/cutlass/utils/grouped_gemm_tile_scheduler_helper.py b/python/CuTeDSL/cutlass/utils/grouped_gemm_tile_scheduler_helper.py index a51bae62..b990e960 100644 --- a/python/CuTeDSL/cutlass/utils/grouped_gemm_tile_scheduler_helper.py +++ b/python/CuTeDSL/cutlass/utils/grouped_gemm_tile_scheduler_helper.py @@ -12,7 +12,12 @@ from typing import List, Tuple import cutlass.cute as cute -from cutlass.cutlass_dsl import Int32, extract_mlir_values, new_from_mlir_values +from cutlass.cutlass_dsl import ( + Int32, + extract_mlir_values, + new_from_mlir_values, + const_expr, +) from cutlass._mlir import ir from cutlass.utils.static_persistent_tile_scheduler import PersistentTileSchedulerParams @@ -270,7 +275,7 @@ class GroupedGemmTileSchedulerHelper: clamp_value = 0 idx = 1 sum_per_thread = value_per_thread - while idx < cute.arch.WARP_SIZE: + while const_expr(idx < cute.arch.WARP_SIZE): value = cute.arch.shuffle_sync_up( sum_per_thread, idx, mask_and_clamp=clamp_value ) @@ -292,7 +297,7 @@ class GroupedGemmTileSchedulerHelper: :return: The problem shape tensor for the specified group :rtype: cute.Tensor """ - cur_problem_mnkl = cute.make_fragment( + cur_problem_mnkl = cute.make_rmem_tensor( cute.make_layout(4), problem_shape_mnkl.element_type ) cute.autovec_copy(problem_shape_mnkl[(group_idx, None)], cur_problem_mnkl) diff --git a/python/CuTeDSL/cutlass/utils/hardware_info.py b/python/CuTeDSL/cutlass/utils/hardware_info.py index e86fcbef..2329b76d 100644 --- a/python/CuTeDSL/cutlass/utils/hardware_info.py +++ b/python/CuTeDSL/cutlass/utils/hardware_info.py @@ -42,7 +42,6 @@ class HardwareInfo: # Getting the max active clusters for a given cluster size def get_max_active_clusters(self, cluster_size: int) -> int: - self._get_device_function() if self._cuda_driver_version_lt(11, 8): raise RuntimeError( "CUDA Driver version < 11.8, cannot get _max_active_clusters" @@ -52,6 +51,8 @@ class HardwareInfo: f"Cluster size must be between 1 and 32, {cluster_size} is not supported" ) + device_fn = self._get_device_function(self.device) + max_shared_memory_per_block = self._checkCudaErrors( driver.cuDeviceGetAttribute( driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, @@ -67,12 +68,16 @@ class HardwareInfo: ) max_dynamic_shared_memory = self._checkCudaErrors( driver.cuOccupancyAvailableDynamicSMemPerBlock( - self.kernel, 1, 1 # numBlocks # blockSize + self.kernel, + 1, + 1, # numBlocks # blockSize ) ) max_active_blocks = self._checkCudaErrors( driver.cuOccupancyMaxActiveBlocksPerMultiprocessor( - self.kernel, 1, max_dynamic_shared_memory # blockSize, + self.kernel, + 1, + max_dynamic_shared_memory, # blockSize, ) ) # allow non-portable cluster size to support detection of non-portable cluster size @@ -168,7 +173,7 @@ class HardwareInfo: ) # get a empty kernel to compute occupancy - def _get_device_function(self) -> None: - self.compiled_kernel = cute.compile(self._host_function) - self.module = next(iter(self.compiled_kernel.cuda_modules.modules)).cuda_module - self.kernel = next(iter(self.compiled_kernel.cuda_modules.modules)).kernel_ptr + def _get_device_function(self, device) -> None: + self.compiled_kernel = cute.compile(self._host_function).to(device) + self.kernel = self.compiled_kernel.exec_context.kernel_functions[0] + self.module = self.compiled_kernel.exec_context.module.cuda_modules[0] diff --git a/python/CuTeDSL/cutlass/utils/hopper_helpers.py b/python/CuTeDSL/cutlass/utils/hopper_helpers.py index 4cd2bae3..bbf68b4a 100644 --- a/python/CuTeDSL/cutlass/utils/hopper_helpers.py +++ b/python/CuTeDSL/cutlass/utils/hopper_helpers.py @@ -9,10 +9,7 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Type, Tuple -from enum import Enum -from typing_extensions import deprecated -import warnings +from typing import Type, Union, Tuple, Optional from cutlass.utils.layout import LayoutEnum from cutlass.cutlass_dsl import ( @@ -20,41 +17,31 @@ from cutlass.cutlass_dsl import ( BFloat16, Float8E5M2, Float8E4M3FN, + Int8, + Uint8, Numeric, NumericMeta, dsl_user_op, ) -import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu.common import CopyUniversalOp from cutlass.cute.nvgpu.warp import StMatrix8x8x16bOp from cutlass.cute.nvgpu.warpgroup import ( MmaF16BF16Op, MmaF8Op, + MmaI8Op, OperandMajorMode, - OperandSource, + OperandSource as WarpgroupOperandSource, + make_smem_layout_atom, ) - -@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead") -class SmemCapacity(Enum): - SM90_SMEM_CAPACITY_BYTES = (228 - 1) * 1024 - - -warnings.warn( - "SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead", - DeprecationWarning, - stacklevel=2, -) -# Dictionary to map compute capability to SMEM capacity -SMEM_CAPACITY = { - "sm90": SmemCapacity.SM90_SMEM_CAPACITY_BYTES.value, -} +# Type alias for documentation clarity +OperandSource = WarpgroupOperandSource @dsl_user_op -def sm90_get_smem_store_op( +def get_smem_store_op( layout_d: LayoutEnum, elem_ty_d: Type[Numeric], elem_ty_acc: Type[Numeric], @@ -98,6 +85,10 @@ def sm90_get_smem_store_op( return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip) +# Temporarily wire to existing code to avoid breaking changes +sm90_get_smem_store_op = get_smem_store_op + + def make_trivial_tiled_mma( a_dtype: Type[Numeric], b_dtype: Type[Numeric], @@ -136,9 +127,9 @@ def make_trivial_tiled_mma( """ if a_dtype in {Float16, BFloat16}: - if cutlass.const_expr(a_dtype != b_dtype): + if a_dtype != b_dtype: raise TypeError(f"Type mismatch: {a_dtype} != {b_dtype}") - if cutlass.const_expr(a_dtype.width != b_dtype.width): + if a_dtype.width != b_dtype.width: raise TypeError(f"Type width mismatch: {a_dtype.width} != {b_dtype.width}") mma_op = MmaF16BF16Op( @@ -162,11 +153,23 @@ def make_trivial_tiled_mma( a_leading_mode, b_leading_mode, ) + elif a_dtype in {Int8, Uint8} and b_dtype in {Int8, Uint8}: + mma_op = MmaI8Op( + a_dtype, + b_dtype, + acc_dtype, + (*tiler_mn, 32), + a_source, + a_leading_mode, + b_leading_mode, + ) else: raise TypeError(f"unsupported a_dtype and b_dtype, got {a_dtype} and {b_dtype}") return cute.make_tiled_mma(cute.make_mma_atom(mma_op), atom_layout_mnk) + +@dsl_user_op def get_smem_layout_atom( layout: LayoutEnum, element_type: Type[Numeric], @@ -180,7 +183,7 @@ def get_smem_layout_atom( :param layout: Layout enum of the tensor :type layout: LayoutEnum :param element_type: Data type of the elements - :type element_type: type[cutlass.Numeric] + :type element_type: type[Numeric] :param major_mode_size: Size of the major mode dimension :type major_mode_size: int @@ -207,3 +210,230 @@ def get_smem_layout_atom( if major_mode_size_bits % sw32_num_contiguous_bits == 0: return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW32 return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_INTER + + +@dsl_user_op +def make_smem_layout_a( + a_layout: LayoutEnum, + mma_tiler_mnk: cute.Tile, + a_dtype: Type[Numeric], + num_stages: int, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + """This function helps with: + + 1. Get the partitioned shape of the A tensor based on the MMA tiler. + 2. Select the heuristic SMEM layout atom based on the A tensor's majorness, the data type, and the major mode size. + 3. cute.Tile the SMEM layout atom to the MMA tile shape. + 4. Stage the SMEM layout based on the number of stages. + + :param a_layout: The layout enum for tensor A + :type a_layout: LayoutEnum + :param mma_tiler_mnk: The MMA tile shape + :type mma_tiler_mnk: cute.cute.Tile + :param a_dtype: The element type for tensor A + :type a_dtype: Type[Numeric] + :param num_stages: The number of pipeline stages for tensor A + :type num_stages: int + + :return: SMEM layout for tensor A + :rtype: Union[cute.Layout, cute.ComposedLayout] + """ + # Extract A tensor shape from the MMA tiler (M dimension) + a_tile_shape_mnk = mma_tiler_mnk + a_smem_shape = cute.slice_(a_tile_shape_mnk, (None, 0, None), loc=loc, ip=ip) + + # Determine if K is the major mode and get the major mode size + is_k_major = a_layout.is_k_major_a() + a_major_mode_size = a_tile_shape_mnk[2] if is_k_major else a_tile_shape_mnk[0] + + # Create SMEM layout atom for A tensor based on major mode and data type + a_smem_layout_atom = make_smem_layout_atom( + get_smem_layout_atom(a_layout, a_dtype, a_major_mode_size, loc=loc, ip=ip), + a_dtype, + loc=loc, + ip=ip, + ) + + # Tile the SMEM layout atom to the A tensor shape and add staging dimension + a_smem_layout_staged = cute.tile_to_shape( + a_smem_layout_atom, + cute.append(a_smem_shape, num_stages), + order=(0, 1, 2) if is_k_major else (0, 1, 2), + loc=loc, + ip=ip, + ) + + return a_smem_layout_staged + + +@dsl_user_op +def make_smem_layout_b( + b_layout: LayoutEnum, + mma_tiler_mnk: cute.Tile, + b_dtype: Type[Numeric], + num_stages: int, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + """This function helps with: + + 1. Get the partitioned shape of the B tensor based on the MMA tiler. + 2. Select the heuristic SMEM layout atom based on the B tensor's majorness, the data type, and the major mode size. + 3. cute.Tile the SMEM layout atom to the MMA tile shape. + 4. Stage the SMEM layout based on the number of stages. + + :param b_layout: The layout enum for tensor B + :type b_layout: LayoutEnum + :param mma_tiler_mnk: The MMA tile shape + :type mma_tiler_mnk: cute.cute.Tile + :param b_dtype: The element type for tensor B + :type b_dtype: Type[Numeric] + :param num_stages: The number of pipeline stages for tensor B + :type num_stages: int + + :return: SMEM layout for tensor B + :rtype: Union[cute.Layout, cute.ComposedLayout] + """ + # Extract B tensor shape from the MMA tiler (N and K dimensions) + b_smem_shape = cute.slice_(mma_tiler_mnk, (0, None, None), loc=loc, ip=ip) + + # Determine if K is the major mode and get the major mode size + is_k_major = b_layout.is_k_major_b() + b_major_mode_size = mma_tiler_mnk[2] if is_k_major else mma_tiler_mnk[1] + + # Create SMEM layout atom for B tensor based on major mode and data type + b_smem_layout_atom = make_smem_layout_atom( + get_smem_layout_atom(b_layout, b_dtype, b_major_mode_size, loc=loc, ip=ip), + b_dtype, + loc=loc, + ip=ip, + ) + + # Tile the SMEM layout atom to the B tensor shape and add staging dimension + b_smem_layout_staged = cute.tile_to_shape( + b_smem_layout_atom, + cute.append(b_smem_shape, num_stages), + order=((1, 0, 2) if not is_k_major else (0, 1, 2)), + loc=loc, + ip=ip, + ) + + return b_smem_layout_staged + + +@dsl_user_op +def make_smem_layout_epi( + epi_dtype: Type[Numeric], + epi_layout: LayoutEnum, + epi_tile: cute.Tile, + epi_stage: int, + smem_trg_shape: Optional[cute.Layout] = None, + smem_order: Optional[tuple] = None, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + """This function helps: + + 1. Select the heuristic SMEM layout atom based on the epilog tile shape, + the epilog tensor's majorness, and the element type. + 2. cute.Tile the SMEM layout atom to the epilog tile shape. + 3. Stage the SMEM layout based on the number of stages. + + :param epi_dtype: The element type for the epilog tensor. + :type epi_dtype: Type[Numeric] + :param epi_layout: The layout enum for the epilog tensor. + :type epi_layout: LayoutEnum + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.cute.Tile + :param epi_stage: The stage of the epilog tensor. + :type epi_stage: int + :param smem_trg_shape: Target shape for SMEM layout (optional). + :type smem_trg_shape: cute.Layout | None + :param smem_order: Order for SMEM layout (optional). + :type smem_order: tuple | None + + :return: SMEM layout for epilog tensors (usually C & D which are processed in the epilog) + :rtype: Union[cute.Layout, cute.ComposedLayout] + """ + # Extract output tensor shape from epilog tile + o_smem_shape = epi_tile + + # Determine major mode size based on layout (M or N major) + o_major_mode_size = epi_tile[1] if epi_layout.is_n_major_c() else epi_tile[0] + + # Create SMEM layout atom for output tensor based on layout and data type + o_smem_layout_atom = make_smem_layout_atom( + get_smem_layout_atom(epi_layout, epi_dtype, o_major_mode_size, loc=loc, ip=ip), + epi_dtype, + loc=loc, + ip=ip, + ) + + # Determine target shape and order for staging (use provided or default) + trg_shape = ( + smem_trg_shape + if smem_trg_shape is not None + else cute.append(o_smem_shape, epi_stage) + ) + + order = ( + smem_order + if smem_order is not None + else (1, 0, 2) + if epi_layout.is_m_major_c() + else (0, 1, 2) + ) + + # Tile the SMEM layout atom to the target shape with staging + o_smem_layout_staged = cute.tile_to_shape( + o_smem_layout_atom, trg_shape, order, loc=loc, ip=ip + ) + + return o_smem_layout_staged + + +def compute_tile_shape_or_override( + tile_shape_mnk: tuple[int, int, int], + element_type: type[Numeric], + is_cooperative: bool = False, + epi_tile_override: Optional[tuple[int, int]] = None, +) -> tuple[int, int]: + """Compute the epilogue tile shape or use override if provided. + + :param tile_shape_mnk: CTA tile shape (M,N,K) + :type tile_shape_mnk: Tuple[int, int, int] + :param element_type: Data type of elements + :type element_type: type[Numeric] + :param is_cooperative: Whether to use cooperative approach + :type is_cooperative: bool + :param epi_tile_override: Optional override for epilogue tile shape + :type epi_tile_override: Tuple[int, int] or None + + :return: Computed epilogue tile shape + :rtype: Tuple[int, int] + """ + if epi_tile_override is not None: + return epi_tile_override + if is_cooperative: + tile_m = min(128, cute.size(tile_shape_mnk, mode=[0])) + tile_n = min(32, cute.size(tile_shape_mnk, mode=[1])) + return (tile_m, tile_n) + else: + n_perf = 64 if element_type.width == 8 else 32 + tile_m = min(64, cute.size(tile_shape_mnk, mode=[0])) + tile_n = min(n_perf, cute.size(tile_shape_mnk, mode=[1])) + return (tile_m, tile_n) + + +__all__ = [ + "get_smem_store_op", + "make_smem_layout_a", + "make_smem_layout_b", + "make_smem_layout_epi", + "compute_tile_shape_or_override", +] diff --git a/python/CuTeDSL/cutlass/utils/layout.py b/python/CuTeDSL/cutlass/utils/layout.py index 4560c266..8f325668 100644 --- a/python/CuTeDSL/cutlass/utils/layout.py +++ b/python/CuTeDSL/cutlass/utils/layout.py @@ -34,6 +34,18 @@ class LayoutEnum(Enum): else warpgroup.OperandMajorMode.MN ) + def is_k_major_a(self): + return self == LayoutEnum.ROW_MAJOR + + def is_m_major_a(self): + return self == LayoutEnum.COL_MAJOR + + def is_n_major_b(self): + return self == LayoutEnum.COL_MAJOR + + def is_k_major_b(self): + return self == LayoutEnum.ROW_MAJOR + def is_n_major_c(self): return self == LayoutEnum.ROW_MAJOR @@ -43,7 +55,14 @@ class LayoutEnum(Enum): @staticmethod def from_tensor(tensor: cute.Tensor) -> "LayoutEnum": ret = None - if tensor.leading_dim == 1: + if isinstance(tensor.leading_dim, tuple): + if tensor.leading_dim[0] == 1: + ret = LayoutEnum.ROW_MAJOR + elif tensor.leading_dim[0] == 0: + ret = LayoutEnum.COL_MAJOR + else: + raise ValueError(f"Invalid leading dimension: {tensor.leading_dim}") + elif tensor.leading_dim == 1: ret = LayoutEnum.ROW_MAJOR elif tensor.leading_dim == 0: ret = LayoutEnum.COL_MAJOR diff --git a/python/CuTeDSL/cutlass/utils/smem_allocator.py b/python/CuTeDSL/cutlass/utils/smem_allocator.py index 2500c06e..bd62ed0e 100644 --- a/python/CuTeDSL/cutlass/utils/smem_allocator.py +++ b/python/CuTeDSL/cutlass/utils/smem_allocator.py @@ -9,92 +9,188 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Type, Union, overload - -from cutlass.cutlass_dsl import Int8, Numeric, NumericMeta, CutlassBaseDSL +from typing import Optional, Type, Union, overload import cutlass.cute as cute from cutlass.cute.arch import get_dyn_smem, get_dyn_smem_size +from cutlass.cutlass_dsl import CutlassBaseDSL, Int8, Numeric, NumericMeta, dsl_user_op + +SMEM_CAPACITY_MAP = { + "sm_120": (100 - 1) * 1024, + "sm_100": (228 - 1) * 1024, + "sm_90": (228 - 1) * 1024, + "sm_80": (164 - 1) * 1024, + "sm_86": (100 - 1) * 1024, + "sm_89": (100 - 1) * 1024, +} class SmemAllocator: - """A class for managing shared memory allocation on GPU. + """A helper class for managing shared memory allocation on GPU. - This class manages a chunk of shared memory and provides APIs for sub-allocation - inside the chunk. - - :ivar _base: The current base address of the shared memory as an i8 typed dynamic value. - :type _base: cute.Pointer - :ivar _allocated_bytes: The total number of bytes allocated in shared memory. - :type _allocated_bytes: int + This class manages shared memory and provides APIs for allocation of raw bytes, + numeric types, arrays, and tensors with specified layouts and alignments. .. note:: - This class is responsible for managing the allocation of tensors in shared memory. - The base pointer is aligned to 1024 bytes upon initialization. + - The base pointer is aligned to 1024 bytes upon initialization. + - There is no need to explicitly specify shared memory size in kernel launch. + - Currently only supports static layouts. Dynamic layouts are not supported. + + **Examples**: + + .. code-block:: python + + smem = SmemAllocator() + + # Allocate raw bytes + buf_ptr = smem.allocate(100) # 100 bytes + + # Allocate numeric type + int8_ptr = smem.allocate(Int8) # 1 byte + + # Define a struct + @cute.struct + class SharedStorage: + alpha: cutlass.Float32 + x: cutlass.Int32 + + # Allocate struct + struct_ptr = smem.allocate(SharedStorage) # 8 bytes + + # use of struct members + struct_ptr.alpha = 1.0 + struct_ptr.x = 2 + + # Allocate array + int8_array = smem.allocate_array(Int8, 10) # 10 bytes + + # Allocate tensor + layout = cute.make_layout((16, 16)) + tensor = smem.allocate_tensor(Int8, layout) # 256 bytes """ - def __init__(self): - """Initialize the SmemAllocator instance. + @staticmethod + def capacity_in_bytes(compute_capability: str) -> int: + """Get the shared memory capacity in bytes for a given compute capability. - Creates a dynamic shared memory base pointer of type i8, aligned to 1024 bytes. + Returns the maximum shared memory capacity in bytes available for the specified + GPU compute capability. + + :param compute_capability: The compute capability string (e.g. "70", "75", "80") + :type compute_capability: str + :return: The shared memory capacity in bytes + :rtype: int + :raises ValueError: If the compute capability is not supported """ - self._base = get_dyn_smem(Int8, alignment=1024) + if compute_capability not in SMEM_CAPACITY_MAP: + raise ValueError(f"Unsupported compute capability: {compute_capability}") + return SMEM_CAPACITY_MAP[compute_capability] + + @dsl_user_op + def __init__(self, *, loc=None, ip=None): + """Initialize a new SmemAllocator instance. + + Creates a new shared memory allocator with a base pointer aligned to 1024 bytes. + Tracks the allocator instance for memory management. + + :param loc: Source location information for debugging, defaults to None + :type loc: Optional[ir.Location] + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: Optional[ir.InsertionPoint] + """ + self._base = get_dyn_smem(Int8, alignment=1024, loc=loc, ip=ip) self._allocated_bytes = 0 CutlassBaseDSL.track_smem_allocator(self, lambda cls: cls._allocated_bytes) @overload - def allocate(self, size_or_type: int, byte_alignment: int) -> cute.Pointer: ... + def allocate( + self, size_or_type: int, byte_alignment: int, *, loc=None, ip=None + ) -> cute.Pointer: ... @overload def allocate( - self, size_or_type: cute.struct, byte_alignment: int + self, size_or_type: Type[Numeric], byte_alignment: int, *, loc=None, ip=None ) -> cute.Pointer: ... - def allocate(self, size_or_type, byte_alignment: int = 1) -> cute.Pointer: + @overload + def allocate( + self, size_or_type: cute.struct, byte_alignment: int, *, loc=None, ip=None + ) -> cute.Pointer: ... + + @dsl_user_op + def allocate( + self, size_or_type, byte_alignment: int = 1, *, loc=None, ip=None + ) -> cute.Pointer: """Allocate a block of memory with specified size and alignment. - This method adjusts the base pointer to ensure proper alignment and updates - the internal state to track allocated memory. + This method allocates a block of shared memory with the specified size and alignment requirements. + It supports allocating raw bytes, numeric types(as scalar value), and struct types. - :param size_or_type: The number of bytes to allocate or a struct class - :type size_or_type: Union[int, cute.struct] - :param byte_alignment: The byte alignment requirement, defaults to 1 (no alignment) + :param size_or_type: The allocation specification, which can be: + - An integer specifying the number of bytes to allocate + - A Numeric type (e.g., Int8, Float32) to allocate space for one element + - A struct type to allocate space for the entire struct + :type size_or_type: Union[int, Type[Numeric], cute.struct] + :param byte_alignment: The minimum byte alignment requirement for the allocation, defaults to 1 :type byte_alignment: int, optional - :return: Pointer to the start of the allocated memory block or struct instance + :param loc: Source location information for debugging, defaults to None + :type loc: Optional[ir.Location] + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: Optional[ir.InsertionPoint] + :return: For raw bytes and numeric types, returns a pointer to the allocated memory. + For struct types, returns an initialized struct instance at the allocated location. :rtype: cute.Pointer :raises ValueError: If size is negative or alignment is less than 1 + :raises TypeError: If size_or_type is not an integer, Numeric type, or struct :raises RuntimeError: If allocation would exceed available shared memory """ - if isinstance(size_or_type, cute.struct): - alignment = max(byte_alignment, size_or_type.__alignof__()) - base_ptr = self.allocate(size_or_type.__sizeof__(), alignment) - return size_or_type(base_ptr) - num_bytes = size_or_type - if num_bytes < 0: - raise ValueError("num_bytes must be non-negative") + if cute.is_integer(size_or_type): + size_in_bytes = size_or_type + elif isinstance(size_or_type, cute.struct): + size_in_bytes = size_or_type.__sizeof__() + alignment = max(byte_alignment, size_or_type.__alignof__()) + base_ptr = self.allocate(size_in_bytes, alignment, loc=loc, ip=ip) + return size_or_type(base_ptr) + elif isinstance(size_or_type, NumericMeta): + size_in_bytes = cute.ceil_div(size_or_type.width, 8) + base_ptr = self.allocate(size_in_bytes, byte_alignment, loc=loc, ip=ip) + return cute.recast_ptr(base_ptr, dtype=size_or_type, loc=loc, ip=ip) + else: + raise TypeError( + f"Expected int, struct, or numeric type, but got {type(size_or_type)}" + ) + + if cute.is_static(size_in_bytes) and size_in_bytes < 0: + raise ValueError("size must be non-negative") + if byte_alignment < 1: - raise ValueError("byte_alignment must be at least 1") + raise ValueError("`byte_alignment` must be at least 1") self._base = self._base.align(byte_alignment) ptr = self._base - self._base += num_bytes + self._base += size_in_bytes if self._allocated_bytes % byte_alignment != 0: self._allocated_bytes += ( byte_alignment - self._allocated_bytes % byte_alignment ) - self._allocated_bytes += num_bytes + self._allocated_bytes += size_in_bytes # Check bounds against available dynamic shared memory cute.testing.assert_( - self._allocated_bytes <= get_dyn_smem_size(), + self._allocated_bytes <= get_dyn_smem_size(loc=loc, ip=ip), f"Allocation failed: shared memory allocation exceeds available memory set in kernel launch. " f"Allocated bytes: {self._allocated_bytes} bytes. " f"Please reduce the allocation or set a larger smem size in kernel launch.", + loc=loc, + ip=ip, ) return ptr - def allocate_array(self, element_type: Type[Numeric], num_elems: int = 1): + @dsl_user_op + def allocate_array( + self, element_type: Type[Numeric], num_elems: int = 1, *, loc=None, ip=None + ): """Allocate an array of elements in shared memory. :param element_type: The type of elements to allocate @@ -114,23 +210,29 @@ class SmemAllocator: ) ptr = self.allocate( - element_type.width // 8 * num_elems, element_type.width // 8 + element_type.width // 8 * num_elems, element_type.width // 8, loc=loc, ip=ip ) - return cute.recast_ptr(ptr, dtype=element_type) + return cute.recast_ptr(ptr, dtype=element_type, loc=loc, ip=ip) + @dsl_user_op def allocate_tensor( self, element_type: Type[Numeric], layout: Union[int, cute.Layout, cute.ComposedLayout], byte_alignment: int = 1, - swizzle: cute.Swizzle = None, + swizzle: Optional[cute.Swizzle] = None, + *, + loc=None, + ip=None, ): """Allocate a tensor in shared memory. + Note: Currently only supports static layouts. Dynamic layouts are not supported. + :param element_type: The type of elements in the tensor :type element_type: Type[Numeric] - :param layout: The layout specification for the tensor + :param layout: The layout specification for the tensor. Must be a static layout. :type layout: Union[int, cute.Layout, cute.ComposedLayout] :param byte_alignment: The byte alignment requirement, defaults to 1 :type byte_alignment: int, optional @@ -152,33 +254,35 @@ class SmemAllocator: and isinstance(layout.inner, cute.Swizzle) ) and (swizzle is not None): raise TypeError( - f"Invalid tensor type: cannot be both iterator swizzle (PDSL) and swizzle layout(PISL) at the same time." + "Invalid tensor type: cannot be both iterator swizzle (PDSL) and swizzle layout(PISL) at the same time." ) if isinstance(layout, int): layout = cute.make_layout(layout) - profile = layout(0) + profile = layout(0, loc=loc, ip=ip) if isinstance(profile, tuple): raise TypeError( - f"cannot allocate a shared memory tensor with a non-integer iterator" + "cannot allocate a shared memory tensor with a non-integer iterator" ) - if not cute.is_static(layout.type): - raise NotImplementedError(f"dynamic layout is not supported: {layout.type}") + if not cute.is_static(layout): + raise NotImplementedError(f"dynamic layout is not supported: {layout}") # At least align the allocation to the natural alignment given by the element type if element_type.width // 8 > byte_alignment: byte_alignment = element_type.width // 8 # Relevant only for sub-byte data types: verify that the entire allocation is byte-aligned - cosize_in_bits = cute.cosize(layout) * element_type.width + cosize_in_bits = cute.cosize(layout, loc=loc, ip=ip) * element_type.width assert isinstance(cosize_in_bits, int) if cosize_in_bits % 8 != 0: raise ValueError("invalid allocation that is not byte-aligned") num_bytes = cosize_in_bits // 8 - ptr = self.allocate(num_bytes, byte_alignment) - ptr = cute.recast_ptr(ptr, swizzle, dtype=element_type) - res = cute.make_tensor(ptr, layout) - return res + ptr = self.allocate(num_bytes, byte_alignment, loc=loc, ip=ip) + ptr = cute.recast_ptr(ptr, swizzle, dtype=element_type, loc=loc, ip=ip) + return cute.make_tensor(ptr, layout, loc=loc, ip=ip) + + +get_smem_capacity_in_bytes = SmemAllocator.capacity_in_bytes diff --git a/python/CuTeDSL/cutlass/utils/smem_capacity.py b/python/CuTeDSL/cutlass/utils/smem_capacity.py deleted file mode 100644 index 87ddb990..00000000 --- a/python/CuTeDSL/cutlass/utils/smem_capacity.py +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - - -SMEM_CAPACITY_MAP = { - "sm_120": (100 - 1) * 1024, - "sm_100": (228 - 1) * 1024, - "sm_90": (228 - 1) * 1024, - "sm_80": (164 - 1) * 1024, - "sm_86": (100 - 1) * 1024, - "sm_89": (100 - 1) * 1024, -} - - -def get_smem_capacity_in_bytes(compute_capability: str) -> int: - if compute_capability not in SMEM_CAPACITY_MAP: - raise ValueError(f"Unsupported compute capability: {compute_capability}") - return SMEM_CAPACITY_MAP[compute_capability] diff --git a/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py b/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py index 2873244d..e42f4d2f 100644 --- a/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py +++ b/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py @@ -90,6 +90,8 @@ class PersistentTileSchedulerParams: self, problem_shape_ntile_mnl: cute.Shape, cluster_shape_mnk: cute.Shape, + swizzle_size: int = 1, + raster_along_m: bool = True, *, loc=None, ip=None, @@ -102,17 +104,26 @@ class PersistentTileSchedulerParams: :type problem_shape_ntile_mnl: cute.Shape :param cluster_shape_mnk: The shape of the cluster in (m, n) dimensions. :type cluster_shape_mnk: cute.Shape + :param swizzle_size: Swizzling size in the unit of cluster. 1 means no swizzle + :type swizzle_size: int + :param raster_along_m: Rasterization order of clusters. Only used when swizzle_size > 1. + True means along M, false means along N. + :type raster_along_m: bool :raises ValueError: If cluster_shape_k is not 1. """ if cluster_shape_mnk[2] != 1: raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}") + if swizzle_size < 1: + raise ValueError(f"expect swizzle_size >= 1, but get {swizzle_size}") self.problem_shape_ntile_mnl = problem_shape_ntile_mnl # cluster_shape_mnk is kept for reconstruction self._cluster_shape_mnk = cluster_shape_mnk self.cluster_shape_mn = cluster_shape_mnk[:2] + self.swizzle_size = swizzle_size + self._raster_along_m = raster_along_m self._loc = loc # By default, we follow m major (col-major) raster order, so make a col-major layout @@ -124,9 +135,51 @@ class PersistentTileSchedulerParams: ip=ip, ) + if swizzle_size > 1: + problem_shape_ncluster_mnl = cute.round_up( + self.problem_layout_ncluster_mnl.shape, + (1, swizzle_size, 1) if raster_along_m else (swizzle_size, 1, 1), + ) + + if raster_along_m: + self.problem_layout_ncluster_mnl = cute.make_layout( + ( + problem_shape_ncluster_mnl[0], + (swizzle_size, problem_shape_ncluster_mnl[1] // swizzle_size), + problem_shape_ncluster_mnl[2], + ), + stride=( + swizzle_size, + (1, swizzle_size * problem_shape_ncluster_mnl[0]), + problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], + ), + loc=loc, + ip=ip, + ) + else: + self.problem_layout_ncluster_mnl = cute.make_layout( + ( + (swizzle_size, problem_shape_ncluster_mnl[0] // swizzle_size), + problem_shape_ncluster_mnl[1], + problem_shape_ncluster_mnl[2], + ), + stride=( + (1, swizzle_size * problem_shape_ncluster_mnl[1]), + swizzle_size, + problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], + ), + loc=loc, + ip=ip, + ) + def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [self.problem_shape_ntile_mnl, self._cluster_shape_mnk]: + for obj in [ + self.problem_shape_ntile_mnl, + self._cluster_shape_mnk, + self.swizzle_size, + self._raster_along_m, + ]: obj_values = extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -135,7 +188,13 @@ class PersistentTileSchedulerParams: def __new_from_mlir_values__(self, values): obj_list = [] for obj, n_items in zip( - [self.problem_shape_ntile_mnl, self._cluster_shape_mnk], self._values_pos + [ + self.problem_shape_ntile_mnl, + self._cluster_shape_mnk, + self.swizzle_size, + self._raster_along_m, + ], + self._values_pos, ): obj_list.append(new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] @@ -160,7 +219,7 @@ class PersistentTileSchedulerParams: # Total ctas in problem size num_ctas_mnl = tuple( - x * y + cute.size(x) * y for x, y in zip( self.problem_layout_ncluster_mnl.shape, self.cluster_shape_mn ) @@ -252,9 +311,8 @@ class StaticPersistentTileScheduler: new_num_tiles_executed, ) - # called by host - @dsl_user_op @staticmethod + @dsl_user_op def create( params: PersistentTileSchedulerParams, block_idx: Tuple[Integer, Integer, Integer], @@ -276,7 +334,6 @@ class StaticPersistentTileScheduler: :return: A StaticPersistentTileScheduler object. :rtype: StaticPersistentTileScheduler """ - params = params # Calculate the number of persistent clusters by dividing the total grid size # by the number of CTAs per cluster @@ -346,9 +403,14 @@ class StaticPersistentTileScheduler: self.params.problem_layout_ncluster_mnl, loc=loc, ip=ip ) - cur_cluster_coord = self.params.problem_layout_ncluster_mnl.get_hier_coord( - current_work_linear_idx, loc=loc, ip=ip - ) + if self.params.swizzle_size == 1: + cur_cluster_coord = self.params.problem_layout_ncluster_mnl.get_hier_coord( + current_work_linear_idx, loc=loc, ip=ip + ) + else: + cur_cluster_coord = self.params.problem_layout_ncluster_mnl.get_flat_coord( + current_work_linear_idx, loc=loc, ip=ip + ) # cur_tile_coord is a tuple of i32 values cur_tile_coord = tuple( diff --git a/python/CuTeDSL/cutlass/utils/tensormap_manager.py b/python/CuTeDSL/cutlass/utils/tensormap_manager.py index c6369c20..6db3f8bf 100644 --- a/python/CuTeDSL/cutlass/utils/tensormap_manager.py +++ b/python/CuTeDSL/cutlass/utils/tensormap_manager.py @@ -13,12 +13,12 @@ from dataclasses import dataclass from enum import Enum, auto from typing import Tuple -from cutlass.cutlass_dsl import const_expr - import cutlass._mlir.dialects.cute as _cute_ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass.cutlass_dsl import dsl_user_op import cutlass.cute as cute +from cutlass import const_expr class TensorMapUpdateMode(Enum): @@ -47,10 +47,14 @@ class TensorMapManager: # convert given cute.Pointer or cutlass.Int64 to a cute.Pointer to tensormap. # address_space: the address space of the resulting tensormap pointer. It could be generic or gmem + @dsl_user_op def get_tensormap_ptr( self, ptr: cute.Pointer, address_space=_cute_ir.AddressSpace.gmem, + *, + loc=None, + ip=None, ) -> cute.Pointer: if address_space not in [ _cute_ir.AddressSpace.gmem, @@ -58,17 +62,19 @@ class TensorMapManager: ]: raise ValueError(f"Invalid address space: {address_space} for tensormap") - gmem_ptr_i64 = ptr.toint().ir_value() + gmem_ptr_i64 = ptr.toint().ir_value(loc=loc, ip=ip) gmem_ptr_i64_align_ty = _cute_ir.ConstrainedIntType.get( self.bytes_per_tensormap, gmem_ptr_i64.type.width ) - gmem_ptr_i64_align = _cute_ir.assume(gmem_ptr_i64_align_ty, gmem_ptr_i64) + gmem_ptr_i64_align = _cute_ir.assume( + gmem_ptr_i64_align_ty, gmem_ptr_i64, loc=loc, ip=ip + ) gmem_ptr_ty = _cute_ir.PtrType.get( _cute_nvgpu_ir.TmaDescriptorTiledType.get(), address_space, self.bytes_per_tensormap, ) - return _cute_ir.inttoptr(gmem_ptr_ty, gmem_ptr_i64_align) + return _cute_ir.inttoptr(gmem_ptr_ty, gmem_ptr_i64_align, loc=loc, ip=ip) # init tensormap pointed by dst_ptr with the one inside copy_atom. # dst_ptr should be pointing to a global memory location or a smem location diff --git a/python/CuTeDSL/cutlass/utils/tmem_allocator.py b/python/CuTeDSL/cutlass/utils/tmem_allocator.py new file mode 100644 index 00000000..6b832436 --- /dev/null +++ b/python/CuTeDSL/cutlass/utils/tmem_allocator.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Optional, Type + +from cutlass import const_expr +from cutlass.cutlass_dsl import ( + Numeric, + Float32, + extract_mlir_values, + new_from_mlir_values, +) +import cutlass.pipeline as pipeline +import cutlass.cute as cute +from cutlass._mlir import ir + + +class TmemAllocator: + """A class for managing tensor memory allocation on Blackwell GPU. + + This class manages allocation/deallocation of tensor memory, including the mbarrier + synchronization for two cta use case. + + :ivar _alloc_result_dst_smem_ptr: The smem pointer that holds the base address of allocated tensor memory. + :type _alloc_result_dst_smem_ptr: cute.Pointer + :ivar _barrier_for_retrieve: The barrier for retrieving tensor memory ptr. + :type _barrier_for_retrieve: pipeline.NamedBarrier + :ivar _allocator_warp_id: The warp id of the allocator warp. + :type _allocator_warp_id: int + :ivar _is_two_cta: Whether the allocator is for two cta. + :type _is_two_cta: bool + :ivar _num_allocated_columns: The number of columns allocated in the tensor memory. + :type _num_allocated_columns: int + :ivar _two_cta_tmem_dealloc_mbar_ptr: The mbarrier pointer required when deallocating tensor memory for two cta. + :type _two_cta_tmem_dealloc_mbar_ptr: cute.Pointer + """ + + @cute.jit + def _init_dealloc_mbarrier(self): + assert self._two_cta_tmem_dealloc_mbar_ptr is not None, ( + "two_cta_tmem_dealloc_mbar_ptr is required for two cta" + ) + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + _is_allocator_warp = warp_idx == self._allocator_warp_id + if _is_allocator_warp: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + self._two_cta_tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + def __init__( + self, + alloc_result_dst_smem_ptr: cute.Pointer, + barrier_for_retrieve: pipeline.NamedBarrier, + allocator_warp_id: int = 0, + is_two_cta: bool = False, + num_allocated_columns: int = 0, + two_cta_tmem_dealloc_mbar_ptr: Optional[cute.Pointer] = None, + ): + """Initialize the TmemAllocator instance. + + Sets up the allocator state by initializing smem pointer that holds the base address of allocated tensor memory, allocator warp id, whether it is for two cta, number of allocated columns, and barrier for retrieving tensor memory ptr. + Meanwhile, it also initializes the mbarrier pointer for two cta deallocation case. + """ + # TODO: automatically maintain a smem address + self._alloc_result_dst_smem_ptr = alloc_result_dst_smem_ptr + self._allocator_warp_id = allocator_warp_id + self._is_two_cta = is_two_cta + self._num_allocated_columns = num_allocated_columns + self._two_cta_tmem_dealloc_mbar_ptr = two_cta_tmem_dealloc_mbar_ptr + self._barrier_for_retrieve = barrier_for_retrieve + + # Init tmem dealloc mbarrier if two cta + if const_expr(self._is_two_cta): + self._init_dealloc_mbarrier() + + def __extract_mlir_values__(self) -> list[ir.Value]: + values = extract_mlir_values(self._alloc_result_dst_smem_ptr) + if self._is_two_cta: + assert self._two_cta_tmem_dealloc_mbar_ptr is not None, ( + "2CTA mode requires the dealloc mbarrier" + ) + values.extend(extract_mlir_values(self._two_cta_tmem_dealloc_mbar_ptr)) + return values + + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "TmemAllocator": + assert len(values) == 2 if self._is_two_cta else 1 + new_alloc_result_dst_smem_ptr = new_from_mlir_values( + self._alloc_result_dst_smem_ptr, [values[0]] + ) + new_two_cta_tmem_dealloc_mbar_ptr = ( + new_from_mlir_values(self._two_cta_tmem_dealloc_mbar_ptr, [values[1]]) + if self._is_two_cta + else None + ) + return TmemAllocator( + new_alloc_result_dst_smem_ptr, + pipeline.NamedBarrier( + barrier_id=self._barrier_for_retrieve.barrier_id, + num_threads=self._barrier_for_retrieve.num_threads, + ), + self._allocator_warp_id, + self._is_two_cta, + self._num_allocated_columns, + new_two_cta_tmem_dealloc_mbar_ptr, + ) + + @cute.jit + def check_valid_num_columns(self, num_columns: int): + """Check if the number of columns is valid. + + This method checks if the number of columns is valid. + It checks if the number of columns is larger than 0, smaller than 512, a multiple of 32, and a power of two. + """ + # larger than 0 + if const_expr(num_columns < 0): + return False + # smaller than 512 + if const_expr(num_columns > 512): + return False + # multiple of 32 + if const_expr(num_columns % 32 != 0): + return False + # power of two + if const_expr(num_columns & (num_columns - 1) != 0): + return False + return True + + @cute.jit + def allocate(self, num_columns: int): + """Allocate a block of tensor memory. + + This method allocates a block of tensor memory from allocator warp and returns a handle to retrieve + the allocated tensor memory address. + """ + + assert self.check_valid_num_columns(num_columns), ( + "num_columns must be multiple of 32 and power of two, and between 0 and 512" + ) + assert self._num_allocated_columns + num_columns <= 512, ( + "total allocated columns must be less than or equal to 512" + ) + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + _is_allocator_warp = warp_idx == self._allocator_warp_id + if _is_allocator_warp: + cute.arch.alloc_tmem( + num_columns, + self._alloc_result_dst_smem_ptr, + is_two_cta=self._is_two_cta, + ) + self._num_allocated_columns += num_columns + + @cute.jit + def wait_for_alloc(self): + """Wait for the allocator warp to finish allocation. + + This method is used to synchronize the allocator warp with the other warps before retrieving tmem ptr. + """ + self._barrier_for_retrieve.arrive_and_wait() + + @cute.jit + def retrieve_ptr( + self, + dtype: Type[Numeric] = Float32, + ) -> cute.Pointer: + """Retrieve the pointer to the allocated tensor memory. + + This method can be called by all warps after allocation has been performed + by the allocator warp. + """ + return cute.arch.retrieve_tmem_ptr( + dtype, + alignment=16, + ptr_to_buffer_holding_addr=self._alloc_result_dst_smem_ptr, + ) + + @cute.jit + def relinquish_alloc_permit(self): + """Relinquish the tensor memory allocation permit. + + This method relinquishes the tensor memory allocation permit for the allocator warp, promising + the allocator warp will not allocate any more tensor memory. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + _is_allocator_warp = warp_idx == self._allocator_warp_id + if _is_allocator_warp: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=self._is_two_cta) + + @cute.jit + def free(self, tmem_ptr: cute.Pointer, num_columns: int = 0): + """Deallocate the tensor memory. + + This method sync on mbarrier (for two cta use case) and deallocates the tensor memory from the allocator warp. + User can optionally specify the number of columns to deallocate. If not specified, all allocated columns will be deallocated. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + _is_allocator_warp = warp_idx == self._allocator_warp_id + + assert num_columns <= self._num_allocated_columns, ( + "num_columns must be less than or equal to num_allocated_columns" + ) + if const_expr(num_columns != 0): + assert self.check_valid_num_columns(num_columns), "num_columns is invalid" + + num_deallocate_columns = ( + self._num_allocated_columns if num_columns == 0 else num_columns + ) # if num_columns is 0, deallocate all allocated columns + self._num_allocated_columns -= num_deallocate_columns + if _is_allocator_warp: + if const_expr(self._is_two_cta): + _cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + # Arrive and wait for dealloc signal from peer cta + cute.arch.mbarrier_arrive( + self._two_cta_tmem_dealloc_mbar_ptr, _cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(self._two_cta_tmem_dealloc_mbar_ptr, 0) + # Deallocate tmem + cute.arch.dealloc_tmem( + tmem_ptr, num_deallocate_columns, is_two_cta=self._is_two_cta + ) diff --git a/python/CuTeDSL/requirements.txt b/python/CuTeDSL/requirements.txt index f588ea75..75eb76cb 100644 --- a/python/CuTeDSL/requirements.txt +++ b/python/CuTeDSL/requirements.txt @@ -1,3 +1,3 @@ # Use `pip install -r requirements.txt` with the present file to install a # wheel consistent with the present state of the github repository -nvidia-cutlass-dsl==4.2.1 +nvidia-cutlass-dsl==4.3.0.dev0 diff --git a/python/cutlass_cppgen/__init__.py b/python/cutlass_cppgen/__init__.py index 9bdd259c..2491da63 100644 --- a/python/cutlass_cppgen/__init__.py +++ b/python/cutlass_cppgen/__init__.py @@ -133,7 +133,7 @@ def get_option_registry(): this._option_registry = OptionRegistry(device_cc()) return this._option_registry -this.__version__ = '4.2.1' +this.__version__ = '4.3.0' from cutlass_cppgen.backend import create_memory_pool from cutlass_cppgen.emit.pytorch import pytorch diff --git a/python/cutlass_library/emit_kernel_listing.py b/python/cutlass_library/emit_kernel_listing.py index fbe52eb5..9ab525a8 100755 --- a/python/cutlass_library/emit_kernel_listing.py +++ b/python/cutlass_library/emit_kernel_listing.py @@ -594,7 +594,6 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode # reduce L1 test runtime if reference kernel is not running on device. if mode == "functional_L1" and profiler_flags_for_verification == "host" : problem_waves = [0.5, 2.5] - if dynamic_cluster: if mode == "functional_L0": diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index 0d2449e7..b9c8751d 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -76,6 +76,7 @@ class GemmOperation: GemmKind.GroupedBlockScaledUniversal3x, GemmKind.BlockwiseUniversal3x, GemmKind.GroupedBlockwiseUniversal3x, + GemmKind.BlockScaledSparseUniversal3x, } self.is_3x = gemm_kind in kinds_3x self.prefix = "3x" if self.is_3x else "" @@ -174,6 +175,7 @@ class GemmOperation: OpcodeClass.WmmaTensorOp, OpcodeClass.SparseTensorOp, OpcodeClass.BlockScaledTensorOp, + OpcodeClass.BlockScaledSparseTensorOp, ] is_tensor_op = self.tile_description.math_instruction.opcode_class in tensor_ops @@ -348,7 +350,7 @@ class GemmOperation: opcode_class_main = self.tile_description.math_instruction.opcode_class instruction_shape = self.tile_description.math_instruction.instruction_shape tile_shape_m, tile_shape_n, tile_shape_k = self.tile_description.tile_shape - if opcode_class_main in [OpcodeClass.TensorOp, OpcodeClass.BlockScaledTensorOp, OpcodeClass.SparseTensorOp]: + if opcode_class_main in [OpcodeClass.TensorOp, OpcodeClass.BlockScaledTensorOp, OpcodeClass.SparseTensorOp, OpcodeClass.BlockScaledSparseTensorOp]: tile_shape_m = instruction_shape[0] tile_shape_n = instruction_shape[1] return (tile_shape_m, tile_shape_n, tile_shape_k) @@ -984,7 +986,7 @@ ${compile_guard_end} element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>" epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] - if opcode_class_main == OpcodeClass.BlockScaledTensorOp: + if opcode_class_main == OpcodeClass.BlockScaledTensorOp or opcode_class_main == OpcodeClass.BlockScaledSparseTensorOp: grouped = is_grouped(operation.gemm_kind) if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped): epi_tile_mn = "cute::Shape" @@ -1099,7 +1101,7 @@ using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape( sfn_vec_size = operation.ScaleFactorNVecSize sfk_vec_size = operation.ScaleFactorKVecSize blockwise_prepare_code = f""" -using {operation_name_str}_ScaleConfig = cutlass::detail::Sm{operation.arch}BlockwiseScaleConfig<{sfm_vec_size}, {sfn_vec_size}, {sfk_vec_size}>; +using {operation_name_str}_ScaleConfig = cutlass::detail::Sm{"90" if operation.arch == 90 else "1xx"}BlockwiseScaleConfig<{sfm_vec_size}, {sfn_vec_size}, {sfk_vec_size}>; using {operation_name_str}_LayoutSFA = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFA()); using {operation_name_str}_LayoutSFB = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFB()); """ @@ -1477,6 +1479,7 @@ class EmitGemmConfigurationLibrary: GemmKind.GroupedBlockScaledUniversal3x: EmitGemmUniversal3xInstance, GemmKind.BlockwiseUniversal3x: EmitGemmUniversal3xInstance, GemmKind.GroupedBlockwiseUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.BlockScaledSparseUniversal3x: EmitGemmUniversal3xInstance, } self.gemm_kind_wrappers = { @@ -1493,6 +1496,7 @@ class EmitGemmConfigurationLibrary: GemmKind.GroupedBlockScaledUniversal3x: 'GroupedBlockScaledGemmUniversal3xOperation', GemmKind.BlockwiseUniversal3x: 'BlockwiseGemmUniversal3xOperation', GemmKind.GroupedBlockwiseUniversal3x: 'GroupedBlockwiseGemmUniversal3xOperation', + GemmKind.BlockScaledSparseUniversal3x: 'BlockScaledSparseGemmUniversal3xOperation', } self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)" diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 063e8fb1..34fc7336 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -6729,6 +6729,8 @@ try: generate_f8f6f4_math_instructions_sm100, generate_mxf8f6f4_math_instructions_sm100, generate_mxf4nvf4_math_instructions_sm100, + generate_sparse_mxf4nvf4_math_instructions_sm100, + generate_sparse_mxf8f6f4_math_instructions_sm100, generate_fp8_math_instructions_sm100, generate_cluster_shapes_sm100, get_pruning_level_from_global_level @@ -6741,6 +6743,8 @@ except ImportError: generate_f8f6f4_math_instructions_sm100, generate_mxf8f6f4_math_instructions_sm100, generate_mxf4nvf4_math_instructions_sm100, + generate_sparse_mxf4nvf4_math_instructions_sm100, + generate_sparse_mxf8f6f4_math_instructions_sm100, generate_fp8_math_instructions_sm100, generate_cluster_shapes_sm100, get_pruning_level_from_global_level @@ -6804,7 +6808,8 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): thor_sm = ThorSMRenumbering(cuda_version) min_cc = 100 - max_cc = thor_sm + max_cc = 100 + max_cc = max(max_cc, thor_sm) math_instructions_1sm, math_instructions_2sm = generate_tf32_math_instructions_sm100(instantiation_level) @@ -6879,7 +6884,9 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK math_instructions_1sm, math_instructions_2sm = generate_16b_math_instructions_sm100(instantiation_level) min_cc = 100 - max_cc = thor_sm + max_cc = 100 + max_cc = max(max_cc, thor_sm) + grouped = is_grouped(gemm_kind) cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) @@ -7057,7 +7064,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK thor_sm = ThorSMRenumbering(cuda_version) min_cc = 100 - max_cc = thor_sm + max_cc = 100 + max_cc = max(max_cc, thor_sm) epi_type = DataType.f32 grouped = is_grouped(gemm_kind) @@ -7505,7 +7513,8 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version, gemm_ki thor_sm = ThorSMRenumbering(cuda_version) min_cc = 100 - max_cc = thor_sm + max_cc = 100 + max_cc = max(max_cc, thor_sm) epi_type = DataType.f32 @@ -7687,7 +7696,8 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud thor_sm = ThorSMRenumbering(cuda_version) min_cc = 100 - max_cc = thor_sm + max_cc = 100 + max_cc = max(max_cc, thor_sm) epi_type = DataType.f32 @@ -7918,7 +7928,8 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio thor_sm = ThorSMRenumbering(cuda_version) min_cc = 100 - max_cc = thor_sm + max_cc = 100 + max_cc = max(max_cc, thor_sm) epi_type = DataType.f32 @@ -8046,8 +8057,12 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped) fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, grouped) - nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]] - fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]] + nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule]] + fp4_schedules = [[fp4_kernel_schedule, epi_schedule]] + if (data_type["sfd_type"]["type"] == DataType.void): + nvfp4_schedules.append([nvfp4_kernel_schedule, epi_nosmem_schedule]) + fp4_schedules.append([fp4_kernel_schedule, epi_nosmem_schedule]) + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind ) @@ -8170,14 +8185,407 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped) fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, grouped) - nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]] - fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]] + nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule]] + fp4_schedules = [[fp4_kernel_schedule, epi_schedule]] + if (data_type["sfd_type"]["type"] == DataType.void): + nvfp4_schedules.append([nvfp4_kernel_schedule, epi_nosmem_schedule]) + fp4_schedules.append([fp4_kernel_schedule, epi_nosmem_schedule]) + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) if isFp4: CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, fp4_schedules , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) +def GenerateSM100_SparseTensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledSparseUniversal3x): + # SM100 MMA with F4 inputs + block scale + sparse + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + instantiation_level = manifest.get_instantiation_level(pruned_level=591, default_level=591, exhaustive_level=9999) + + grouped = is_grouped(gemm_kind) + if grouped: + return # not support for grouped sparse kernels + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 64], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 64], [LayoutType.ColumnMajor, 32], [LayoutType.ColumnMajor, 0]], + + ] + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = 100 + max_cc = max(max_cc, thor_sm) + + def tile_schedulers(sfdtype): + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K. + if sfdtype["type"] == DataType.void or grouped: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + math_instructions_1sm, math_instructions_2sm = generate_sparse_mxf4nvf4_math_instructions_sm100(instantiation_level) + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.e4m3, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for layout in layouts: + for data_type in data_types: + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + + # E2M1 x E2M1, vector size 32, E8 + # E2M1 x E2M1, vector size 16, UE4M3 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + nvfp4_epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1SmNvf4 + fp4_epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1SmMxf4 + nvfp4_kernel_schedule = KernelScheduleType.SparseNvf4TmaWarpSpecialized1SmSm100 + fp4_kernel_schedule = KernelScheduleType.SparseMxf4TmaWarpSpecialized1SmSm100 + + nvfp4_schedules = [[nvfp4_kernel_schedule, nvfp4_epi_schedule]] + fp4_schedules = [[fp4_kernel_schedule, fp4_epi_schedule]] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind + ) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, fp4_schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind + ) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.e4m3, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for layout in layouts: + for data_type in data_types: + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + + # E2M1 x E2M1, vector size 32, E8 + # E2M1 x E2M1, vector size 16, UE4M3 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + nvfp4_epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2SmNvf4 + fp4_epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2SmMxf4 + + nvfp4_kernel_schedule = KernelScheduleType.SparseNvf4TmaWarpSpecialized2SmSm100 + fp4_kernel_schedule = KernelScheduleType.SparseMxf4TmaWarpSpecialized2SmSm100 + + nvfp4_schedules = [[nvfp4_kernel_schedule, nvfp4_epi_schedule]] + fp4_schedules = [[fp4_kernel_schedule, fp4_epi_schedule]] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind + ) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, fp4_schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind + ) + + +def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledSparseUniversal3x): + # SM100 MMA with mixed F4/F6/F8 inputs + block scale + sparse + instantiation_level = manifest.get_instantiation_level(pruned_level=591, default_level=591, exhaustive_level=9999) + + grouped = is_grouped(gemm_kind) + if grouped: + return # not support for grouped sparse kernels + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 128], [LayoutType.RowMajor, 0]], + ] + math_instructions_1sm, math_instructions_2sm = generate_sparse_mxf8f6f4_math_instructions_sm100(instantiation_level) + + acc_types = [ DataType.f32 ] + + def tile_schedulers(sfdtype): + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K. + if sfdtype["type"] == DataType.void or grouped: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = 100 + max_cc = max(max_cc, thor_sm) + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 8 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 64, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 64, "layout" : LayoutType.RowMajor} + }, + ] + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + for layout in layouts: + for data_type in data_types: + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1SmMxf8f6f4 + kernel_schedule = KernelScheduleType.SparseMxf8f6f4TmaWarpSpecialized1SmSm100 + + schedules = [[kernel_schedule, epi_schedule]] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind + ) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 8 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + # void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 64, "layout" : LayoutType.RowMajor} + }, + # none void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 64, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + # void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 64, "layout" : LayoutType.RowMajor} + }, + # none void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 64, "layout" : LayoutType.RowMajor} + }, + ] + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + for layout in layouts: + for data_type in data_types: + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2SmMxf8f6f4 + kernel_schedule = KernelScheduleType.SparseMxf8f6f4TmaWarpSpecialized2SmSm100 + + schedules = [[kernel_schedule, epi_schedule]] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind + ) + def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): # SM100 MMA with F4 + block scale if not CudaToolkitVersionSatisfies(cuda_version, 13, 0): @@ -8581,8 +8989,8 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): thor_sm = ThorSMRenumbering(cuda_version) min_cc = 100 - max_cc = thor_sm - + max_cc = 100 + max_cc = max(max_cc, thor_sm) epi_type = DataType.f32 math_instructions_1sm = [ @@ -8798,7 +9206,8 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version): thor_sm = ThorSMRenumbering(cuda_version) min_cc = 100 - max_cc = thor_sm + max_cc = 100 + max_cc = max(max_cc, thor_sm) tile_schedulers = [ TileSchedulerType.Default, TileSchedulerType.StreamK @@ -8926,7 +9335,8 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version): thor_sm = ThorSMRenumbering(cuda_version) min_cc = 100 - max_cc = thor_sm + max_cc = 100 + max_cc = max(max_cc, thor_sm) tile_schedulers = [ TileSchedulerType.Default, TileSchedulerType.StreamK @@ -9054,7 +9464,8 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version): thor_sm = ThorSMRenumbering(cuda_version) min_cc = 100 - max_cc = thor_sm + max_cc = 100 + max_cc = max(max_cc, thor_sm) tile_schedulers = [ TileSchedulerType.Default, TileSchedulerType.StreamK @@ -9181,7 +9592,8 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version): thor_sm = ThorSMRenumbering(cuda_version) min_cc = 100 - max_cc = thor_sm + max_cc = 100 + max_cc = max(max_cc, thor_sm) tile_schedulers = [ TileSchedulerType.Default, TileSchedulerType.StreamK @@ -9322,7 +9734,8 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): thor_sm = ThorSMRenumbering(cuda_version) min_cc = 100 - max_cc = thor_sm + max_cc = 100 + max_cc = max(max_cc, thor_sm) tile_schedulers = [ TileSchedulerType.Default, TileSchedulerType.StreamK @@ -9536,8 +9949,9 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version, thor_sm = ThorSMRenumbering(cuda_version) - minimum_compute_capability = 100 - maximum_compute_capability = thor_sm + min_cc = 100 + max_cc = 100 + max_cc = max(max_cc, thor_sm) spatial_dims = [2, 3] @@ -9584,7 +9998,7 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version, warp_count = [4, 1, 1] tile_description = TileDescription( tile_shape, stages, warp_count, math_inst, - minimum_compute_capability, maximum_compute_capability, + min_cc, max_cc, cluster_shape) tile_descriptions.append(tile_description) @@ -9648,7 +10062,7 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version, warp_count = [4, 1, 1] tile_description = TileDescription( tile_shape, stages, warp_count, math_inst, - minimum_compute_capability, maximum_compute_capability, + min_cc, max_cc, cluster_shape) tile_descriptions.append(tile_description) @@ -9691,8 +10105,10 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version, thor_sm = ThorSMRenumbering(cuda_version) - minimum_compute_capability = 100 - maximum_compute_capability = thor_sm + min_cc = 100 + max_cc = 100 + max_cc = max(max_cc, thor_sm) + spatial_dims = [2, 3] stages = 0 # zero means "deduce the number of stages automatically" @@ -9732,7 +10148,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version, warp_count = [4, 1, 1] tile_description = TileDescription( tile_shape, stages, warp_count, math_inst, - minimum_compute_capability, maximum_compute_capability, + min_cc, max_cc, cluster_shape) tile_descriptions.append(tile_description) @@ -9797,7 +10213,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version, warp_count = [4, 1, 1] tile_description = TileDescription( tile_shape, stages, warp_count, math_inst, - minimum_compute_capability, maximum_compute_capability, + min_cc, max_cc, cluster_shape) tile_descriptions.append(tile_description) @@ -10138,6 +10554,252 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio gemm_kind = GemmKind.BlockScaledUniversal3x ) +def GenerateSM120_Sparse_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version): + # SM120 MMA with mixed F4/F6/F8 inputs + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + layouts = [ + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]] + ] + + instruction_sizes = [ + [16, 8, 64] + ] + + tile_sizes = [ + [128, 64, 256], + [128, 128, 256] + ] + + cluster_shape = [1,1,1] + + ab_types = [ + # DataType.e2m1, + # DataType.e2m3, + # DataType.e3m2, + # DataType.e5m2, + DataType.e4m3, + ] + + acc_types = [ DataType.f32 ] + + + def tile_schedulers(sfdtype, kernel_schedule): + # Pingpong kernel schedule doesn't support stream-K. + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K + if sfdtype["type"] == DataType.void: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + min_cc = 120 + max_cc = 121 + + epi_type = DataType.f32 + + math_instructions = [] + + kernel_schedules = [ + KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedAcc2x4Sm120 + ] + + for instr_size, a_type, b_type, acc_type in product(instruction_sizes, ab_types, ab_types, acc_types): + math_instructions.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledSparseTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + for math_inst in math_instructions: + tile_descriptions = [] + for tile_size in tile_sizes: + tile_descriptions.append( + TileDescription(tile_size, 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 64, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type, kernel_schedule in product(data_types, kernel_schedules): + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, EpilogueScheduleType.SparseTmaWarpSpecializedCooperativeSm120]], + tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule), + gemm_kind = GemmKind.BlockScaledSparseUniversal3x + ) + +def GenerateSM120_Sparse_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version): + # SM120 MMA with with F4 + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 64], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]] + ] + + instruction_sizes = [ + [16, 8, 64] + ] + + + tile_sizes = [ + [128, 64, 256], + [128, 128, 256] + ] + + cluster_shape = [1,1,1] + + ab_types = [ + DataType.e2m1 + ] + + sf_types = [ + DataType.ue4m3, + # DataType.ue8m0 + ] + + acc_types = [ DataType.f32 ] + + + def is_nvf4(kernel_schedule): + if kernel_schedule == KernelScheduleType.SparseNvf4TmaWarpSpecializedSm120: + return True + else: + return False + + def tile_schedulers(sfdtype, kernel_schedule): + # When SFD is void, the epilogue is the traditional linear combination, for which we already have tests with stream-K + + if sfdtype["type"] == DataType.void: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + min_cc = 120 + max_cc = 121 + + epi_type = DataType.f32 + + math_instructions = [] + + kernel_schedules = [ + KernelScheduleType.SparseNvf4TmaWarpSpecializedSm120, + KernelScheduleType.SparseMxf4TmaWarpSpecializedSm120, + ] + + for instr_size, a_type, b_type, acc_type, sf_type in product(instruction_sizes, ab_types, ab_types, acc_types, sf_types): + math_instructions.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledSparseTensorOp, + MathOperation.multiply_add, + sf_type) + ) + + for math_inst in math_instructions: + for kernel_schedule in kernel_schedules: + tile_descriptions = [] + for tile_size in tile_sizes: + # nvf4 kernel only supports ue4m3 SF + # mxf4 kernel only supports ue8m0 SF + if (math_inst.element_scale_factor == DataType.ue4m3 and is_nvf4(kernel_schedule)) or \ + (math_inst.element_scale_factor == DataType.ue8m0 and not is_nvf4(kernel_schedule)): + tile_descriptions.append( + TileDescription(tile_size, 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, EpilogueScheduleType.SparseTmaWarpSpecializedCooperativeSm120]], + tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule), + gemm_kind = GemmKind.BlockScaledSparseUniversal3x + ) + def GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return @@ -10423,6 +11085,12 @@ def GenerateSM100(manifest, cuda_version): GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version) GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) + # + # Block Scaled Sparse Gemm + # + GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM100_SparseTensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version) + # # Conv # @@ -10437,7 +11105,11 @@ def GenerateSM120(manifest, cuda_version): # GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version) GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version) - + # + # Sparse Block Scaled Gemm + # + GenerateSM120_Sparse_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM120_Sparse_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version) # # Sparse Gemm # @@ -10460,8 +11132,8 @@ def GenerateSM90_Conv3x(manifest, cuda_version, if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return - minimum_compute_capability = 90 - maximum_compute_capability = 90 + min_cc = 90 + max_cc = 90 spatial_dims = (2, 3) @@ -10800,7 +11472,7 @@ def GenerateSM90_Conv3x(manifest, cuda_version, math_inst = make_math_instruction(data_types, mma_shape) tile_shape = (mma_shape[0], mma_shape[1], num_mma_per_tile * mma_shape[2]) tile_description = TileDescription(tile_shape, stages, warp_count, math_inst, - minimum_compute_capability, maximum_compute_capability, cluster_shape) + min_cc, max_cc, cluster_shape) assert(isinstance(spatial_dim, int)) dims_and_alignments = ( ( diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index 56d22dc4..3440360d 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -322,7 +322,7 @@ def is_complex(data_type): return False def is_block_scaled(gemm_kind): - return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x) + return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x, GemmKind.BlockScaledSparseUniversal3x) def is_blockwise(gemm_kind): return gemm_kind in (GemmKind.BlockwiseUniversal3x, GemmKind.GroupedBlockwiseUniversal3x) @@ -548,6 +548,12 @@ class KernelScheduleType(enum.Enum): Nvf4TmaWarpSpecialized1SmSm100 = enum_auto() Nvf4TmaWarpSpecialized2SmSm100 = enum_auto() + SparseMxf4TmaWarpSpecialized1SmSm100 = enum_auto() + SparseMxf4TmaWarpSpecialized2SmSm100 = enum_auto() + SparseNvf4TmaWarpSpecialized1SmSm100 = enum_auto() + SparseNvf4TmaWarpSpecialized2SmSm100 = enum_auto() + SparseMxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() + SparseMxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() # FP4 Ultra MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto() MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto() @@ -586,6 +592,10 @@ class KernelScheduleType(enum.Enum): Mxf4TmaWarpSpecializedCooperativeSm120 = enum_auto() Mxf4TmaWarpSpecializedPingpongSm120 = enum_auto() + SparseMxf8f6f4TmaWarpSpecializedSm120 = enum_auto() + SparseMxf8f6f4TmaWarpSpecializedAcc2x4Sm120 = enum_auto() + SparseNvf4TmaWarpSpecializedSm120 = enum_auto() + SparseMxf4TmaWarpSpecializedSm120 = enum_auto() F8f6f4SparseTmaWarpSpecializedCooperativeSm120 = enum_auto() BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto() @@ -637,6 +647,13 @@ KernelScheduleTag = { KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100', KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100', + KernelScheduleType.SparseMxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf4Sm100', + KernelScheduleType.SparseMxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf4Sm100', + KernelScheduleType.SparseNvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized1SmNvf4Sm100', + KernelScheduleType.SparseNvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized2SmNvf4Sm100', + KernelScheduleType.SparseMxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100', + KernelScheduleType.SparseMxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100', + # FP4 Ultra KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103', KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103', @@ -694,6 +711,10 @@ KernelScheduleTag = { KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120', KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120', + KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Sm120', + KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedAcc2x4Sm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Acc2x4Sm120', + KernelScheduleType.SparseNvf4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedNvf4Sm120', + KernelScheduleType.SparseMxf4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf4Sm120', } # @@ -742,6 +763,14 @@ KernelScheduleSuffixes = { KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', + KernelScheduleType.SparseMxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', + KernelScheduleType.SparseMxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', + KernelScheduleType.SparseNvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', + KernelScheduleType.SparseNvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', + + KernelScheduleType.SparseMxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm', + KernelScheduleType.SparseMxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm', KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm', KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm', @@ -796,6 +825,11 @@ KernelScheduleSuffixes = { KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs32', KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs32', + KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedSm120: '_q', + KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedAcc2x4Sm120: '_acc2x4_q', + KernelScheduleType.SparseNvf4TmaWarpSpecializedSm120: '_o_vs16', + KernelScheduleType.SparseMxf4TmaWarpSpecializedSm120: '_o_vs32', + KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: '_q', KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: '_cooperative_q', @@ -827,6 +861,13 @@ class EpilogueScheduleType(enum.Enum): PtrArrayTmaWarpSpecialized2Sm = enum_auto() PtrArrayTmaWarpSpecializedPingpong = enum_auto() PtrArrayTmaWarpSpecializedCooperative = enum_auto() + TmaWarpSpecialized1SmNvf4 = enum_auto() + TmaWarpSpecialized2SmNvf4 = enum_auto() + TmaWarpSpecialized1SmMxf4 = enum_auto() + TmaWarpSpecialized2SmMxf4 = enum_auto() + TmaWarpSpecialized1SmMxf8f6f4 = enum_auto() + TmaWarpSpecialized2SmMxf8f6f4 = enum_auto() + SparseTmaWarpSpecializedCooperativeSm120 = enum_auto() # EpilogueScheduleTag = { @@ -854,6 +895,13 @@ EpilogueScheduleTag = { EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm', EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative', EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong', + EpilogueScheduleType.TmaWarpSpecialized1SmNvf4: 'cutlass::epilogue::TmaWarpSpecialized1SmNvf4', + EpilogueScheduleType.TmaWarpSpecialized2SmNvf4: 'cutlass::epilogue::TmaWarpSpecialized2SmNvf4', + EpilogueScheduleType.TmaWarpSpecialized1SmMxf4: 'cutlass::epilogue::TmaWarpSpecialized1SmMxf4', + EpilogueScheduleType.TmaWarpSpecialized2SmMxf4: 'cutlass::epilogue::TmaWarpSpecialized2SmMxf4', + EpilogueScheduleType.TmaWarpSpecialized1SmMxf8f6f4: 'cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4', + EpilogueScheduleType.TmaWarpSpecialized2SmMxf8f6f4: 'cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4', + EpilogueScheduleType.SparseTmaWarpSpecializedCooperativeSm120: 'cutlass::epilogue::SparseTmaWarpSpecializedCooperativeSm120', } # @@ -882,6 +930,13 @@ EpilogueScheduleSuffixes = { EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_epi_tma', EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma', EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma', + EpilogueScheduleType.TmaWarpSpecialized1SmNvf4: '_epi_tma', + EpilogueScheduleType.TmaWarpSpecialized2SmNvf4: '_epi_tma', + EpilogueScheduleType.TmaWarpSpecialized1SmMxf4: '_epi_tma', + EpilogueScheduleType.TmaWarpSpecialized2SmMxf4: '_epi_tma', + EpilogueScheduleType.TmaWarpSpecialized1SmMxf8f6f4: '_epi_tma', + EpilogueScheduleType.TmaWarpSpecialized2SmMxf8f6f4: '_epi_tma', + EpilogueScheduleType.SparseTmaWarpSpecializedCooperativeSm120: '_epi_tma', } class EpilogueFunctor3x(enum.Enum): @@ -906,6 +961,12 @@ def is_tma_epilogue(epilogue_schedule_type): EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong, + EpilogueScheduleType.TmaWarpSpecialized1SmNvf4, + EpilogueScheduleType.TmaWarpSpecialized2SmNvf4, + EpilogueScheduleType.TmaWarpSpecialized1SmMxf4, + EpilogueScheduleType.TmaWarpSpecialized2SmMxf4, + EpilogueScheduleType.TmaWarpSpecialized1SmMxf8f6f4, + EpilogueScheduleType.TmaWarpSpecialized2SmMxf8f6f4, ] def to_grouped_schedule(schedule, grouped): @@ -1040,7 +1101,8 @@ class OpcodeClass(enum.Enum): TensorOp = enum_auto() WmmaTensorOp = enum_auto() SparseTensorOp = enum_auto() - BlockScaledTensorOp = enum_auto() + BlockScaledTensorOp = enum_auto() + BlockScaledSparseTensorOp = enum_auto() OpcodeClassNames = { @@ -1048,7 +1110,8 @@ OpcodeClassNames = { OpcodeClass.TensorOp: 'tensorop', OpcodeClass.WmmaTensorOp: 'wmma_tensorop', OpcodeClass.SparseTensorOp: 'sptensorop', - OpcodeClass.BlockScaledTensorOp: 'bstensorop' + OpcodeClass.BlockScaledTensorOp: 'bstensorop', + OpcodeClass.BlockScaledSparseTensorOp: 'bssptensorop' } OpcodeClassTag = { @@ -1056,7 +1119,8 @@ OpcodeClassTag = { OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp', OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp', OpcodeClass.SparseTensorOp: 'cutlass::arch::OpClassSparseTensorOp', - OpcodeClass.BlockScaledTensorOp: 'cutlass::arch::OpClassBlockScaledTensorOp' + OpcodeClass.BlockScaledTensorOp: 'cutlass::arch::OpClassBlockScaledTensorOp', + OpcodeClass.BlockScaledSparseTensorOp: 'cutlass::arch::OpClassBlockScaledSparseTensorOp' } ################################################################################################### @@ -1143,6 +1207,7 @@ class GemmKind(enum.Enum): GroupedBlockScaledUniversal3x = enum_auto() BlockwiseUniversal3x = enum_auto() GroupedBlockwiseUniversal3x = enum_auto() + BlockScaledSparseUniversal3x = enum_auto() # GemmKindNames = { @@ -1158,7 +1223,8 @@ GemmKindNames = { GemmKind.GroupedUniversal3x: "gemm_grouped", GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped", GemmKind.BlockwiseUniversal3x: "gemm", - GemmKind.GroupedBlockwiseUniversal3x: "gemm_grouped" + GemmKind.GroupedBlockwiseUniversal3x: "gemm_grouped", + GemmKind.BlockScaledSparseUniversal3x: "spgemm" } # diff --git a/python/cutlass_library/sm100_shapes.py b/python/cutlass_library/sm100_shapes.py index 32e43765..51f405e8 100644 --- a/python/cutlass_library/sm100_shapes.py +++ b/python/cutlass_library/sm100_shapes.py @@ -315,6 +315,12 @@ SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM = { } +SM100_MMA_SHAPES_MXF8F6F4_SPARSE_1SM = { + (128, 128, 32): 0, + (128, 192, 32): 1, + (128, 256, 32): 0, +} + SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM = { (256, 64, 32): 1, (256, 128, 32): 0, @@ -324,6 +330,15 @@ SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM = { } + +SM100_MMA_SHAPES_MXF8F6F4_SPARSE_2SM = { + (256, 128, 32): 0, + (256, 192, 32): 1, + (256, 256, 32): 0, + +} + + # MXF4NVF4 SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM = { (128, 64, 64): 1, @@ -332,6 +347,13 @@ SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM = { (128, 256, 64): 0, } + +SM100_MMA_SHAPES_MXF4NVF4_SPARSE_1SM = { + (128, 128, 64): 0, + (128, 256, 64): 0, +} + + SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM = { # Multiples of 16 for N (256, 64, 64): 1, @@ -340,3 +362,11 @@ SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM = { (256, 256, 64): 0, } + + +SM100_MMA_SHAPES_MXF4NVF4_SPARSE_2SM = { + # Multiples of 16 for N + (256, 128, 64): 0, + (256, 256, 64): 0, + +} diff --git a/python/cutlass_library/sm100_utils.py b/python/cutlass_library/sm100_utils.py index 9bf24fe7..a74b8154 100644 --- a/python/cutlass_library/sm100_utils.py +++ b/python/cutlass_library/sm100_utils.py @@ -659,3 +659,228 @@ def generate_cluster_shapes_sm100(level: int, change_priority_func : Union[Calla ] return shapes_1sm, shapes_2sm + +def generate_sparse_mxf4nvf4_math_instructions_sm100(level: int, enable_runtime_dtype = False, enable_compile_time_dtype = True): + """ + Generate all BlockScaledSparseTensorOp math instructions for MXFP4 and MXFP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_SPARSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_SPARSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + # math_instructions_1sm.append( + # MathInstruction( + # shape, + # a_type, b_type, DataType.f32, + # OpcodeClass.BlockScaledSparseTensorOp, + # MathOperation.multiply_add, + # DataType.ue8m0) + # ) + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledSparseTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e2m1, + ] + + for a_type, b_type in product(compile_time_types, repeat=2): + # math_instructions_1sm.append( + # MathInstruction( + # shape, + # a_type, b_type, DataType.f32, + # OpcodeClass.BlockScaledSparseTensorOp, + # MathOperation.multiply_add, + # DataType.ue8m0) + # ) + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledSparseTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + # math_instructions_2sm.append( + # MathInstruction( + # shape, + # a_type, b_type, DataType.f32, + # OpcodeClass.BlockScaledSparseTensorOp, + # MathOperation.multiply_add, + # DataType.ue8m0) + # ) + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledSparseTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e2m1, + ] + + for a_type, b_type in product(compile_time_types, repeat=2): + # math_instructions_2sm.append( + # MathInstruction( + # shape, + # a_type, b_type, DataType.f32, + # OpcodeClass.BlockScaledSparseTensorOp, + # MathOperation.multiply_add, + # DataType.ue8m0) + # ) + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledSparseTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + return math_instructions_1sm, math_instructions_2sm + + +def generate_sparse_mxf8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = False, enable_compile_time_dtype = True): + """ + Generate all BlockScaledSparseTensorOp math instructions for MXFP8, MXFP6, and MXFP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + + tcgen05_level = get_tcgen05_level_from_global_level(level) + pruning_level = get_pruning_level_from_global_level(level) + + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_SPARSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_SPARSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + + if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)): + continue + + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledSparseTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, + # DataType.e5m2, + # DataType.e3m2, + # DataType.e2m3, + # DataType.e2m1 + ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledSparseTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + + if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)): + continue + + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledSparseTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, + # DataType.e5m2, + # DataType.e3m2, + # DataType.e2m3, + # DataType.e2m1 + ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledSparseTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + return math_instructions_1sm, math_instructions_2sm diff --git a/python/setup_cutlass.py b/python/setup_cutlass.py index acc0c46e..8b53d8f4 100644 --- a/python/setup_cutlass.py +++ b/python/setup_cutlass.py @@ -51,7 +51,7 @@ setup_pycute.perform_setup() setup( name='cutlass_cppgen', - version='4.2.0', + version='4.3.0', description='CUTLASS Pythonic Interface', package_dir={'': '.'}, packages=[ diff --git a/python/setup_library.py b/python/setup_library.py index c56d6b55..4ee11af9 100644 --- a/python/setup_library.py +++ b/python/setup_library.py @@ -36,7 +36,7 @@ from setuptools import setup def perform_setup(): setup( name='cutlass_library', - version='4.2.1', + version='4.3.0', description='CUTLASS library generation scripts', packages=['cutlass_library'] ) diff --git a/python/setup_pycute.py b/python/setup_pycute.py index 0bad050f..37ba01cb 100644 --- a/python/setup_pycute.py +++ b/python/setup_pycute.py @@ -36,7 +36,7 @@ from setuptools import setup def perform_setup(): setup( name='pycute', - version='4.2.1', + version='4.3.0', description='Python implementation of CuTe', packages=['pycute'], ) diff --git a/test/unit/conv/device_3x/dgrad/CMakeLists.txt b/test/unit/conv/device_3x/dgrad/CMakeLists.txt index cb7abd30..826303a9 100644 --- a/test/unit/conv/device_3x/dgrad/CMakeLists.txt +++ b/test/unit/conv/device_3x/dgrad/CMakeLists.txt @@ -32,6 +32,7 @@ add_custom_target( cutlass_test_unit_conv_dgrad_device_tensorop_sm90 cutlass_test_unit_conv_dgrad_device_tensorop_sm100 cutlass_test_unit_conv_dgrad_device_tensorop_sm100_fusion + cutlass_test_unit_conv_dgrad_device_tensorop_sm100_streamk ) cutlass_test_unit_add_executable( @@ -88,4 +89,10 @@ cutlass_test_unit_add_executable( sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu ) +cutlass_test_unit_add_executable_split_file( + cutlass_test_unit_conv_dgrad_device_tensorop_sm100_streamk + + sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu +) + endif() diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu new file mode 100644 index 00000000..10c9f2e4 --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu @@ -0,0 +1,344 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + /*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape = cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape = cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape = cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape = cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape = cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape = cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/CMakeLists.txt b/test/unit/conv/device_3x/fprop/CMakeLists.txt index 4091595c..b56a3eea 100644 --- a/test/unit/conv/device_3x/fprop/CMakeLists.txt +++ b/test/unit/conv/device_3x/fprop/CMakeLists.txt @@ -34,6 +34,7 @@ add_custom_target( cutlass_test_unit_conv3d_fprop_device_tensorop_sm90 cutlass_test_unit_conv_fprop_device_tensorop_sm100 cutlass_test_unit_conv_fprop_device_tensorop_sm100_fusion + cutlass_test_unit_conv_fprop_device_tensorop_sm100_streamk ) cutlass_test_unit_add_executable( @@ -121,4 +122,14 @@ cutlass_test_unit_add_executable( sm100_conv3d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32_with_fusion.cu ) +cutlass_test_unit_add_executable_split_file( + cutlass_test_unit_conv_fprop_device_tensorop_sm100_streamk + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu +) + endif() diff --git a/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu new file mode 100644 index 00000000..79988ed4 --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu @@ -0,0 +1,350 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + /*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape = cutlass::conv::ConvProblemShape; + + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape = cutlass::conv::ConvProblemShape; + + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape = cutlass::conv::ConvProblemShape; + + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape = cutlass::conv::ConvProblemShape; + + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape = cutlass::conv::ConvProblemShape; + + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape = cutlass::conv::ConvProblemShape; + + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/wgrad/CMakeLists.txt b/test/unit/conv/device_3x/wgrad/CMakeLists.txt index f1c4041b..a23fa009 100644 --- a/test/unit/conv/device_3x/wgrad/CMakeLists.txt +++ b/test/unit/conv/device_3x/wgrad/CMakeLists.txt @@ -32,6 +32,9 @@ add_custom_target( cutlass_test_unit_conv_wgrad_device_tensorop_sm90 cutlass_test_unit_conv_wgrad_device_tensorop_sm100 cutlass_test_unit_conv_wgrad_device_tensorop_sm100_fusion + cutlass_test_unit_conv1d_wgrad_device_tensorop_sm100_streamk + cutlass_test_unit_conv2d_wgrad_device_tensorop_sm100_streamk + cutlass_test_unit_conv3d_wgrad_device_tensorop_sm100_streamk ) cutlass_test_unit_add_executable( @@ -68,4 +71,22 @@ cutlass_test_unit_add_executable_split_file( sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu ) +cutlass_test_unit_add_executable_split_file( + cutlass_test_unit_conv1d_wgrad_device_tensorop_sm100_streamk + + sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu +) + +cutlass_test_unit_add_executable_split_file( + cutlass_test_unit_conv2d_wgrad_device_tensorop_sm100_streamk + + sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu +) + +cutlass_test_unit_add_executable_split_file( + cutlass_test_unit_conv3d_wgrad_device_tensorop_sm100_streamk + + sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu +) + endif() diff --git a/test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu b/test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu new file mode 100644 index 00000000..465af928 --- /dev/null +++ b/test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu @@ -0,0 +1,344 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + /*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::TmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu b/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu new file mode 100644 index 00000000..d2196498 --- /dev/null +++ b/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu @@ -0,0 +1,344 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::TmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu b/test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu new file mode 100644 index 00000000..58fc242b --- /dev/null +++ b/test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::TmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/core/uint128.cu b/test/unit/core/uint128.cu index a6872328..75dc21d1 100644 --- a/test/unit/core/uint128.cu +++ b/test/unit/core/uint128.cu @@ -58,6 +58,11 @@ TEST(uint128_t, host_arithmetic) { T y = j; EXPECT_TRUE(static_cast(x + y) == (i + j)); + EXPECT_TRUE(static_cast(x * static_cast(y)) == (i * j)); + + if (j != 0) { + EXPECT_TRUE(static_cast(x / static_cast(y)) == (i / j)); + } } } diff --git a/test/unit/cute/ampere/cooperative_gemm.cu b/test/unit/cute/ampere/cooperative_gemm.cu index af866bc5..caad03f0 100644 --- a/test/unit/cute/ampere/cooperative_gemm.cu +++ b/test/unit/cute/ampere/cooperative_gemm.cu @@ -537,6 +537,8 @@ TEST(SM80_CuTe_Ampere, CooperativeGemmLDSMx2) { SM75_U32x2_LDSM_N{}); } +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) + TEST(SM89_CuTe_Ada, CooperativeGemm_e4m3e4m3f32_MMA) { using TA = cutlass::float_e4m3_t; using TB = cutlass::float_e4m3_t; @@ -609,8 +611,6 @@ TEST(SM89_CuTe_Ada, CooperativeGemm_e5m2e5m2f32_MMA) { test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } -#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) - TEST(SM89_CuTe_Ada, CooperativeGemm_e4m3e4m3f16_MMA) { using TA = cutlass::float_e4m3_t; using TB = cutlass::float_e4m3_t; diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 3b474b0a..0b19413a 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -645,6 +645,7 @@ endif() if (CUTLASS_NVCC_DEVICE_COMPILE) +if (NOT CUTLASS_NVCC_ARCHS MATCHES 101|101a|101f|110|110a|110f) cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_blas3 @@ -808,7 +809,7 @@ cutlass_test_unit_gemm_device_add_executable( hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu ) -if (NOT CUTLASS_NVCC_ARCHS MATCHES 100f|101|101a|101f|103|103a|103f) +if (NOT CUTLASS_NVCC_ARCHS MATCHES 100f|101|101a|101f|110|110a|110f|103|103a|103f) cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_blas3_gaussian @@ -844,6 +845,7 @@ cutlass_test_unit_gemm_device_add_executable( her2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu her2k_cf64h_cf64n_tensor_op_f64_grouped_sm80.cu ) +endif() endif() diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index 90dfffa4..d7dbde84 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -35,6 +35,9 @@ include(GNUInstallDirs) set(CUTLASS_BUILD_MONO_LIBRARY OFF CACHE BOOL "Determines whether the cutlass library is generated as a single file or multiple files.") +option(CUTLASS_BUILD_SHARED_LIBS "Build shared libraries" ON) +option(CUTLASS_BUILD_STATIC_LIBS "Build static libraries" ON) + ################################################################################ add_library(cutlass_library_includes INTERFACE) @@ -62,7 +65,7 @@ install( install( DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/ - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ) add_library(cutlass_library_internal_interface INTERFACE) @@ -123,88 +126,98 @@ function(cutlass_add_cutlass_library) if (CUTLASS_BUILD_MONO_LIBRARY AND __SUFFIX) # If we're only building a single monolithic library then we - # simply link the generated object files to the default library. + # simply link the generated object files to the default library. + if(CUTLASS_BUILD_SHARED_LIBS) + target_link_libraries(${DEFAULT_NAME} PRIVATE $) + endif() - target_link_libraries(${DEFAULT_NAME} PRIVATE $) - target_link_libraries(${DEFAULT_NAME}_static PRIVATE $) + if(CUTLASS_BUILD_STATIC_LIBS) + target_link_libraries(${DEFAULT_NAME}_static PRIVATE $) + endif() else() - cutlass_add_library( - ${__NAME} - SHARED - EXPORT_NAME ${__EXPORT_NAME} - "" + # Shared library (honors CMake's standard CUTLASS_BUILD_SHARED_LIBS) + if(CUTLASS_BUILD_SHARED_LIBS) + cutlass_add_library( + ${__NAME} + SHARED + EXPORT_NAME ${__EXPORT_NAME} + "" ) - target_compile_features(${__NAME} INTERFACE cxx_std_17) - - set_target_properties( - ${__NAME} - PROPERTIES - OUTPUT_NAME ${__OUTPUT_NAME} - WINDOWS_EXPORT_ALL_SYMBOLS 1 - ) - - target_link_libraries( - ${__NAME} - PUBLIC cutlass_library_includes - PRIVATE $ - cuda_driver - ) - - set_target_properties(${__NAME} PROPERTIES DEBUG_POSTFIX "${CUTLASS_LIBRARY_DEBUG_POSTFIX}") - - cutlass_add_library( - ${__NAME}_static - STATIC - EXPORT_NAME ${__EXPORT_NAME}_static - "" + target_compile_features(${__NAME} INTERFACE cxx_std_17) + + set_target_properties(${__NAME} + PROPERTIES + OUTPUT_NAME ${__OUTPUT_NAME} + WINDOWS_EXPORT_ALL_SYMBOLS 1 + DEBUG_POSTFIX "${CUTLASS_LIBRARY_DEBUG_POSTFIX}" ) - target_compile_features(${__NAME}_static INTERFACE cxx_std_17) - - if (WIN32) - set(STATIC_OUTPUT_NAME ${__OUTPUT_NAME}.static) - else() - set(STATIC_OUTPUT_NAME ${__OUTPUT_NAME}) - endif() - - set_target_properties( - ${__NAME}_static - PROPERTIES - OUTPUT_NAME ${STATIC_OUTPUT_NAME} - WINDOWS_EXPORT_ALL_SYMBOLS 1 + target_link_libraries(${__NAME} + PUBLIC cutlass_library_includes + PRIVATE $ + cuda_driver ) - - target_link_libraries( - ${__NAME}_static - PUBLIC cutlass_library_includes - PRIVATE $ - cuda_driver + + install( + TARGETS ${__NAME} + EXPORT NvidiaCutlass + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} ) - - set_target_properties(${__NAME}_static PROPERTIES DEBUG_POSTFIX "${CUTLASS_LIBRARY_DEBUG_POSTFIX}") - - install( - TARGETS ${__NAME} ${__NAME}_static - EXPORT NvidiaCutlass - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - ) - - if (__SUFFIX) - - # The partial libraries generated will be registered as linked libraries - # to the main cutlass library so users automatically get the necessary link - # commands to pull in all kernels by default. - - target_link_libraries(${DEFAULT_NAME} PUBLIC ${__NAME}) - target_link_libraries(${DEFAULT_NAME}_static PUBLIC ${__NAME}_static) - + + if (__SUFFIX) + target_link_libraries(${DEFAULT_NAME} PUBLIC ${__NAME}) + endif() endif() + # Static library + if(CUTLASS_BUILD_STATIC_LIBS) + cutlass_add_library( + ${__NAME}_static + STATIC + EXPORT_NAME ${__EXPORT_NAME}_static + "" + ) + + target_compile_features(${__NAME}_static INTERFACE cxx_std_17) + + if (WIN32) + set(STATIC_OUTPUT_NAME ${__OUTPUT_NAME}.static) + else() + set(STATIC_OUTPUT_NAME ${__OUTPUT_NAME}) + endif() + + set_target_properties( + ${__NAME}_static + PROPERTIES + OUTPUT_NAME ${STATIC_OUTPUT_NAME} + WINDOWS_EXPORT_ALL_SYMBOLS 1 + DEBUG_POSTFIX "${CUTLASS_LIBRARY_DEBUG_POSTFIX}" + ) + + target_link_libraries( + ${__NAME}_static + PUBLIC cutlass_library_includes + PRIVATE $ + cuda_driver + ) + + install( + TARGETS ${__NAME}_static + EXPORT NvidiaCutlass + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + ) + + if (__SUFFIX) + target_link_libraries(${DEFAULT_NAME}_static PUBLIC ${__NAME}_static) + endif() + endif() endif() endfunction() @@ -268,8 +281,13 @@ cutlass_add_cutlass_library( ) # For backward compatibility with the old name -add_library(cutlass_lib ALIAS cutlass_library) -add_library(cutlass_lib_static ALIAS cutlass_library_static) +if(CUTLASS_BUILD_SHARED_LIBS) + add_library(cutlass_lib ALIAS cutlass_library) +endif() + +if(CUTLASS_BUILD_STATIC_LIBS) + add_library(cutlass_lib_static ALIAS cutlass_library_static) +endif() ################################################################################ diff --git a/tools/library/include/cutlass/library/descriptions.h b/tools/library/include/cutlass/library/descriptions.h index 5e80c124..6f1dc5ff 100644 --- a/tools/library/include/cutlass/library/descriptions.h +++ b/tools/library/include/cutlass/library/descriptions.h @@ -348,6 +348,9 @@ struct BlockScaledGemmDescription : public OperationDescription { /// Describes the destination matrix TensorDescription D; + /// Describes the sparse meta matrices + TensorDescription E; + /// Describes the SFA operand TensorDescription SFA; diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index 6764d9a6..ca843ce1 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -392,10 +392,11 @@ struct BlockScaledGemmArguments { library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; - + int device_index{0}; bool use_pdl{false}; }; + /// Blockwise GEMM // // OperationKind: kBlockwiseGemm diff --git a/tools/library/include/cutlass/library/types.h b/tools/library/include/cutlass/library/types.h index 9f8c4ff1..2ec95c7a 100644 --- a/tools/library/include/cutlass/library/types.h +++ b/tools/library/include/cutlass/library/types.h @@ -209,6 +209,7 @@ enum class GemmKind { kPlanarComplex, kPlanarComplexArray, kGrouped, + kBlockScaledSparseGemm, kInvalid }; diff --git a/tools/library/src/block_scaled_gemm_operation_3x.hpp b/tools/library/src/block_scaled_gemm_operation_3x.hpp index c96b9a22..ff75a8a6 100644 --- a/tools/library/src/block_scaled_gemm_operation_3x.hpp +++ b/tools/library/src/block_scaled_gemm_operation_3x.hpp @@ -41,6 +41,8 @@ #include "cutlass/library/library.h" #include "library_internal.h" #include "gemm_operation_3x.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/transform/device/transform_universal_adapter.hpp" /////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::library { @@ -48,7 +50,7 @@ namespace cutlass::library { /////////////////////////////////////////////////////////////////////////////////////////////////// template -class BlockScaledGemmUniversal3xOperation : public GemmOperation3xBase { +class BlockScaledGemmUniversal3xOperationBase : public GemmOperation3xBase { public: using Operator = Operator_; using OperatorArguments = typename Operator::Arguments; @@ -92,15 +94,9 @@ public: static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; using RuntimeDataTypeA = typename Operator::CollectiveMainloop::RuntimeDataTypeA; using RuntimeDataTypeB = typename Operator::CollectiveMainloop::RuntimeDataTypeB; - - -private: - BlockScaledGemmDescription description_; public: - - /// Constructor - BlockScaledGemmUniversal3xOperation(char const *name = "unknown_gemm"): + BlockScaledGemmUniversal3xOperationBase(char const *name = "unknown_gemm"): GemmOperation3xBase(name, GemmKind::kUniversal) { description_.kind = OperationKind::kBlockScaledGemm; description_.SFA.element = NumericTypeMap::kId; @@ -182,38 +178,14 @@ public: BlockScaledGemmDescription const& get_gemm_description() const { return description_; } - protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { - // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides - // Do nothing here and construct kernel arguments in update_arguments_ instead - // We also cannot construct TMA descriptors without all the arguments available - - operator_args.mode = configuration->mode; - return Status::kSuccess; - } - - template - struct UpdateFusionArgs { - static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) { - // If a custom EVT is instantiated then it is the users's responsibility - // to ensure alpha and beta are updated appropriately - return Status::kSuccess; - } - }; - - template - struct UpdateFusionArgs> { - static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) { - + BlockScaledGemmDescription description_; +template + static Status update_fusion_args(FusionArgs& fusion_args, BlockScaledGemmArguments const& arguments) { if constexpr (epilogue_scalefactor_generation) { fusion_args.block_scale_factor_ptr = static_cast(arguments.SFD); fusion_args.norm_constant_ptr = static_cast(arguments.norm_constant); } - if (arguments.pointer_mode == ScalarPointerMode::kHost) { fusion_args.alpha = *static_cast(arguments.alpha); @@ -234,21 +206,12 @@ protected: else { return Status::kErrorInvalidProblem; } - } - }; + } /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - BlockScaledGemmArguments const *arguments) { - Status status = Status::kSuccess; - - status = UpdateFusionArgs::update_( - operator_args.epilogue.thread, *arguments); - if (status != Status::kSuccess) { - return status; - } - + static Status update_arguments_base( + OperatorArguments& operator_args, + BlockScaledGemmArguments const* arguments) { operator_args.problem_shape = cute::make_shape( arguments->problem_size.m(), arguments->problem_size.n(), @@ -256,11 +219,10 @@ protected: arguments->batch_count); // update arguments - + if constexpr (IsRuntimeDataType) { using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; - operator_args.mainloop.ptr_A = static_cast(arguments->A); operator_args.mainloop.ptr_B = static_cast(arguments->B); using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA; @@ -298,17 +260,12 @@ protected: } else { - - operator_args.mainloop.ptr_A = static_cast(arguments->A); operator_args.mainloop.ptr_B = static_cast(arguments->B); } operator_args.mainloop.ptr_SFA = static_cast(arguments->SFA); operator_args.mainloop.ptr_SFB = static_cast(arguments->SFB); operator_args.epilogue.ptr_C = static_cast(arguments->C); operator_args.epilogue.ptr_D = static_cast(arguments->D); - - operator_args.mainloop.dA = cute::make_int_tuple_from( - arguments->lda, arguments->batch_stride_A); operator_args.mainloop.dB = cute::make_int_tuple_from( arguments->ldb, arguments->batch_stride_B); operator_args.epilogue.dC = cute::make_int_tuple_from( @@ -353,7 +310,74 @@ protected: arguments->cluster_shape_fallback.n(), arguments->cluster_shape_fallback.k()); } - + return Status::kSuccess; + } +}; + +template +class BlockScaledGemmUniversal3xOperation : public BlockScaledGemmUniversal3xOperationBase { +public: + using Base = BlockScaledGemmUniversal3xOperationBase; + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + BlockScaledGemmUniversal3xOperation(char const *name = "unknown_gemm"): + BlockScaledGemmUniversal3xOperationBase(name) {} + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) { + + return Base::update_fusion_args(fusion_args, arguments); + } + }; + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + BlockScaledGemmArguments const *arguments) { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + if constexpr (Base::IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + operator_args.mainloop.ptr_A = static_cast(arguments->A); + + } else { + operator_args.mainloop.ptr_A = static_cast(arguments->A); + } + operator_args.mainloop.dA = cute::make_int_tuple_from( + arguments->lda, arguments->batch_stride_A); + status = Base::update_arguments_base(operator_args, arguments); return status; } @@ -443,6 +467,306 @@ public: return status; } }; + +template +class BlockScaledSparseGemmUniversal3xOperation : public BlockScaledGemmUniversal3xOperationBase { +public: + using Base = BlockScaledGemmUniversal3xOperationBase; + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ArchTag = typename Operator::ArchTag; + using StrideA = cutlass::gemm::TagToStrideA_t; + using ElementE = typename Operator::CollectiveMainloop::ElementE; + using LayoutE = typename Operator::CollectiveMainloop::LayoutE; + using SparseConfig = typename Operator::CollectiveMainloop::SparseConfig; + using ProblemShape = typename Operator::GemmKernel::ProblemShape; + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, + typename Base::ElementA, + typename Base::LayoutA, + SparseConfig>; + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, + typename Base::ElementA, + typename Base::LayoutA, + SparseConfig, + ArchTag>; + using Compressor = cutlass::transform::device::TransformUniversalAdapter; +private: + // Variables that must change in the const functions. + mutable CompressorUtility compressor_utility; + mutable int problem_count = 1; + mutable std::vector iter_idx; + + mutable uint64_t tensor_ac_size = 0; + mutable uint64_t tensor_e_size = 0; + mutable uint64_t tensor_a_size = 0; + mutable uint64_t host_op_workspace_size = 0; + mutable uint64_t device_compress_workspace_size = 0; + mutable uint64_t device_op_workspace_size = 0; + mutable uint64_t device_per_iter_workspace_size = 0; + +public: + BlockScaledSparseGemmUniversal3xOperation(char const *name = "unknown_gemm"):Base(name) { + this->description_.E = make_TensorDescription(typename SparseConfig::TensorEAlignmentK{}); + } +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) { + + return Base::update_fusion_args(fusion_args, arguments); + } + }; + + static Status update_arguments_( + OperatorArguments &operator_args, + BlockScaledGemmArguments const *arguments, + CompressorUtility const& compressor_utility, + void* device_a_compressed_ptr = nullptr, + void* device_e_ptr = nullptr) { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + // update arguments + if constexpr (Base::IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + operator_args.mainloop.ptr_A = static_cast(device_a_compressed_ptr); + } else { + operator_args.mainloop.ptr_A = static_cast(device_a_compressed_ptr); + } + operator_args.mainloop.ptr_E = static_cast(device_e_ptr); + + operator_args.mainloop.layout_a = compressor_utility.fill_layoutA_from_compressor(); + operator_args.mainloop.layout_e = compressor_utility.fill_layoutE_from_compressor(); + + status = Base::update_arguments_base(operator_args, arguments); + return status; +} +public: + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr), compressor_utility); + if (status != Status::kSuccess) { + return 0; + } + typename Compressor::Arguments compress_arguments { + {compressor_utility.M, 0, compressor_utility.K, compressor_utility.L}, + {/*Empty Not Use*/}, + {/*Empty Not Use*/} }; + + // Size for one iteration + // For multi-iteration, will need to multiply result of this function w/ actual problem_count + tensor_ac_size = compressor_utility.get_compressed_tensor_A_bytes(); + tensor_e_size = compressor_utility.get_tensor_E_bytes(); + device_op_workspace_size = Operator::get_workspace_size(args); + device_compress_workspace_size = Compressor::get_workspace_size(compress_arguments); + // NOTE: order here is the order of workspace partition + device_per_iter_workspace_size = device_op_workspace_size + device_compress_workspace_size + tensor_ac_size + tensor_e_size; + + return device_per_iter_workspace_size; + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *configuration) const override { + // Memory to hold operator + host_op_workspace_size = sizeof(Operator); + + // Memory to hold result of `.structure_sparse_zero_mask_fill()` + tensor_a_size = compressor_utility.get_raw_tensor_A_bytes(); + + // NOTE: order here is the order of workspace partition + const uint64_t size = host_op_workspace_size + tensor_a_size; + return size; + } + /// Returns success if the operation can proceed + Status can_implement( + void const *configuration_ptr, void const *arguments_ptr) const override { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + BlockScaledGemmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + auto problem_shape_MNKL = cute::make_shape( + configuration->problem_size.m(), + configuration->problem_size.n(), + configuration->problem_size.k(), + configuration->batch_count); + const int M = configuration->problem_size.m(); + const int N = configuration->problem_size.n(); + const int K = configuration->problem_size.k(); + const int L = configuration->batch_count; + using StrideA = typename CompressorUtility::StrideA; + auto dA = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + compressor_utility.set_problem_size(problem_shape_MNKL, dA); + + auto status = update_arguments_(args, arguments, compressor_utility); + if (status != Status::kSuccess) { + return status; + } + // can_implement rules may need access to problem shape + args.problem_shape = problem_shape_MNKL; + + return Operator::can_implement(args); + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + Operator *op = new (host_workspace) Operator; + return Status::kSuccess; + } + + Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspaces, + int problem_count_from_profiler, + cudaStream_t stream = nullptr) { + iter_idx.resize(static_cast(configuration)->device_count, 0); + // Set problem_count. + problem_count = problem_count_from_profiler; + + // * Host Ptr + auto* host_op_workspace_ptr = reinterpret_cast(host_workspace); + auto* host_a_raw_ptr = host_op_workspace_ptr + host_op_workspace_size; + + // * Construct Op + Operator *op = new (host_op_workspace_ptr) Operator; + + // * Device Ptr (1st iteration) + // Device workspace : | iter1 | iter2 | iter3 | .. | iterx | + // iteri : op_workspace | tensor_ac | tensor_e + auto* device_ptr_iter1 = static_cast(device_workspace); + auto* device_op_workspace_ptr_iter1 = device_ptr_iter1; + auto* device_compressor_workspace_ptr_iter1 = device_op_workspace_ptr_iter1 + device_op_workspace_size; + auto* device_a_compressed_ptr_iter1 = device_compressor_workspace_ptr_iter1 + device_compress_workspace_size; + auto* device_e_ptr_iter1 = device_a_compressed_ptr_iter1 + tensor_ac_size; + // * Device A Raw Ptr + auto* device_a_raw_ptr = profiler_workspaces[0]; + // * Random fill 50% of TensorA w/ zero following the structured sparse requirement + CUDA_CHECK(cudaMemcpyAsync(host_a_raw_ptr, device_a_raw_ptr, tensor_a_size, cudaMemcpyDeviceToHost, stream)); + compressor_utility.structure_sparse_zero_mask_fill(host_a_raw_ptr, 2000); + CUDA_CHECK(cudaMemcpyAsync(device_a_raw_ptr, host_a_raw_ptr, tensor_a_size, cudaMemcpyHostToDevice, stream)); + + CUDA_CHECK(cudaGetLastError()); + + // * Compress DTensorA and get DTensorAC & DTensorE + cutlass::KernelHardwareInfo hw_info; + CUDA_CHECK(cudaGetDevice(&hw_info.device_id)); + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {compressor_utility.M, 0, compressor_utility.K, compressor_utility.L}, + {device_a_raw_ptr, + compressor_utility.dA, + device_a_compressed_ptr_iter1, + device_e_ptr_iter1}, + {hw_info} + }; + + cutlass::Status status {cutlass::Status::kSuccess}; + + Compressor compressor_op; + status = compressor_op.can_implement(arguments); + if (status != Status::kSuccess) { + return status; + } + + status = compressor_op.initialize(arguments, device_compressor_workspace_ptr_iter1, stream); + if (status != Status::kSuccess) { + return status; + } + + status = compressor_op.run(stream); + if (status != Status::kSuccess) { + return status; + } + + // * Copy Iter1's DTensorAC DTensorE to each iteration's DTensorAC DTensorE + for (int iter_i = 1; iter_i < problem_count; iter_i++) { + // * Device AC E Ptr per iteration + // Device workspace : | iter1 | iter2 | iter3 | .. | iterx | + // iteri : op_workspace | tensor_ac | tensor_e + auto* device_ptr_iteri = static_cast(device_workspace) + device_per_iter_workspace_size * iter_i; + auto* device_op_workspace_ptr = device_ptr_iteri; + auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size; + auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size; + auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size; + + CUDA_CHECK(cudaMemcpyAsync(device_a_compressed_ptr, device_a_compressed_ptr_iter1, tensor_ac_size, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(device_e_ptr, device_e_ptr_iter1, tensor_e_size, cudaMemcpyDeviceToDevice, stream)); + } + + CUDA_CHECK(cudaStreamSynchronize(stream)); + + CUDA_CHECK(cudaGetLastError()); + return Status::kSuccess; + } + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + const auto device_index = static_cast(arguments_ptr)->device_index; + + auto* device_ptr_iteri = static_cast(device_workspace) + device_per_iter_workspace_size * iter_idx[device_index]; + auto* device_op_workspace_ptr = device_ptr_iteri; + auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size; + auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size; + auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size; + iter_idx[device_index] = (iter_idx[device_index] + 1) % problem_count; + + Status status = update_arguments_(operator_args, static_cast(arguments_ptr), compressor_utility, device_a_compressed_ptr, device_e_ptr); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(operator_args, device_workspace, stream, nullptr, static_cast(arguments_ptr)->use_pdl); + return status; + } +}; /////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::library diff --git a/tools/library/src/gemm_operation_3x.hpp b/tools/library/src/gemm_operation_3x.hpp index 2c1d1794..4b510ffa 100644 --- a/tools/library/src/gemm_operation_3x.hpp +++ b/tools/library/src/gemm_operation_3x.hpp @@ -185,23 +185,11 @@ public: /// Constructor GemmUniversal3xOperation(char const *name = "unknown_gemm"): - GemmOperation3xBase(name, GemmKind::kUniversal) { - if constexpr (Operator::ArchTag::kMinComputeCapability == 90) { - dim3 cluster_dims( - cute::size<0>(typename Operator::GemmKernel::ClusterShape{}), - cute::size<1>(typename Operator::GemmKernel::ClusterShape{}), - cute::size<2>(typename Operator::GemmKernel::ClusterShape{})); - uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock; - void const* kernel_ptr = (void*)(device_kernel); - max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters( - cluster_dims, - threads_per_block, - kernel_ptr); - } - } + GemmOperation3xBase(name, GemmKind::kUniversal) {} private: - int max_active_clusters{}; + // mutable because it needs to be set in initialize (see comment in initialize) + mutable int max_active_clusters{}; protected: @@ -683,6 +671,21 @@ public: void *host_workspace, void *device_workspace, cudaStream_t stream = nullptr) const override { + // this would ideally go in the constructor, but + // the constructor is called at profiler startup for EVERY kernel, + // REGARDLESS of whether the kernel is actually supported on the device + if constexpr (Operator::ArchTag::kMinComputeCapability == 90) { + dim3 cluster_dims( + cute::size<0>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<1>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<2>(typename Operator::GemmKernel::ClusterShape{})); + uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock; + void const* kernel_ptr = (void*)(device_kernel); + max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters( + cluster_dims, + threads_per_block, + kernel_ptr); + } Operator *op = new (host_workspace) Operator; return Status::kSuccess; } diff --git a/tools/library/src/reference/gemm_f4_f8_f32.cu b/tools/library/src/reference/gemm_f4_f8_f32.cu index c4fec090..b57bd1ff 100644 --- a/tools/library/src/reference/gemm_f4_f8_f32.cu +++ b/tools/library/src/reference/gemm_f4_f8_f32.cu @@ -99,6 +99,46 @@ void initialize_gemm_reference_operations_f4_f8_f32(Manifest &manifest) { float // ElementD >(manifest); + // 1. + make_gemm_real_canonical_layouts< + float_e2m1_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + // 2. + make_gemm_real_canonical_layouts< + float_e2m1_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + // 3. + make_gemm_real_canonical_layouts< + float_e2m1_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + // 4. + make_gemm_real_canonical_layouts< + float_e2m1_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + } /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/profiler/CMakeLists.txt b/tools/profiler/CMakeLists.txt index 2e343589..f2b10dac 100644 --- a/tools/profiler/CMakeLists.txt +++ b/tools/profiler/CMakeLists.txt @@ -104,11 +104,7 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.3 AND CUDA_VERSION VERSION_LESS 12.4 A set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,host --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) else() set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) - if (90a IN_LIST CUTLASS_NVCC_ARCHS_ENABLED OR (90 IN_LIST CUTLASS_NVCC_ARCHS_ENABLED)) - set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=BlockScaledGemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) - else() - set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=BlockScaledGemm --mode=trace --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) - endif() + set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=BlockScaledGemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) endif() set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_CONV2D --operation=Conv2d --providers=cutlass --verification-providers=cudnn,device --junit-output=test_cutlass_profiler_conv2d --print-kernel-before-running=true) diff --git a/tools/profiler/src/block_scaled_gemm_operation_profiler.cu b/tools/profiler/src/block_scaled_gemm_operation_profiler.cu index 84faffbe..05907038 100644 --- a/tools/profiler/src/block_scaled_gemm_operation_profiler.cu +++ b/tools/profiler/src/block_scaled_gemm_operation_profiler.cu @@ -355,11 +355,17 @@ Status BlockScaledGemmOperationProfiler::GemmProblem::parse( int64_t BlockScaledGemmOperationProfiler::GemmProblem::bytes_with_problem_shape( library::BlockScaledGemmDescription const &operation_desc, gemm::GemmCoord const &problem_shape) const { + + int sfa_m = round_up(problem_shape.m(), 128); + int sfb_n = round_up(problem_shape.n(), 128); + int sfa_sfb_k = round_up(ceil_div(problem_shape.k(), operation_desc.SFVecSize), 4); // Input bytes read and Output bytes written for the gemm problem int64_t bytes = int64_t(library::sizeof_bits(operation_desc.A.element) * problem_shape.m() / 8) * problem_shape.k() + int64_t(library::sizeof_bits(operation_desc.B.element) * problem_shape.n() / 8) * problem_shape.k() + - int64_t(library::sizeof_bits(operation_desc.C.element) * problem_shape.m() / 8) * problem_shape.n(); + int64_t(library::sizeof_bits(operation_desc.C.element) * problem_shape.m() / 8) * problem_shape.n() + + int64_t(library::sizeof_bits(operation_desc.SFA.element) * sfa_m / 8) * sfa_sfb_k + + int64_t(library::sizeof_bits(operation_desc.SFB.element) * sfb_n / 8) * sfa_sfb_k; // Set is_beta_zero true if beta is zero bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; }); @@ -726,6 +732,8 @@ Status BlockScaledGemmOperationProfiler::initialize_workspace( library::BlockScaledGemmDescription const &operation_desc = static_cast(operation->description()); + bool is_sparse = operation_desc.tile_description.math_instruction.opcode_class == cutlass::library::OpcodeClassID::kSparseTensorOp; + // Compute the number of copies of the problem to avoid L2 camping. if (!options.profiling.workspace_count) { int64_t bytes = problem_.bytes(operation_desc); @@ -917,6 +925,7 @@ Status BlockScaledGemmOperationProfiler::initialize_workspace( /* Query device SM count to pass onto the kernel as an argument, where needed */ gemm_workspace_.arguments.sm_count = options.device.get_sm_count(0); + gemm_workspace_.arguments.device_index = static_cast(0); } // @@ -932,12 +941,34 @@ Status BlockScaledGemmOperationProfiler::initialize_workspace( workspace_size = underlying_operation->get_device_workspace_size(&gemm_workspace_.configuration, &gemm_workspace_.arguments); + if (is_sparse) { + // sparse gemm get_device_workspace_size() only return device workspace size per iteration + // Needs to multiply it w/ number of iteration + workspace_size *= gemm_workspace_.problem_count; + } gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); + // Convert to structure sparse contents here. + if (is_sparse) { + uint8_t* profiler_workspaces[1]; + profiler_workspaces[0] = reinterpret_cast(gemm_workspace_.A->data()); + // Sparse operations have a different initialize interface. + // initialize_with_profiler_workspace converts mxk tensorA to compressed mxk/sp tensorA and the tensorE + auto modifiable_underlying_op = const_cast(underlying_operation); + status = modifiable_underlying_op->initialize_with_profiler_workspace( + &gemm_workspace_.configuration, + gemm_workspace_.host_workspace.data(), + gemm_workspace_.device_workspace.data(), + profiler_workspaces, + gemm_workspace_.problem_count); + } + else { status = underlying_operation->initialize( &gemm_workspace_.configuration, gemm_workspace_.host_workspace.data(), gemm_workspace_.device_workspace.data()); + } + if (status != Status::kSuccess) { return status; } diff --git a/tools/profiler/src/cutlass_profiler.cu b/tools/profiler/src/cutlass_profiler.cu index 6e3fb4b8..0abac6d8 100644 --- a/tools/profiler/src/cutlass_profiler.cu +++ b/tools/profiler/src/cutlass_profiler.cu @@ -63,7 +63,7 @@ CutlassProfiler::CutlassProfiler( operation_profilers_.emplace_back(new GemmOperationProfiler(options)); - operation_profilers_.emplace_back(new BlockScaledGemmOperationProfiler(options)); + operation_profilers_.emplace_back(new BlockScaledGemmOperationProfiler(options)); operation_profilers_.emplace_back(new BlockwiseGemmOperationProfiler(options)); diff --git a/tools/util/CMakeLists.txt b/tools/util/CMakeLists.txt index b69ea023..bd9834d7 100644 --- a/tools/util/CMakeLists.txt +++ b/tools/util/CMakeLists.txt @@ -45,8 +45,8 @@ target_link_libraries( install( DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/ - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ - ) + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} +) install( TARGETS cutlass_tools_util_includes