diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f423fb2..9ca22eaf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,14 +17,16 @@ * Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures: - Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture. - Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture. - - Support for [grouped GEMM with blockwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture. + - Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture. - Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture. - Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture. -* Added support for enhanced kernel performance search in CUTLASS: + - Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture. +* Added support for enhanced kernel performance search (auto-tuning) in CUTLASS profiler: - Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels. - Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance. - Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration. - - More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss). + - More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/cpp/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss). +* Support `void` as the D element in sm100 kernel epilogues. ## [3.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.8.0) (2025-01-25) @@ -40,7 +42,7 @@ - [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp). - [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp). - Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types. - - Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp). + - Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/cpp/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp). - Extensions to testbeds and reference check code for unit tests and CUTLASS profiler. * Full support for Blackwell SM100 kernels in CUTLASS 3.x API: - [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that @@ -78,11 +80,11 @@ - A set of new [Hopper grouped GEMM kernels](./examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes. - A new [Hopper FP8 GEMM with groupwise scaling](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu). * Documentation updates: - - [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel). - - Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/blackwell_functionality.md) - - A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures. + - [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/cpp/quickstart.md#instantiating-a-blackwell-gemm-kernel). + - Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/cpp/blackwell_functionality.md) + - A new [functionality documentation](./media/docs/cpp/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures. - Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture). - - Updates to [profiler documentation](./media/docs/profiler.md) for testing mixed input GEMM kernels on Hopper. + - Updates to [profiler documentation](./media/docs/cpp/profiler.md) for testing mixed input GEMM kernels on Hopper. ## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11) - [Hopper blockwise scaling FP8 GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439). @@ -95,7 +97,7 @@ + Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication. + Remove `cute::copy_vec` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment,...)`. + A refactor of default epilogue struct `DefaultEpilogue` [API](./include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel. -- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](./media/docs/profiler.md#cutlass-profiler). +- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](./media/docs/cpp/profiler.md#cutlass-profiler). - Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! - Optimal code generation with CUDA toolkit versions 12.6. @@ -109,12 +111,12 @@ - A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API. - [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. - [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu). -- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md). -- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. +- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/cpp/dependent_kernel_launch.md). +- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/cpp/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. - A new TMA-enabled [epilogue](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support. - A SIMT-enabled pointer-array [epilogue](./include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp). - A new [Ping-Pong kernel schedule for Grouped GEMM](./include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations. -- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/profiler.md#instantiating-more-kernels-with-hopper). +- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/cpp/profiler.md#instantiating-more-kernels-with-hopper). - A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h) - Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu). - Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! @@ -124,7 +126,7 @@ - [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu) - [Exposure of L2 `cache_hint`s in TMA copy atoms](./include/cute/arch/copy_sm90_tma.hpp#L48) -- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and +- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/cpp/profiler.md#GEMM), and [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). - [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu). - A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence: @@ -137,7 +139,7 @@ - Support for residual add (beta != 0) in convolution kernels. - A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output. - A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt). -- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md). +- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/cpp/ide_setup.md) and [expanded code style guide](./media/docs/cpp/programming_guidelines.md). - Better support for MSVC as a host compiler. - Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2. - Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1. @@ -145,7 +147,7 @@ ## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09) - Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp) - + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md). + + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/cpp/gemm_api_3x.md). + Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](./include/cutlass/conv/convnd_problem_shape.hpp). + Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms + [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API. @@ -157,7 +159,7 @@ - 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices. + [Ampere FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm80.cu) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu#L227-L301), [Ampere INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu#L392-L1342), [Ampere INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu#L372-L934). + [Turing FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm75.cu#L55-L394), [Turing INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu#L166-L537), [Turing INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu#L310-L564). -- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial). +- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cpp/cute/03_tensor.md), [MMA atoms](./media/docs/cpp/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial). - Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337). - Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17. - Fixes to greatly reduce build warnings. @@ -176,7 +178,7 @@ * Beta release of [Group-GEMM](./examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above). * [Ampere Sparse GEMM](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now. * NamedBarriers usability improvement and list of [ReservedNamedBarriers](./include/cutlass/arch/barrier.h) has been officially released. -* Improved [CuTe documentation](./media/docs/cute/) including improved clarity and depth of [Quickstart](./media/docs/cute/00_quickstart.md), [CuTe Layout](./media/docs/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved. +* Improved [CuTe documentation](./media/docs/cpp/cute/) including improved clarity and depth of [Quickstart](./media/docs/cute/00_quickstart.md), [CuTe Layout](./media/docs/cpp/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cpp/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved. ## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31) * [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types. @@ -227,7 +229,7 @@ * Epilogue builders. Similar to mainloop builders (see [example 49](./examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization. * Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler. * Performance optimizations for the [*warp-specialized persistent ping-pong*](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel. -* Changes to the [GEMM API 3.x](./media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs. +* Changes to the [GEMM API 3.x](./media/docs/cpp/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs. * [FMHA Backward Pass](./examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers. * [Streamk GEMM with Broadcast](./examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM. * [Batched B2B GEMM](./examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel. @@ -239,10 +241,10 @@ * Updates and bugfixes from the community (thanks!) ## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23) -* [CuTe](./media/docs/cute/00_quickstart.md), a [new core library and backend](./include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors. -* [A new conceptual operation hierarchy](./media/docs/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](./media/docs/gemm_api_3x.md). -* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](./include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](./include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](./media/docs/cutlass_3x_backwards_compatibility.md). -* Updates to [Functionality](./media/docs/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3. +* [CuTe](./media/docs/cpp/cute/00_quickstart.md), a [new core library and backend](./include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors. +* [A new conceptual operation hierarchy](./media/docs/cpp/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](./media/docs/cpp/gemm_api_3x.md). +* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](./include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](./include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](./media/docs/cpp/cutlass_3x_backwards_compatibility.md). +* Updates to [Functionality](./media/docs/cpp/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3. * Updates to [Compatibility](./README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](./README.md#Target-Architecture). * New warp-specialized GEMM [kernel schedules](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](./include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. * Extensions to CUTLASS profiler to support threadblock cluster shapes in library and profiler tile configurations. @@ -420,7 +422,7 @@ * Global memory iterators supporting Fprop, Dgrad, and Wgrad * `MmaMultistage` for implicit GEMM convolution for NVIDIA Ampere architecture * `MmaPipeline` for implicit GEMM convolution for NVIDIA Volta and Turing architectures - * [Documentation](./media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation + * [Documentation](./media/docs/cpp/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation ## [2.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.3.0) (2020-09-23) * [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/) @@ -434,7 +436,7 @@ * NVIDIA Ampere GPU Architecture examples and documentation: * [Tensor Float 32](./examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and * [Sparse Tensor Cores](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu) - * Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/gemm_api.md#efficient-epilogue) + * Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/cpp/gemm_api.md#efficient-epilogue) ## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08) * [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/) @@ -454,7 +456,7 @@ * Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON` ## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06) - * BLAS-style host-side API added to [CUTLASS Library](./media/docs/quickstart.md#cutlass-library) + * BLAS-style host-side API added to [CUTLASS Library](./media/docs/cpp/quickstart.md#cutlass-library) * API to launch compiled kernel instances for GEMM and planar complex GEMM * Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores * Computes complex matrix products on matrices stored as disjoint real and imaginary parts @@ -468,10 +470,10 @@ * Encapsulated functionality embodying modern C++11 programming techniques * Optimized containers and data types for efficient, generic, portable device code * Updates to: - * [Quick start guide](./media/docs/quickstart.md) + * [Quick start guide](./media/docs/cpp/quickstart.md) * [Documentation](./README.md#documentation) - * [Utilities](./media/docs/utilities.md) - * [CUTLASS Profiler](./media/docs/profiler.md) + * [Utilities](./media/docs/cpp/utilities.md) + * [CUTLASS Profiler](./media/docs/cpp/profiler.md) * Native Turing Tensor Cores * Efficient GEMM kernels targeting Turing Tensor Cores * Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands diff --git a/README.md b/README.md index 77a81620..4593e94d 100644 --- a/README.md +++ b/README.md @@ -32,9 +32,9 @@ the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components. -See the [Quick Start Guide](./media/docs/quickstart.md) to get started quickly. +See the [Quick Start Guide](./media/docs/cpp/quickstart.md) to get started quickly. -See the [functionality docs](./media/docs/functionality.md) for a more comprehensive +See the [functionality docs](./media/docs/cpp/functionality.md) for a more comprehensive list of kernel level features, data types, instructions, and minimum supported by CUTLASS on each GPU architecture. @@ -54,14 +54,16 @@ architecture. * Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures: - Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture. - Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture. - - Support for [grouped GEMM with blockwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture. + - Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture. - Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture. - Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture. -* Added support for enhanced kernel performance search in CUTLASS: + - Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture. +* Added support for enhanced kernel performance search (auto-tuning) in CUTLASS profiler: - Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels. - Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance. - Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration. - - More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss). + - More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/cpp/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss). +* Support `void` as the D element in sm100 kernel epilogues. Note: CUTLASS 3.x builds are known to be down on Windows platforms for all CUDA toolkits. CUTLASS team is working on a fix. @@ -108,7 +110,7 @@ Layouts can also be combined and manipulated via functional composition, on whic CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design and improves code composability and readability. More documentation specific to CuTe can be found in its -[dedicated documentation directory](./media/docs/cute/00_quickstart.md). +[dedicated documentation directory](./media/docs/cpp/cute/00_quickstart.md). # Compatibility @@ -155,6 +157,7 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be |NVIDIA H100 Tensor Core GPU |9.0|11.8| |NVIDIA H200 Tensor Core GPU |9.0|11.8| |NVIDIA B200 Tensor Core GPU |10.0|12.8| +|NVIDIA GeForce RTX 50x0 series |10.0|12.8| ## Target Architecture @@ -190,7 +193,7 @@ NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels compiled for Blackwell SM100 architecture with arch conditional features (using `sm100a`) are not compatible with RTX 50 series GPUs. -Please refer to the [functionality documentation](./media/docs/functionality.md) +Please refer to the [functionality documentation](./media/docs/cpp/functionality.md) for details on which kernels require which target architectures. # Documentation @@ -198,22 +201,22 @@ for details on which kernels require which target architectures. CUTLASS is described in the following documents and the accompanying [Doxygen documentation](https://nvidia.github.io/cutlass). -- [Quick Start Guide](./media/docs/quickstart.md) - basics of building and running CUTLASS -- [Functionality](./media/docs/functionality.md) - summarizes functionality available in CUTLASS -- [Efficient GEMM in CUDA](./media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA -- [CUTLASS 3.x Design](./media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components -- [GEMM API 3.x](./media/docs/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts -- [GEMM API 2.x](./media/docs/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts -- [Implicit GEMM Convolution](./media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS -- [Code Organization](./media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project -- [Terminology](./media/docs/terminology.md) - describes terms used in the code -- [Programming Guidelines](./media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++ -- [Fundamental types](./media/docs/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays -- [Layouts](./media/docs/layout.md) - describes layouts of matrices and tensors in memory -- [Tile Iterators](./media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory -- [CUTLASS Profiler](./media/docs/profiler.md) - command-line driven profiling application -- [CUTLASS Utilities](./media/docs/utilities.md) - additional templates used to facilitate rapid development -- [Dependent kernel launch](./media/docs/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent +- [Quick Start Guide](./media/docs/cpp/quickstart.md) - basics of building and running CUTLASS +- [Functionality](./media/docs/cpp/functionality.md) - summarizes functionality available in CUTLASS +- [Efficient GEMM in CUDA](./media/docs/cpp/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA +- [CUTLASS 3.x Design](./media/docs/cpp/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components +- [GEMM API 3.x](./media/docs/cpp/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts +- [GEMM API 2.x](./media/docs/cpp/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts +- [Implicit GEMM Convolution](./media/docs/cpp/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS +- [Code Organization](./media/docs/cpp/code_organization.md) - describes the organization and contents of the CUTLASS project +- [Terminology](./media/docs/cpp/terminology.md) - describes terms used in the code +- [Programming Guidelines](./media/docs/cpp/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++ +- [Fundamental types](./media/docs/cpp/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays +- [Layouts](./media/docs/cpp/layout.md) - describes layouts of matrices and tensors in memory +- [Tile Iterators](./media/docs/cpp/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory +- [CUTLASS Profiler](./media/docs/cpp/profiler.md) - command-line driven profiling application +- [CUTLASS Utilities](./media/docs/cpp/utilities.md) - additional templates used to facilitate rapid development +- [Dependent kernel launch](./media/docs/cpp/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent kernels in the same stream, and how it is used in CUTLASS. # Resources @@ -233,7 +236,7 @@ projects. Client applications should target CUTLASS's `include/` directory in th paths. CUTLASS unit tests, examples, and utilities can be build with CMake. -The minimum version of CMake is given in the [Quickstart guide](./media/docs/quickstart.md). +The minimum version of CMake is given in the [Quickstart guide](./media/docs/cpp/quickstart.md). Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed on your system. @@ -278,7 +281,7 @@ CUTLASS is arranged as a header-only library along with Utilities, Tools, Exampl and template concepts defined in the CUTLASS project. A detailed explanation of the source code organization may be found in the -[CUTLASS documentation](./media/docs/code_organization.md), but several main components are summarized below. +[CUTLASS documentation](./media/docs/cpp/code_organization.md), but several main components are summarized below. ## CUTLASS Template Library @@ -352,7 +355,7 @@ tools/ The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate basic usage of Core API components and complete tests of the CUTLASS GEMM computations. -Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/quickstart.md). +Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/cpp/quickstart.md). # Performance Profiling @@ -568,9 +571,9 @@ reference_device: Passed ## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler - Please follow the links for more CMake examples on selectively compiling CUTLASS kernels: - - [GEMM CMake Examples](./media/docs/quickstart.md#gemm-cmake-examples) - - [Implicit GEMM convolution CMake Examples](./media/docs/quickstart.md#convolution-cmake-examples) -- [Further details about the CUTLASS Profiler are described here.](./media/docs/profiler.md) + - [GEMM CMake Examples](./media/docs/cpp/quickstart.md#gemm-cmake-examples) + - [Implicit GEMM convolution CMake Examples](./media/docs/cpp/quickstart.md#convolution-cmake-examples) +- [Further details about the CUTLASS Profiler are described here.](./media/docs/cpp/profiler.md) # About diff --git a/examples/65_distributed_gemm/65_distributed_gemm.cu b/examples/65_distributed_gemm/65_distributed_gemm.cu index 2289d62a..90b6ff8b 100644 --- a/examples/65_distributed_gemm/65_distributed_gemm.cu +++ b/examples/65_distributed_gemm/65_distributed_gemm.cu @@ -834,10 +834,10 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (props.major < 9) { + if (props.major != 9 || props.minor != 0) { std::cerr - << "This example requires a GPU of NVIDIA's Hopper Architecture or " - << "later (compute capability 90 or greater)." << std::endl; + << "This example requires a GPU of NVIDIA's Hopper Architecture " + << "(compute capability 90)." << std::endl; return 0; } diff --git a/examples/65_distributed_gemm/README.md b/examples/65_distributed_gemm/README.md index e3c48a9d..6bfff53c 100644 --- a/examples/65_distributed_gemm/README.md +++ b/examples/65_distributed_gemm/README.md @@ -63,6 +63,10 @@ procedure is the same, simply modify the following line in the example: using TP = _8; ``` +## References +* [Distributed GEMM Blog](https://blog.shi-labs.com/distributed-gemm-88be6a481e2b) +* [Distributed GEMM Talk on CUDA Mode](https://www.youtube.com/watch?v=NHRTCQBZokg) + ## Copyright Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/examples/65_distributed_gemm/REQUIREMENTS.md b/examples/65_distributed_gemm/REQUIREMENTS.md index 4b8cca3b..c6288a91 100644 --- a/examples/65_distributed_gemm/REQUIREMENTS.md +++ b/examples/65_distributed_gemm/REQUIREMENTS.md @@ -17,6 +17,8 @@ Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit ar This example specifically requires CUDA Toolkit 12.6 or newer, due to some of the necessary CUDA graph APIs. +The minimum CUDA driver version for running this example is [560.28.03](https://docs.nvidia.com/cuda/archive/12.6.0/cuda-toolkit-release-notes/index.html#id5). + ### Hardware / driver settings This example requires Hopper GPUs with NVLink network. diff --git a/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu b/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu index 3148d2aa..10cfe89d 100644 --- a/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu +++ b/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu @@ -30,11 +30,9 @@ **************************************************************************************************/ /*! \file - \brief A FP8 blockwise scaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. + \brief An FP8 blockwise scaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. */ - - #include #include "cutlass/cutlass.h" @@ -115,7 +113,7 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutC, AlignmentD, - cutlass::epilogue::TmaWarpSpecialized1Sm + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -125,7 +123,7 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder ElementAccumulator, MmaTileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100 // Note: Groupwise and Blockwise only support 1 SM MMA at this moment + cutlass::gemm::KernelScheduleSm100Blockwise >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -222,8 +220,7 @@ struct Options { } /// Compute performance in GFLOP/s - double gflops(double runtime_s) const - { + double gflops(double runtime_s) const { // Two flops per multiply-add uint64_t flop = uint64_t(2) * m * n * k; double gflop = double(flop) / double(1.0e9); @@ -232,8 +229,7 @@ struct Options { }; /// Result structure -struct Result -{ +struct Result { double avg_runtime_ms; double gflops; cutlass::Status status; @@ -273,13 +269,16 @@ bool initialize_tensor( if (bits_input == 1) { scope_max = 2; scope_min = 0; - } else if (bits_input <= 8) { + } + else if (bits_input <= 8) { scope_max = 2; scope_min = -2; - } else if (bits_output == 16) { + } + else if (bits_output == 16) { scope_max = 5; scope_min = -5; - } else { + } + else { scope_max = 8; scope_min = -8; } @@ -392,8 +391,7 @@ void initialize(const Options &options) { } /// Populates a Gemm::Arguments structure from the given commandline options -typename Gemm::Arguments args_from_options(const Options &options) -{ +typename Gemm::Arguments args_from_options(const Options &options) { typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {options.m, options.n, options.k, options.l}, @@ -468,8 +466,7 @@ bool verify(const Options &options) { /// Execute a given example GEMM computation template -int run(Options &options) -{ +int run(Options &options) { initialize(options); @@ -510,8 +507,7 @@ int run(Options &options) } // Run profiling loop - if (options.iterations > 0) - { + if (options.iterations > 0) { GpuTimer timer; timer.start(); for (int iter = 0; iter < options.iterations; ++iter) { diff --git a/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu b/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu index 11083e09..6d8d1de0 100644 --- a/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu +++ b/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu @@ -30,7 +30,7 @@ **************************************************************************************************/ /*! \file - \brief A FP8 groupwise scaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. + \brief An FP8 groupwise scaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. */ #include @@ -96,9 +96,9 @@ using ElementCompute = float; // MMA and Cluster Tile Shapes // Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 -using MmaTileShape_MNK = Shape<_128,_128,_128>; +using MmaTileShape_MNK = Shape<_256,_128,_128>; // Shape of the threadblocks in a cluster -using ClusterShape_MNK = Shape<_1,_1,_1>; +using ClusterShape_MNK = Shape<_2,_1,_1>; constexpr int ScaleGranularityM = 1; constexpr int ScaleGranularityN = 128; @@ -120,7 +120,7 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutC, AlignmentD, - cutlass::epilogue::TmaWarpSpecialized1Sm + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -130,7 +130,7 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder ElementAccumulator, MmaTileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100 // Note: Groupwise and Blockwise only support 1 SM MMA at this moment + cutlass::gemm::KernelScheduleSm100Blockwise >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -227,8 +227,7 @@ struct Options { } /// Compute performance in GFLOP/s - double gflops(double runtime_s) const - { + double gflops(double runtime_s) const { // Two flops per multiply-add uint64_t flop = uint64_t(2) * m * n * k; double gflop = double(flop) / double(1.0e9); @@ -237,8 +236,7 @@ struct Options { }; /// Result structure -struct Result -{ +struct Result { double avg_runtime_ms; double gflops; cutlass::Status status; @@ -278,13 +276,16 @@ bool initialize_tensor( if (bits_input == 1) { scope_max = 2; scope_min = 0; - } else if (bits_input <= 8) { + } + else if (bits_input <= 8) { scope_max = 2; scope_min = -2; - } else if (bits_output == 16) { + } + else if (bits_output == 16) { scope_max = 5; scope_min = -5; - } else { + } + else { scope_max = 8; scope_min = -8; } @@ -397,9 +398,8 @@ void initialize(const Options &options) { } /// Populates a Gemm::Arguments structure from the given commandline options -typename Gemm::Arguments args_from_options(const Options &options) -{ - typename Gemm::Arguments arguments{ +typename Gemm::Arguments args_from_options(const Options &options) { + typename Gemm::Arguments arguments { cutlass::gemm::GemmUniversalMode::kGemm, {options.m, options.n, options.k, options.l}, {tensor_A.device_data(), stride_A, @@ -473,8 +473,7 @@ bool verify(const Options &options) { /// Execute a given example GEMM computation template -int run(Options &options) -{ +int run(Options &options) { initialize(options); @@ -515,8 +514,7 @@ int run(Options &options) } // Run profiling loop - if (options.iterations > 0) - { + if (options.iterations > 0) { GpuTimer timer; timer.start(); for (int iter = 0; iter < options.iterations; ++iter) { diff --git a/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu b/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu new file mode 100644 index 00000000..b43869e7 --- /dev/null +++ b/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu @@ -0,0 +1,754 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief An FP8 blockwise-scaled grouped GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. + In this example M, N, and K are fixed across groups. +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +#include "cutlass/util/reference/host/gett.hpp" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// MMA type +using ElementAccumulator = float; +using ElementCompute = float; + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape_MNK = Shape<_128,_128,_128>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_1,_1,_1>; +// Shape of the tile computed by each SM + +using ScaleConfig = decltype(cutlass::detail::sm100_trivial_blockwise_scale_config(MmaTileShape_MNK{})); + +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutC *, AlignmentD, + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100 + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; + +static_assert(cute::is_same_v); +static_assert(cute::is_same_v); + +/// Initialization +uint64_t seed; + +// Host-side allocations +std::vector offset_A; +std::vector offset_B; +std::vector offset_C; +std::vector offset_D; +std::vector offset_SFA; +std::vector offset_SFB; + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; +std::vector layout_SFA_host; +std::vector layout_SFB_host; + +std::vector ptr_ref_D_host; + +std::vector ptr_A_host; +std::vector ptr_B_host; +std::vector ptr_C_host; +std::vector ptr_D_host; +std::vector ptr_SFA_host; +std::vector ptr_SFB_host; + +// Shared Allocations + +cutlass::HostTensor block_A; +cutlass::HostTensor block_B; +cutlass::HostTensor block_C; +cutlass::HostTensor block_D; +cutlass::HostTensor block_ref_D; +cutlass::HostTensor block_SFA; +cutlass::HostTensor block_SFB; + +// Device-side allocations +cutlass::DeviceAllocation problem_sizes; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_SFA; +cutlass::DeviceAllocation ptr_SFB; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; +cutlass::DeviceAllocation layout_SFA; +cutlass::DeviceAllocation layout_SFB; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + bool skip_verification = false; + + float alpha = 1.f, beta = 0.f; + int iterations = 1000; + int m = 1024, n = 2048, k = 512, groups = 10; + std::vector problem_sizes_host; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("skip-verification")) { + skip_verification = true; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + + for (int i = 0; i < groups; ++i) { + problem_sizes_host.push_back({m, n, k}); + } + + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "81_blackwell_grouped_gemm_blockwise\n\n" + << " Blackwell FP8 GEMM with Blockwise Scaling using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --skip-verification Skip verification.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "81_blackwell_grouped_gemm_blockwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * groups; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result { + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} + +/// Helper to initialize a block of device data (scale_tensors) +template +bool initialize_scale_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + + scope_min = -1; + scope_max = 1; + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(Options const& options) { + int32_t total_elements_A = 0; + int32_t total_elements_B = 0; + int32_t total_elements_C = 0; + int32_t total_elements_D = 0; + int32_t total_elements_SFA = 0; + int32_t total_elements_SFB = 0; + + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + offset_SFA.push_back(total_elements_SFA); + offset_SFB.push_back(total_elements_SFB); + + int32_t elements_A = M * K; + int32_t elements_B = K * N; + int32_t elements_C = M * N; + int32_t elements_D = M * N; + + auto gemm_layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1)); + auto gemm_layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1)); + + int32_t elements_SFA = cosize(gemm_layout_SFA); + int32_t elements_SFB = cosize(gemm_layout_SFB); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + total_elements_SFA += elements_SFA; + total_elements_SFB += elements_SFB; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + layout_SFA_host.push_back(gemm_layout_SFA); + layout_SFB_host.push_back(gemm_layout_SFB); + } + + block_A.resize(cutlass::make_Coord(total_elements_A)); + block_B.resize(cutlass::make_Coord(total_elements_B)); + block_C.resize(cutlass::make_Coord(total_elements_C)); + block_D.resize(cutlass::make_Coord(total_elements_D)); + block_ref_D.resize(cutlass::make_Coord(total_elements_D)); + block_SFA.resize(cutlass::make_Coord(total_elements_SFA)); + block_SFB.resize(cutlass::make_Coord(total_elements_SFB)); + + initialize_tensor(block_A.host_view(), cutlass::Distribution::Uniform, seed + 2022); + initialize_tensor(block_B.host_view(), cutlass::Distribution::Uniform, seed + 2023); + initialize_tensor(block_C.host_view(), cutlass::Distribution::Uniform, seed + 2024); + initialize_scale_tensor(block_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2026); + initialize_scale_tensor(block_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2027); + + block_A.sync_device(); + block_B.sync_device(); + block_C.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); + + // copy problem sizes + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + std::vector device_ptr_A_host(options.groups); + std::vector device_ptr_B_host(options.groups); + std::vector device_ptr_C_host(options.groups); + std::vector device_ptr_D_host(options.groups); + std::vector device_ptr_SFA_host(options.groups); + std::vector device_ptr_SFB_host(options.groups); + + ptr_A_host = std::vector(options.groups); + ptr_B_host = std::vector(options.groups); + ptr_C_host = std::vector(options.groups); + ptr_D_host = std::vector(options.groups); + ptr_SFA_host = std::vector(options.groups); + ptr_SFB_host = std::vector(options.groups); + ptr_ref_D_host = std::vector(options.groups); + + for (int32_t i = 0; i < options.groups; ++i) { + // Ptrs for A + ptr_A_host.at(i) = block_A.host_data() + offset_A.at(i); + device_ptr_A_host.at(i) = block_A.device_data() + offset_A.at(i); + + // Ptrs for B + ptr_B_host.at(i) = block_B.host_data() + offset_B.at(i); + device_ptr_B_host.at(i) = block_B.device_data() + offset_B.at(i); + + // Ptrs for C + ptr_C_host.at(i) = block_C.host_data() + offset_C.at(i); + device_ptr_C_host.at(i) = block_C.device_data() + offset_C.at(i); + + // Ptrs for D + ptr_D_host.at(i) = block_D.host_data() + offset_D.at(i); + device_ptr_D_host.at(i) = block_D.device_data() + offset_D.at(i); + ptr_ref_D_host.at(i) = block_ref_D.host_data() + offset_D.at(i); + + // Ptrs for SFA + ptr_SFA_host.at(i) = block_SFA.host_data() + offset_SFA.at(i); + device_ptr_SFA_host.at(i) = block_SFA.device_data() + offset_SFA.at(i); + + // Ptrs for SFB + ptr_SFB_host.at(i) = block_SFB.host_data() + offset_SFB.at(i); + device_ptr_SFB_host.at(i) = block_SFB.device_data() + offset_SFB.at(i); + } + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(device_ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(device_ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(device_ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(device_ptr_D_host.data()); + + ptr_SFA.reset(options.groups); + ptr_SFA.copy_from_host(device_ptr_SFA_host.data()); + + ptr_SFB.reset(options.groups); + ptr_SFB.copy_from_host(device_ptr_SFB_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + layout_SFA.reset(options.groups); + layout_SFA.copy_from_host(layout_SFA_host.data()); + + layout_SFB.reset(options.groups); + layout_SFB.copy_from_host(layout_SFB_host.data()); + +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), + ptr_B.get(), stride_B.get(), + ptr_SFA.get(), layout_SFA.get(), + ptr_SFB.get(), layout_SFB.get() + }, + { + {}, // epilogue.thread + ptr_C.get(), stride_C.get(), + ptr_D.get(), stride_D.get() + }, + hw_info + }; + + auto &fusion_args = arguments.epilogue.thread; + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + block_D.sync_host(); + + for (int i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(ptr_A_host.at(i), + cute::make_layout(cute::make_shape(M, K, 1), stride_A_host.at(i))); + auto B = cute::make_tensor(ptr_B_host.at(i), + cute::make_layout(cute::make_shape(N, K, 1), stride_B_host.at(i))); + auto C = cute::make_tensor(ptr_C_host.at(i), + cute::make_layout(cute::make_shape(M, N, 1), stride_C_host.at(i))); + auto D = cute::make_tensor(ptr_ref_D_host.at(i), + cute::make_layout(cute::make_shape(M, N, 1), stride_D_host.at(i))); + + auto SFA = cute::make_tensor(ptr_SFA_host.at(i), layout_SFA_host.at(i)); + auto SFB = cute::make_tensor(ptr_SFB_host.at(i), layout_SFB_host.at(i)); + + using unused_t = decltype(D); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, + decltype(A), + decltype(SFA), + decltype(B), + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; + + cutlass::reference::host::GettEpilogueParams< + ElementAccumulator, + ElementAccumulator, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D) + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + } + + bool passed = cutlass::reference::host::TensorEquals(block_ref_D.host_view(), block_D.host_view()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) { + initialize(options); + + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + Result result; + if (!options.skip_verification) { + // Check if output from CUTLASS kernel and reference kernel are equal or not + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + } + + // Run profiling loop + if (options.iterations > 0) { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.groups << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least sm100a. + + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 || props.minor != 0) { + std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl; + return 0; + } + + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Run + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu b/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu new file mode 100644 index 00000000..60667cda --- /dev/null +++ b/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu @@ -0,0 +1,761 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief An FP8 blockwise-scaled grouped GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. + In this example M, N, and K are fixed across groups. +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +#include "cutlass/util/reference/host/gett.hpp" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// MMA type +using ElementAccumulator = float; +using ElementCompute = float; + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape_MNK = Shape<_256,_128,_128>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_2,_1,_1>; +// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2 + +constexpr int ScaleGranularityM = 1; +constexpr int ScaleGranularityN = 128; +constexpr int ScaleGranularityK = 128; +using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig; + +// Note when we have multiple scale factors per tile (in this case 128 scales in M per tile), we will restrict up to a +// 16B alignment if possible (i.e., we have at least 16B of scales in M). +// In this case the smallest M that can be executed is 16. To avoid this for smaller M, you can swap A and B +// and transpose A, B, C, and scales since B^T A^T = C^T. +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutC *, AlignmentD, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100 + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; + +static_assert(cute::is_same_v); +static_assert(cute::is_same_v); + +/// Initialization +uint64_t seed; + +// Host-side allocations +std::vector offset_A; +std::vector offset_B; +std::vector offset_C; +std::vector offset_D; +std::vector offset_SFA; +std::vector offset_SFB; + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; +std::vector layout_SFA_host; +std::vector layout_SFB_host; + +std::vector ptr_ref_D_host; + +std::vector ptr_A_host; +std::vector ptr_B_host; +std::vector ptr_C_host; +std::vector ptr_D_host; +std::vector ptr_SFA_host; +std::vector ptr_SFB_host; + +// Shared Allocations + +cutlass::HostTensor block_A; +cutlass::HostTensor block_B; +cutlass::HostTensor block_C; +cutlass::HostTensor block_D; +cutlass::HostTensor block_ref_D; +cutlass::HostTensor block_SFA; +cutlass::HostTensor block_SFB; + +// Device-side allocations +cutlass::DeviceAllocation problem_sizes; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_SFA; +cutlass::DeviceAllocation ptr_SFB; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; +cutlass::DeviceAllocation layout_SFA; +cutlass::DeviceAllocation layout_SFB; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + bool skip_verification = false; + + float alpha = 1.f, beta = 0.f; + int iterations = 1000; + int m = 1024, n = 2048, k = 512, groups = 10; + std::vector problem_sizes_host; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("skip-verification")) { + skip_verification = true; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + + for (int i = 0; i < groups; ++i) { + problem_sizes_host.push_back({m, n, k}); + } + + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "81_blackwell_grouped_gemm_groupwise\n\n" + << " Blackwell FP8 GEMM with Groupwise Scaling using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --skip-verification Skip verification.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "81_blackwell_grouped_gemm_groupwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * groups; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result { + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} + +/// Helper to initialize a block of device data (scale_tensors) +template +bool initialize_scale_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + + scope_min = -1; + scope_max = 1; + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + int32_t total_elements_A = 0; + int32_t total_elements_B = 0; + int32_t total_elements_C = 0; + int32_t total_elements_D = 0; + int32_t total_elements_SFA = 0; + int32_t total_elements_SFB = 0; + + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + offset_SFA.push_back(total_elements_SFA); + offset_SFB.push_back(total_elements_SFB); + + int32_t elements_A = M * K; + int32_t elements_B = K * N; + int32_t elements_C = M * N; + int32_t elements_D = M * N; + + auto gemm_layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1)); + auto gemm_layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1)); + + int32_t elements_SFA = cosize(gemm_layout_SFA); + int32_t elements_SFB = cosize(gemm_layout_SFB); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + total_elements_SFA += elements_SFA; + total_elements_SFB += elements_SFB; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + layout_SFA_host.push_back(gemm_layout_SFA); + layout_SFB_host.push_back(gemm_layout_SFB); + } + + block_A.resize(cutlass::make_Coord(total_elements_A)); + block_B.resize(cutlass::make_Coord(total_elements_B)); + block_C.resize(cutlass::make_Coord(total_elements_C)); + block_D.resize(cutlass::make_Coord(total_elements_D)); + block_ref_D.resize(cutlass::make_Coord(total_elements_D)); + block_SFA.resize(cutlass::make_Coord(total_elements_SFA)); + block_SFB.resize(cutlass::make_Coord(total_elements_SFB)); + + initialize_tensor(block_A.host_view(), cutlass::Distribution::Uniform, seed + 2022); + initialize_tensor(block_B.host_view(), cutlass::Distribution::Uniform, seed + 2023); + initialize_tensor(block_C.host_view(), cutlass::Distribution::Uniform, seed + 2024); + initialize_scale_tensor(block_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2026); + initialize_scale_tensor(block_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2027); + + block_A.sync_device(); + block_B.sync_device(); + block_C.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); + + // copy problem sizes + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + std::vector device_ptr_A_host(options.groups); + std::vector device_ptr_B_host(options.groups); + std::vector device_ptr_C_host(options.groups); + std::vector device_ptr_D_host(options.groups); + std::vector device_ptr_SFA_host(options.groups); + std::vector device_ptr_SFB_host(options.groups); + + ptr_A_host = std::vector(options.groups); + ptr_B_host = std::vector(options.groups); + ptr_C_host = std::vector(options.groups); + ptr_D_host = std::vector(options.groups); + ptr_SFA_host = std::vector(options.groups); + ptr_SFB_host = std::vector(options.groups); + ptr_ref_D_host = std::vector(options.groups); + + for (int32_t i = 0; i < options.groups; ++i) { + // Ptrs for A + ptr_A_host.at(i) = block_A.host_data() + offset_A.at(i); + device_ptr_A_host.at(i) = block_A.device_data() + offset_A.at(i); + + // Ptrs for B + ptr_B_host.at(i) = block_B.host_data() + offset_B.at(i); + device_ptr_B_host.at(i) = block_B.device_data() + offset_B.at(i); + + // Ptrs for C + ptr_C_host.at(i) = block_C.host_data() + offset_C.at(i); + device_ptr_C_host.at(i) = block_C.device_data() + offset_C.at(i); + + // Ptrs for D + ptr_D_host.at(i) = block_D.host_data() + offset_D.at(i); + device_ptr_D_host.at(i) = block_D.device_data() + offset_D.at(i); + ptr_ref_D_host.at(i) = block_ref_D.host_data() + offset_D.at(i); + + // Ptrs for SFA + ptr_SFA_host.at(i) = block_SFA.host_data() + offset_SFA.at(i); + device_ptr_SFA_host.at(i) = block_SFA.device_data() + offset_SFA.at(i); + + // Ptrs for SFB + ptr_SFB_host.at(i) = block_SFB.host_data() + offset_SFB.at(i); + device_ptr_SFB_host.at(i) = block_SFB.device_data() + offset_SFB.at(i); + } + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(device_ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(device_ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(device_ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(device_ptr_D_host.data()); + + ptr_SFA.reset(options.groups); + ptr_SFA.copy_from_host(device_ptr_SFA_host.data()); + + ptr_SFB.reset(options.groups); + ptr_SFB.copy_from_host(device_ptr_SFB_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + layout_SFA.reset(options.groups); + layout_SFA.copy_from_host(layout_SFA_host.data()); + + layout_SFB.reset(options.groups); + layout_SFB.copy_from_host(layout_SFB_host.data()); + +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), + ptr_B.get(), stride_B.get(), + ptr_SFA.get(), layout_SFA.get(), + ptr_SFB.get(), layout_SFB.get() + }, + { + {}, // epilogue.thread + ptr_C.get(), stride_C.get(), + ptr_D.get(), stride_D.get() + }, + hw_info + }; + + auto &fusion_args = arguments.epilogue.thread; + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + block_D.sync_host(); + + for (int i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(ptr_A_host.at(i), + cute::make_layout(cute::make_shape(M, K, 1), stride_A_host.at(i))); + auto B = cute::make_tensor(ptr_B_host.at(i), + cute::make_layout(cute::make_shape(N, K, 1), stride_B_host.at(i))); + auto C = cute::make_tensor(ptr_C_host.at(i), + cute::make_layout(cute::make_shape(M, N, 1), stride_C_host.at(i))); + auto D = cute::make_tensor(ptr_ref_D_host.at(i), + cute::make_layout(cute::make_shape(M, N, 1), stride_D_host.at(i))); + + auto SFA = cute::make_tensor(ptr_SFA_host.at(i), layout_SFA_host.at(i)); + auto SFB = cute::make_tensor(ptr_SFB_host.at(i), layout_SFB_host.at(i)); + + using unused_t = decltype(D); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, + decltype(A), + decltype(SFA), + decltype(B), + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; + + cutlass::reference::host::GettEpilogueParams< + ElementAccumulator, + ElementAccumulator, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D) + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + } + + bool passed = cutlass::reference::host::TensorEquals(block_ref_D.host_view(), block_D.host_view()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) { + initialize(options); + + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + Result result; + if (!options.skip_verification) { + // Check if output from CUTLASS kernel and reference kernel are equal or not + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + } + + // Run profiling loop + if (options.iterations > 0) { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.groups << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least sm100a. + + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 || props.minor != 0) { + std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl; + return 0; + } + + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Run + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/81_blackwell_gemm_blockwise/CMakeLists.txt b/examples/81_blackwell_gemm_blockwise/CMakeLists.txt index a4dc34d0..8b981546 100644 --- a/examples/81_blackwell_gemm_blockwise/CMakeLists.txt +++ b/examples/81_blackwell_gemm_blockwise/CMakeLists.txt @@ -54,4 +54,18 @@ cutlass_example_add_executable( TEST_SMALL ) +cutlass_example_add_executable( + 81_blackwell_grouped_gemm_blockwise + 81_blackwell_grouped_gemm_blockwise.cu + TEST_COMMAND_OPTIONS + TEST_SMALL +) + +cutlass_example_add_executable( + 81_blackwell_grouped_gemm_groupwise + 81_blackwell_grouped_gemm_groupwise.cu + TEST_COMMAND_OPTIONS + TEST_SMALL +) + endif() diff --git a/include/cute/algorithm/cooperative_gemm.hpp b/include/cute/algorithm/cooperative_gemm.hpp index b1490b02..f0a99359 100644 --- a/include/cute/algorithm/cooperative_gemm.hpp +++ b/include/cute/algorithm/cooperative_gemm.hpp @@ -98,19 +98,23 @@ epilogue_predication(ThrMMA const& thr_mma, } } -template + class SmemCopyLdOpC, class SmemCopyStOpC> CUTE_HOST_DEVICE void -epilogue_no_predication(Alpha const& alpha, +epilogue_no_predication(uint32_t thread_idx, + ThrMMA const& thr_mma, + Alpha const& alpha, Tensor & tCrC, Beta const& beta, - Tensor & tCsC, + Tensor & sC, CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM CStoreTransformOp const& sC_store_op, // transforms results before they are stored to C - SmemCopyOpC const& sC_copy_op) + SmemCopyLdOpC const& sC_copy_ld_op, + SmemCopyStOpC const& sC_copy_st_op) { using InputTypeC = typename TSC::value_type; using ComputeTypeC = typename TRC::value_type; @@ -125,10 +129,18 @@ epilogue_no_predication(Alpha const& alpha, CUTE_GCC_UNREACHABLE; } (); - Tensor tCrDi = make_fragment_like(tCsC); Tensor tCrD = make_fragment_like(tCrC); + Tensor tCrDi = make_fragment_like(tCrD); + if(!isBetaZero) { - copy(sC_copy_op, tCsC, tCrDi); + auto smem_tiled_copy_C = make_tiled_copy_C(Copy_Atom{}, thr_mma); + auto smem_thr_copy_C = smem_tiled_copy_C.get_thread_slice(thread_idx); + Tensor tCsC = smem_thr_copy_C.partition_S(sC); + Tensor tCrDi_copy_view = smem_thr_copy_C.retile_D(tCrDi); + CUTE_STATIC_ASSERT_V(size<1>(tCsC) == size<1>(tCrDi_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsC) == size<2>(tCrDi_copy_view)); // CPY_N + copy(smem_tiled_copy_C, tCsC, tCrDi_copy_view); + // Transform C on/after load cute::transform(tCrDi, tCrD, sC_load_op); } @@ -136,7 +148,14 @@ epilogue_no_predication(Alpha const& alpha, axpby(alpha, tCrC, beta, tCrD); // Transform C before/on store cute::transform(tCrD, tCrDi, sC_store_op); - copy(sC_copy_op, tCrDi, tCsC); + + auto smem_tiled_copy_C = make_tiled_copy_C(Copy_Atom{}, thr_mma); + auto smem_thr_copy_C = smem_tiled_copy_C.get_thread_slice(thread_idx); + Tensor tCsC = smem_thr_copy_C.partition_D(sC); + Tensor tCrDi_copy_view = smem_thr_copy_C.retile_S(tCrDi); + CUTE_STATIC_ASSERT_V(size<1>(tCsC) == size<1>(tCrDi_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsC) == size<2>(tCrDi_copy_view)); // CPY_N + copy(smem_tiled_copy_C, tCrDi_copy_view, tCsC); } // Predicated Cooperative GEMM @@ -283,7 +302,9 @@ cooperative_gemm_no_predication(uint32_t thread_idx, // Create register tensors for the MMA to operate on Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCrAi = make_fragment_like(tCrA); Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K) + Tensor tCrBi = make_fragment_like(tCrB); using CopyOpAType = SmemCopyOpA; using CopyOpBType = SmemCopyOpB; @@ -291,7 +312,6 @@ cooperative_gemm_no_predication(uint32_t thread_idx, auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, thr_mma); auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); Tensor tCsA = smem_thr_copy_A.partition_S(sA); - Tensor tCrAi = make_fragment_like(tCsA); Tensor tCrAi_copy_view = smem_thr_copy_A.retile_D(tCrAi); CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrAi_copy_view)); // CPY_M CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrAi_copy_view)); // CPY_K @@ -299,7 +319,6 @@ cooperative_gemm_no_predication(uint32_t thread_idx, auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, thr_mma); auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); Tensor tCsB = smem_thr_copy_B.partition_S(sB); - Tensor tCrBi = make_fragment_like(tCsB); Tensor tCrBi_copy_view = smem_thr_copy_B.retile_D(tCrBi); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrBi_copy_view)); // CPY_N CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrBi_copy_view)); // CPY_K @@ -346,7 +365,7 @@ template + class SmemCopyLdOpC = DefaultCopy, class SmemCopyStOpC = DefaultCopy> CUTE_HOST_DEVICE void cooperative_gemm(uint32_t thread_idx, @@ -356,13 +375,14 @@ cooperative_gemm(uint32_t thread_idx, Tensor const& sB, Beta const& beta, Tensor & sC, - ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM - BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM - CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM - CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C - SmemCopyOpA const& sA_copy_op = {}, - SmemCopyOpB const& sB_copy_op = {}, - SmemCopyOpC const& sC_copy_op = {}) + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}, + SmemCopyLdOpC const& sC_copy_ld_op = {}, + SmemCopyStOpC const& sC_copy_st_op = {}) { CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); @@ -394,7 +414,7 @@ cooperative_gemm(uint32_t thread_idx, thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op ); detail::epilogue_no_predication( - alpha, tCrC, beta, tCsC, sC_load_op, sC_store_op, sC_copy_op + thread_idx, thr_mma,alpha, tCrC, beta, sC, sC_load_op, sC_store_op, sC_copy_ld_op, sC_copy_st_op ); } else { detail::cooperative_gemm_predication( @@ -466,7 +486,7 @@ template + class SmemCopyLdOpC = DefaultCopy, class SmemCopyStOpC = DefaultCopy> CUTE_HOST_DEVICE void cooperative_gemm(uint32_t thread_idx, @@ -476,17 +496,18 @@ cooperative_gemm(uint32_t thread_idx, Tensor const& sB, Beta const& beta, Tensor && sC, - ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM - BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM - CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM - CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C - SmemCopyOpA const& sA_copy_op = {}, - SmemCopyOpB const& sB_copy_op = {}, - SmemCopyOpC const& sC_copy_op = {}) + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}, + SmemCopyLdOpC const& sC_copy_ld_op = {}, + SmemCopyStOpC const& sC_copy_st_op = {}) { cooperative_gemm(thread_idx, tiled_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op, sC_load_op, sC_store_op, - sA_copy_op, sB_copy_op, sC_copy_op); + sA_copy_op, sB_copy_op, sC_copy_ld_op, sC_copy_st_op); } // Legacy overload of cute::gemm for backwards-compatibility diff --git a/include/cute/arch/mma_sm120.hpp b/include/cute/arch/mma_sm120.hpp index 84c09b8b..1433a2c8 100644 --- a/include/cute/arch/mma_sm120.hpp +++ b/include/cute/arch/mma_sm120.hpp @@ -3245,7 +3245,7 @@ rr_blockscaled_op_selector_sm120() { if constexpr (UseF8F6F4) { return SM120::BLOCKSCALED::SM120_16x8x32_TN_VS{}; - } + } else{ return SM120::BLOCKSCALED::SM120_16x8x64_TN_VS{}; } diff --git a/include/cute/arch/tmem_allocator_sm100.hpp b/include/cute/arch/tmem_allocator_sm100.hpp index 2d2cac9d..9839e740 100644 --- a/include/cute/arch/tmem_allocator_sm100.hpp +++ b/include/cute/arch/tmem_allocator_sm100.hpp @@ -57,7 +57,7 @@ public: * @pre Must never be issued by more than one warp at the same time. * @pre For repeated allocations, the same warp must be used to issue all allocations. **/ - __device__ void + CUTE_HOST_DEVICE void allocate(int num_columns, uint32_t* dst_ptr) { #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) uint32_t dst_intptr = cute::cast_smem_ptr_to_uint(dst_ptr); @@ -116,7 +116,7 @@ public: * @pre For repeated allocations, the same warp must be used to issue all allocations. * @pre The 2 warps from participating CTAs have the same logical warp ID. **/ - __device__ void + CUTE_HOST_DEVICE void allocate(int num_columns, uint32_t* dst_ptr) { #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) uint32_t dst_intptr = cute::cast_smem_ptr_to_uint(dst_ptr); diff --git a/include/cute/arch/util.hpp b/include/cute/arch/util.hpp index ac196028..6a3883e8 100644 --- a/include/cute/arch/util.hpp +++ b/include/cute/arch/util.hpp @@ -88,7 +88,7 @@ namespace cute { /// CUTE helper to cast SMEM pointer to unsigned -CUTE_DEVICE +CUTE_HOST_DEVICE uint32_t cast_smem_ptr_to_uint(void const* const ptr) { diff --git a/include/cute/numeric/math.hpp b/include/cute/numeric/math.hpp index 83dcd4e6..147458b8 100644 --- a/include/cute/numeric/math.hpp +++ b/include/cute/numeric/math.hpp @@ -57,7 +57,7 @@ template >(t) < static_cast>(u) ? t : u; } template & D, void const* ptr); ///////////////////////////////////////////////////////////////////////////////////////////////// /// CUTLASS helper to get SMEM pointer -CUTLASS_DEVICE unsigned cutlass_get_smem_pointer(void *ptr) { +CUTLASS_HOST_DEVICE unsigned cutlass_get_smem_pointer(void *ptr) { return cute::cast_smem_ptr_to_uint(ptr); } diff --git a/include/cutlass/arch/wmma.h b/include/cutlass/arch/wmma.h index 2cafa510..9cb9c04f 100644 --- a/include/cutlass/arch/wmma.h +++ b/include/cutlass/arch/wmma.h @@ -34,9 +34,6 @@ #pragma once -// CUTLASS WMMA does not support clang at present. -#if !(defined(__clang__) && defined(__CUDA__)) - #if (__CUDACC_VER_MAJOR__ >= 9) #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700)) #define CUTLASS_ARCH_WMMA_ENABLED @@ -58,8 +55,6 @@ #endif #endif -#endif //!(defined(__clang__) && defined(__CUDA__)) - #if defined(CUTLASS_ARCH_WMMA_ENABLED) #include diff --git a/include/cutlass/array.h b/include/cutlass/array.h index e1e18282..ce33110a 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -986,6 +986,21 @@ struct multiply_add, Array, Array> { return result; } + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, T const &scalar) const { + + Array result; + multiply_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], b[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE Array operator()(Array const &a, T const &scalar_b, T const &scalar_c) const { diff --git a/include/cutlass/epilogue/collective/builders/sm100_builder.inl b/include/cutlass/epilogue/collective/builders/sm100_builder.inl index 16eb4fc9..176b1f25 100644 --- a/include/cutlass/epilogue/collective/builders/sm100_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm100_builder.inl @@ -866,6 +866,45 @@ struct CallbacksBuilder< >; }; +// ptr array aux fusion callbacks builder for sm100 tma epilogue +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class FusionOp, + class CtaTileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator, + class AccLoadOp +> +struct CallbacksBuilder< + Sm100PtrArrayTmaWarpSpecialized, + FusionOp, + CtaTileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + AccLoadOp, + cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && not cute::is_subbyte_v> +> { + using GmemStrideTypeAux = gemm::TagToStrideC_t; + using SmemLayoutAtomAux = decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom< + GmemStrideTypeAux, typename FusionOp::ElementAux, EpilogueTile_MN>()); + using CopyOpR2S = decltype(detail::sm100_get_smem_store_op< + GmemStrideTypeAux, typename FusionOp::ElementAux, ElementAccumulator, AccLoadOp>()); + using CopyOpS2R = decltype(detail::sm100_get_smem_load_op< + GmemStrideTypeAux, typename FusionOp::ElementAux, ElementAccumulator, AccLoadOp>()); + using SmemCopyOpAux = cute::conditional_t; + + using Callbacks = fusion::FusionCallbacks< + Sm100PtrArrayTmaWarpSpecialized, + FusionOp, CtaTileShape_MNK, EpilogueTile_MN, + SmemLayoutAtomAux, SmemCopyOpAux + >; +}; + template < int StagesC, int StagesD, @@ -930,7 +969,7 @@ template < class ElementC_, class GmemLayoutTagC_, int AlignmentC, - class ElementD, + class ElementD_, class GmemLayoutTagD, int AlignmentD, class Schedule, @@ -943,6 +982,9 @@ private: static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule"); static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch"); + static constexpr bool DisableDestination = cute::is_void_v; + using ElementD = cute::conditional_t,ElementD_>; // prevents void ref breakages + // Passing void C disables source load + smem allocation static constexpr bool DisableSource = cute::is_void_v; using ElementC = cute::conditional_t; // prevents void ref breakages @@ -1168,7 +1210,7 @@ public: EpilogueTile_MN, ElementC_, // Need to pass void through to expose via GemmUniversal GmemStrideTypeC, - ElementD, + ElementD_, // Need to pass void through to expose via GemmUniversal GmemStrideTypeD, decltype(fusion_callbacks()), AccLoadOp, diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index f6844375..50a5420b 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -206,6 +206,46 @@ struct CallbacksBuilder< >; }; +// ptr array aux fusion callbacks builder for sm90 tma epilogue +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class AccLoadOp, + class ElementAccumulator +> +struct CallbacksBuilder< + Sm90PtrArrayTmaWarpSpecialized, + FusionOp, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + AccLoadOp, + cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && not cute::is_subbyte_v> // aux subbyte tensor doesn't use smem +> { + using GmemStrideTypeAux = gemm::TagToStrideC_t; + using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom< + GmemStrideTypeAux, typename FusionOp::ElementAux, EpilogueTile_MN>()); + using CopyOpR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator< + GmemStrideTypeAux, typename FusionOp::ElementAux>()); + using CopyOpS2R = decltype(detail::sm90_get_smem_load_op_for_source< + GmemStrideTypeAux, typename FusionOp::ElementAux>()); + using SmemCopyOpAux = cute::conditional_t; + + using Callbacks = fusion::FusionCallbacks< + Sm90PtrArrayTmaWarpSpecialized, + FusionOp, TileShape_MNK, EpilogueTile_MN, + SmemLayoutAtomAux, SmemCopyOpAux + >; +}; + template < int StagesC, int StagesD, diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp index c142a0b7..77a3b510 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp @@ -129,8 +129,13 @@ public: static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); private: - using GmemElementD = ElementD; - using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages + + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v; + using GmemElementD = cute::conditional_t>; + using GmemElementC = cute::conditional_t; // prevents void ref breakages + static_assert(not cute::is_void_v, "GmemElementD is void"); + using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; constexpr static int StagesC = StagesC_; @@ -138,9 +143,8 @@ private: static_assert(StagesC >= 1, "StagesC must be >= 1"); static_assert(StagesD >= 1, "StagesD must be >= 1"); - constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool ReuseSmemC = ReuseSmemC_ && is_destination_supported; constexpr static bool DelayTmaStore = DelayTmaStore_; - constexpr static bool is_source_supported = not cute::is_void_v; constexpr static bool is_m_major_C = detail::is_m_major(); constexpr static bool is_m_major_D = detail::is_m_major(); @@ -159,7 +163,7 @@ private: using SmemLayoutC = decltype(cute::append<3>(SmemLayoutStageC{}, Layout, Int>{})); using SmemLayoutD = decltype(cute::append<3>(SmemLayoutStageD{}, Layout, Int>{})); - constexpr static bool support_smem_reuse = is_source_supported && StagesD <= StagesC + constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC && MaxStageBits % sizeof_bits_v == 0 && MaxStageBits % sizeof_bits_v == 0; static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); @@ -239,7 +243,7 @@ public: using TMA_C = decltype(make_tma_copy( CopyOpG2S{}, make_tensor( - make_gmem_ptr(static_cast,ElementD,ElementC> const*>(nullptr)), + make_gmem_ptr(static_cast(nullptr)), TensorShapeC{}, append<3>(InternalStrideC{}, _0{})), SmemLayoutStageC{}, @@ -248,7 +252,7 @@ public: using TMA_D = decltype(make_tma_copy( CopyOpS2G{}, make_tensor( - make_gmem_ptr(static_cast(nullptr)), + make_gmem_ptr(static_cast(nullptr)), TensorShapeD{}, append<3>(InternalStrideD{}, _0{})), SmemLayoutStageD{}, @@ -278,6 +282,8 @@ public: // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. // These will be replaced with correct values before the initial tma load. auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1)); + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. constexpr int tma_alignment_bits = 128; auto init_M = tma_alignment_bits; auto init_N = tma_alignment_bits; @@ -308,10 +314,13 @@ public: tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutStageC{}, EpilogueTile{}, _1{}); } - // Tensor pointers will be fixed before the first access - ElementD* ptr_D_first_batch = nullptr; - Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{}))); - typename Params::TMA_D tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutStageD{}, EpilogueTile{}, _1{}); + typename Params::TMA_D tma_store_d{}; + if constexpr (is_destination_supported) { + // Tensor pointers will be fixed before the first access + ElementD* ptr_D_first_batch = nullptr; + Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{}))); + tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutStageD{}, EpilogueTile{}, _1{}); + } auto fusion_workspace = static_cast(workspace); auto fusion_workspace_size = round_nearest(FusionCallbacks::get_workspace_size(problem_shape, args.thread), MinTensorMapWorkspaceAlignment); @@ -359,9 +368,11 @@ public: auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); auto [M,N,K,L] = problem_shape_MNKL; - constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); - constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); + if constexpr (is_destination_supported) { + constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); + } if constexpr (is_source_supported) { constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); @@ -752,13 +763,9 @@ public: thread_idx }; - auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); - bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - // Thread synchronizer for previously issued waits or fences // to ensure visibility of smem reads/writes to threads or TMA unit - auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + auto synchronize = [] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; // Predication for sub-128 thread T2R tiled copy Layout tmem_warp_layout = typename decltype(make_tmem_warp_partitioner(tAcc_epi(_,_,0,0)))::TiledLayout_TV{}; @@ -795,31 +802,38 @@ public: [[maybe_unused]] int epi_n_prev = 0; static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); - auto epi_loop_fn = [&] (auto& cst_callbacks) { - // The TMA store sequence for one subtile iteration - auto tma_store_fn = [&] (int epi_m, int epi_n) { + // The Epilogue Loop + auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // The TMA store sequence for one epilogue loop iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) CUTLASS_LAMBDA_FUNC_INLINE { // Write the tile from smem to gmem with TMA cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA synchronize(); // ensure all threads have issued their async fence - if (issue_tma_store) { - copy(params.tma_store_d.with(get<0>(store_tensormap_info)), bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); - } + if constexpr (is_destination_supported) { + if (issue_tma_store) { + copy(params.tma_store_d.with(get<0>(store_tensormap_info)), bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + } + // Post async fence, pre TMA commit callback entry point cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); - + // Commit the TMA stores for this stage if (issue_tma_store) { store_pipeline.producer_commit(store_pipe_producer_state); } ++store_pipe_producer_state; - + // Wait for the next smem buffer to be available if (issue_tma_store) { store_pipeline.producer_acquire(store_pipe_producer_state); } synchronize(); - + if constexpr (ReuseSmemC) { // producer_acquire returns when at most StagesD-1 committed stores are pending bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; @@ -831,11 +845,7 @@ public: ++load_pipe_consumer_state; } } - }; - - // - // BEGIN EPILOGUE - // + }; // tma_store_fn // Begin the wait for the producer load results ConsumerToken load_wait_token{BarrierStatus::WaitDone}; @@ -953,8 +963,10 @@ public: // Copy output tile from register to smem bool issue_smem_store = issue_tmem_load; - if (issue_smem_store) { - copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + if constexpr (is_destination_supported) { + if (issue_smem_store) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } } // Post reduction, pre TMA store callback entry point @@ -982,9 +994,11 @@ public: cst_callbacks.end(); }; - epi_loop_fn(cst_callbacks); - cst_callbacks.end(); - + // + // BEGIN EPILOGUE + // + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + epi_loop_fn(cst_callbacks); return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_pipe_consumer_state); } @@ -1343,7 +1357,7 @@ public: } __syncwarp(); } - } else { + } else if constexpr (is_destination_supported) { int const offset_Ddesc = cute::is_void_v ? 0 : sm_count; tma_desc = &gmem_tensormap[sm_idx + offset_Ddesc]; if (cute::elect_one_sync()) { @@ -1374,7 +1388,7 @@ public: params.ptr_C[next_batch]); } } - } else { + } else if constexpr (is_destination_supported) { cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_D, params.ptr_D[next_batch]); } @@ -1414,7 +1428,7 @@ public: } } } - else { + else if constexpr (is_destination_supported) { ElementD const* ptr_D = nullptr; Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group])); @@ -1473,7 +1487,7 @@ public: } tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_C); } - } else { + } else if constexpr (is_destination_supported) { tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_D); } } @@ -1486,7 +1500,7 @@ public: if (is_source_supported) { cute::tma_descriptor_fence_acquire(tensormap); } - } else { + } else if constexpr (is_destination_supported) { cute::tma_descriptor_fence_acquire(tensormap); } } diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp index 778f8769..2eb5c582 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp @@ -646,12 +646,12 @@ public: thread_idx }; - auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks(cst_args); - bool is_C_load_needed = fusion_callbacks.is_C_load_needed(); - auto synchronize = [] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + // The Epilogue Loop auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + bool is_C_load_needed = fusion_callbacks.is_C_load_needed(); + // Ensure there are no threads from the previous wave writing to shared memory being utilized for the current wave. synchronize(); cst_callbacks.begin(); @@ -747,6 +747,10 @@ public: cst_callbacks.end(); }; + // + // BEGIN EPILOGUE + // + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); epi_loop_fn(cst_callbacks); return cute::make_tuple(acc_pipe_consumer_state); } diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp index a144accd..89e5448c 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp @@ -687,7 +687,7 @@ public: // OOB predication for tile quantization "residue" // Absolute coordinate tensors (dynamic) Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) - Tensor cD_mn = local_tile(mD_crd, take<0,2>(cta_tile_mnk), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(cta_tile_mnk), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) Tensor tTR_cD_mn = thread_t2r.partition_D(flat_divide(cD_mn, EpilogueTile{})); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) // Relative coordinate tensors (static) Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N) @@ -696,7 +696,7 @@ public: auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) auto residue_tTR_cD = make_coord(M,N) - tTR_cD_mn(_0{}); // (m,n) - // Get the fusion callbacks for the consumer store warps + // Arguments for the fusion callbacks for the consumer store warps constexpr bool RefSrc = false; // Register tensors reference T2R copy dst layout auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ problem_shape_mnkl, @@ -713,10 +713,6 @@ public: thread_idx }; - auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); - bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - // Thread synchronizer for previously issued waits or fences // to ensure visibility of smem reads/writes to threads or TMA unit auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; @@ -756,8 +752,12 @@ public: [[maybe_unused]] int epi_n_prev = 0; static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + // The Epilogue Loop auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - // The TMA store sequence for one subtile iteration + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // The TMA store sequence for one epilogue loop iteration auto tma_store_fn = [&] (int epi_m, int epi_n) CUTLASS_LAMBDA_FUNC_INLINE { // Write the tile from smem to gmem with TMA cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA @@ -765,22 +765,22 @@ public: if (issue_tma_store) { copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); } - + // Post async fence, pre TMA commit callback entry point cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); - + // Commit the TMA stores for this stage if (issue_tma_store) { store_pipeline.producer_commit(store_pipe_producer_state); } ++store_pipe_producer_state; - + // Wait for the next smem buffer to be available if (issue_tma_store) { store_pipeline.producer_acquire(store_pipe_producer_state); } synchronize(); - + if constexpr (ReuseSmemC) { // producer_acquire returns when at most StagesD-1 committed stores are pending bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; @@ -792,11 +792,8 @@ public: ++load_pipe_consumer_state; } } - }; + }; // tma_store_fn - // - // BEGIN EPILOGUE - // cst_callbacks.begin(); if (cst_callbacks.begin_sync_needed()) { synchronize(); @@ -941,9 +938,13 @@ public: } cst_callbacks.end(); - }; + }; // epi_loop_fn - epi_loop_fn(cst_callbacks); + // + // BEGIN EPILOGUE + // + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + epi_loop_fn(cst_callbacks); return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_pipe_consumer_state); } diff --git a/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp index da8dca23..d026b15c 100644 --- a/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp @@ -78,12 +78,13 @@ namespace detail { } }(); + // norm_constant and qpvscale_rcps are all positive numbers. + auto acc_scales = cutlass::multiplies>{}(norm_constant, qpvscale_rcps); + CUTLASS_PRAGMA_UNROLL for (int sf_v = 0; sf_v < NumVecs; ++sf_v) { - // norm_constant and qpvscale_rcps[sf_v] are all positive numbers. - ElementCompute acc_scale = mul(norm_constant, qpvscale_rcps[sf_v]); // Map INF to fp32::max - acc_scale = minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); + auto acc_scale = minimum_with_nan_propagation{}(acc_scales[sf_v], cutlass::platform::numeric_limits::max()); // Convert to output type output_frgs[sf_v] = cutlass::NumericArrayConverter{}(mul_array(compute_frgs[sf_v], acc_scale)); } @@ -240,17 +241,19 @@ struct Sm100BlockScaleFactorRowStore { cutlass::multiplies mul; cutlass::maximum_absolute_value_reduction, true> amax_reduction; + cutlass::Array vec_maxs; cutlass::Array pvscales; // SF generation CUTLASS_PRAGMA_UNROLL for (int sf_v = 0; sf_v < NumVecs; ++sf_v) { compute_frgs[sf_v] = NumericArrayConverter{}(input_frgs[sf_v]); /// Step1: get max across a vector - ElementCompute vec_max = amax_reduction(ElementCompute(0), compute_frgs[sf_v]); - /// Step2: Compute Scale - pvscales[sf_v] = mul(vec_max, norm_constant_scaled_down); + vec_maxs[sf_v] = amax_reduction(ElementCompute(0), compute_frgs[sf_v]); } + /// Step2: Compute Scale + pvscales = cutlass::multiplies>{}(vec_maxs, norm_constant_scaled_down); + tC_rSFD_frg(_0{}) = cutlass::NumericArrayConverter{}(pvscales); Tensor tCgSFD_flt = filter_zeros(tC_gSFD(_,_,_,_0{},_0{},get<0>(epi_tile_coord_mn) + epi_m, get<1>(epi_tile_coord_mn) + epi_n)); diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index c498a382..cd470f84 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -1191,9 +1191,11 @@ struct Sm90RowBroadcast { auto layout_M = make_layout(M, repeat_like(M, _0{})); auto layout_L = make_layout(L, get<2>(params.dRow)); - ElementInput const* ptr_row; + ElementInput const* ptr_row = nullptr; if constexpr(IsArrayOfPointers) { - ptr_row = params.ptr_row[l]; + if (!(EnableNullptr && params.ptr_row == nullptr)) { + ptr_row = params.ptr_row[l]; + } } else { ptr_row = params.ptr_row; } @@ -1439,9 +1441,11 @@ struct Sm90ColBroadcast { auto layout_N = make_layout(N, repeat_like(N, _0{})); auto layout_L = make_layout(L, get<2>(params.dCol)); - ElementInput const* ptr_col; + ElementInput const* ptr_col = nullptr; if constexpr(IsArrayOfPointers) { - ptr_col = params.ptr_col[l]; + if (!(EnableNullptr && params.ptr_col == nullptr)) { + ptr_col = params.ptr_col[l]; + } } else { ptr_col = params.ptr_col; } diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp index ce841bf2..93720f8d 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -116,6 +116,172 @@ sm90_partition_for_epilogue( // ///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Producer load callbacks, called by the epilogue load warp. +// Operations usually only define this if TMA load is needed. Most operations will reuse this empy implementation +// Load callbacks are responsible for issuing corresponding mbarrier expect-tx ops for any TMA loads issued, but +// are not responsible for issuing the producer_commit barrier arrival, which is issued by the collective instead +// If this is non-empty, is_producer_load_needed must be true. +// +template +struct ProducerLoadCallbacksImpl { + // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables + CallbacksTuple callbacks_tuple; + + // Before entry of the subtile load loop + CUTLASS_DEVICE void + begin() { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.begin(); + } + ); + } + + // Entry of the subtile load loop. Aux loads usually performed here + // Upon entry the producer acquire of the current subtile lock has completed. + // Upon exit all TMA loads for this subtile must have been issued, with corresponding expect-tx operations + CUTLASS_DEVICE void + step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.step(full_mbarrier_ptr, epi_m, epi_n, load_iteration, issue_tma_load); + } + ); + } + + // Exit of the subtile load loop. + CUTLASS_DEVICE void + end() { + for_each(callbacks_tuple, + [] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.end(); + } + ); + } +}; + + +// +// Consumer store callbacks, called by the epilogue store warps. +// All operations must redefine this, with optional inheritance from this empty implementation. +// +template +struct ConsumerStoreCallbacksImpl { + // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables + CallbacksTuple callbacks_tuple; + + // Before entry of subtile store loop. Gmem broadcasts usually performed here. + CUTLASS_DEVICE void + begin() { + for_each(callbacks_tuple, + [] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.begin(); + } + ); + } + + // Is a thread sync needed after begin(). Allows chaining async copies across multiple nodes + CUTLASS_DEVICE bool + begin_sync_needed() const { + return cute::apply(callbacks_tuple, + [] (auto const&... callbacks) { + return (false || ... || callbacks.begin_sync_needed()); + } + ); + } + + // Start of subtile store iteration + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.begin_loop(epi_m, epi_n); + } + ); + } + + // Before visit callback. Smem broadcasts usually performed here. + // Upon entry, all producer loads for this subtile are completed and visible. + CUTLASS_DEVICE void + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.previsit(epi_m, epi_n, load_iteration, is_producer_load_needed); + } + ); + } + + // Perform the fused elementwise computation + template + CUTLASS_DEVICE auto // returns an Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const&... frg_inputs) // depends on the N-naryness of the op + = delete; // Must be implemented for each operation + + // After visit call. Smem reductions usually performed here + // reduction_buffer is an arbitrary smem tensor that can be used for workspace + // It is each nodes reponsibility to assert that this buffer is sufficiently sized + // and to ensure that this buffer is no longer needed upon callback exit + // i.e. results are synchronized and no longer in the reduction buffer + // + // visit_results is a rmem tensor that contains the results of visit() for an entire + // on the current epilogue subtile + template + CUTLASS_DEVICE void + reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.reduce(reduction_buffer, sync_fn, epi_m, epi_n, is_last_iteration, visit_results); + } + ); + } + + // After reduce call, before smem async fence. Smem stores usually performed here. + // Upon exit, all smem stores for TMA must have been issued + CUTLASS_DEVICE void + postreduce(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.postreduce(epi_m, epi_n, store_iteration, issue_smem_store); + } + ); + } + + // After smem async fence, before TMA store commit. Aux stores usually performed here + // Upon exit, all TMA stores for this subtile must have been issued + // Because of the TMA store delay optimization, this entry point must ONLY be used for TMA stores + // other gmem stores can be placed in the reduce or postreduce entry points + CUTLASS_DEVICE void + tma_store(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.tma_store(epi_m, epi_n, store_iteration, issue_tma_store); + } + ); + } + + // End of subtile store iteration + CUTLASS_DEVICE void + end_loop(int epi_m, int epi_n) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.end_loop(epi_m, epi_n); + } + ); + } + + // Exit of subtile store loop. Gmem reductions usually performed here. + CUTLASS_DEVICE void + end() { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.end(); + } + ); + } +}; + template< class ProblemShapeMNKL, class TileShapeMNK, @@ -349,51 +515,6 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { ); } - // - // Producer load callbacks, called by the epilogue load warp. - // Operations usually only define this if TMA load is needed. Most operations will reuse this empy implementation - // Load callbacks are responsible for issuing corresponding mbarrier expect-tx ops for any TMA loads issued, but - // are not responsible for issuing the producer_commit barrier arrival, which is issued by the collective instead - // If this is non-empty, is_producer_load_needed must be true. - // - template - struct ProducerLoadCallbacks { - // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables - CallbacksTuple callbacks_tuple; - - // Before entry of the subtile load loop - CUTLASS_DEVICE void - begin() { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.begin(); - } - ); - } - - // Entry of the subtile load loop. Aux loads usually performed here - // Upon entry the producer acquire of the current subtile lock has completed. - // Upon exit all TMA loads for this subtile must have been issued, with corresponding expect-tx operations - CUTLASS_DEVICE void - step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.step(full_mbarrier_ptr, epi_m, epi_n, load_iteration, issue_tma_load); - } - ); - } - - // Exit of the subtile load loop. - CUTLASS_DEVICE void - end() { - for_each(callbacks_tuple, - [] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.end(); - } - ); - } - }; - // Producer load callbacks factory // All operations must redefine this, but most can just dispatch to the base impl template @@ -405,131 +526,11 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { }, [] (auto&&... callbacks) CUTLASS_LAMBDA_FUNC_INLINE { auto callbacks_tuple = cute::make_tuple(callbacks...); - return ProducerLoadCallbacks{callbacks_tuple}; + return ProducerLoadCallbacksImpl{callbacks_tuple}; } ); } - // - // Consumer store callbacks, called by the epilogue store warps. - // All operations must redefine this, with optional inheritance from this empty implementation. - // - template - struct ConsumerStoreCallbacks { - // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables - CallbacksTuple callbacks_tuple; - - // Before entry of subtile store loop. Gmem broadcasts usually performed here. - CUTLASS_DEVICE void - begin() { - for_each(callbacks_tuple, - [] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.begin(); - } - ); - } - - // Is a thread sync needed after begin(). Allows chaining async copies across multiple nodes - CUTLASS_DEVICE bool - begin_sync_needed() const { - return cute::apply(callbacks_tuple, - [] (auto const&... callbacks) { - return (false || ... || callbacks.begin_sync_needed()); - } - ); - } - - // Start of subtile store iteration - CUTLASS_DEVICE void - begin_loop(int epi_m, int epi_n) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.begin_loop(epi_m, epi_n); - } - ); - } - - // Before visit callback. Smem broadcasts usually performed here. - // Upon entry, all producer loads for this subtile are completed and visible. - CUTLASS_DEVICE void - previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.previsit(epi_m, epi_n, load_iteration, is_producer_load_needed); - } - ); - } - - // Perform the fused elementwise computation - template - CUTLASS_DEVICE auto // returns an Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const&... frg_inputs) // depends on the N-naryness of the op - = delete; // Must be implemented for each operation - - // After visit call. Smem reductions usually performed here - // reduction_buffer is an arbitrary smem tensor that can be used for workspace - // It is each nodes reponsibility to assert that this buffer is sufficiently sized - // and to ensure that this buffer is no longer needed upon callback exit - // i.e. results are synchronized and no longer in the reduction buffer - // - // visit_results is a rmem tensor that contains the results of visit() for an entire - // on the current epilogue subtile - template - CUTLASS_DEVICE void - reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.reduce(reduction_buffer, sync_fn, epi_m, epi_n, is_last_iteration, visit_results); - } - ); - } - - // After reduce call, before smem async fence. Smem stores usually performed here. - // Upon exit, all smem stores for TMA must have been issued - CUTLASS_DEVICE void - postreduce(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.postreduce(epi_m, epi_n, store_iteration, issue_smem_store); - } - ); - } - - // After smem async fence, before TMA store commit. Aux stores usually performed here - // Upon exit, all TMA stores for this subtile must have been issued - // Because of the TMA store delay optimization, this entry point must ONLY be used for TMA stores - // other gmem stores can be placed in the reduce or postreduce entry points - CUTLASS_DEVICE void - tma_store(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.tma_store(epi_m, epi_n, store_iteration, issue_tma_store); - } - ); - } - - // End of subtile store iteration - CUTLASS_DEVICE void - end_loop(int epi_m, int epi_n) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.end_loop(epi_m, epi_n); - } - ); - } - - // Exit of subtile store loop. Gmem reductions usually performed here. - CUTLASS_DEVICE void - end() { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.end(); - } - ); - } - }; - // Consumer store callbacks factory // All operations must redefine this template < @@ -544,7 +545,7 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { }, [] (auto&&... callbacks) CUTLASS_LAMBDA_FUNC_INLINE { auto callbacks_tuple = cute::make_tuple(callbacks...); - return ConsumerStoreCallbacks{callbacks_tuple}; + return ConsumerStoreCallbacksImpl{callbacks_tuple}; } ); } @@ -553,8 +554,8 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { ///////////////////////////////////////////////////////////////////////////////////////////////// // Convenience aliases -using EmptyProducerLoadCallbacks = Sm90VisitorImpl<>::ProducerLoadCallbacks>; -using EmptyConsumerStoreCallbacks = Sm90VisitorImpl<>::ConsumerStoreCallbacks>; +using EmptyProducerLoadCallbacks = ProducerLoadCallbacksImpl>; +using EmptyConsumerStoreCallbacks = ConsumerStoreCallbacksImpl>; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -614,9 +615,9 @@ struct Sm90TreeVisitor : Sm90VisitorImpl { > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto callbacks_tuple = Sm90VisitorImpl:: + auto callbacks_impl = Sm90VisitorImpl:: template get_consumer_store_callbacks(args); - return ConsumerStoreCallbacks(std::move(callbacks_tuple)); + return ConsumerStoreCallbacks(cute::move(callbacks_impl)); } }; @@ -663,9 +664,9 @@ struct Sm90SplitTreeVisitor : Sm90VisitorImpl CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto callbacks_tuple = Sm90VisitorImpl:: + auto callbacks_impl = Sm90VisitorImpl:: template get_consumer_store_callbacks(args); - return ConsumerStoreCallbacks(std::move(callbacks_tuple)); + return ConsumerStoreCallbacks(cute::move(callbacks_impl)); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -739,9 +740,9 @@ struct Sm90TopologicalVisitor : Sm90VisitorImpl { > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto callbacks_tuple = Sm90VisitorImpl:: + auto callbacks_impl = Sm90VisitorImpl:: template get_consumer_store_callbacks(args); - return ConsumerStoreCallbacks(std::move(callbacks_tuple)); + return ConsumerStoreCallbacks(cute::move(callbacks_impl)); } }; diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 145488ae..33c5585f 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -210,6 +210,45 @@ struct Clamp> { } }; +// Lower Bound +template +struct LowerBound { + struct Arguments { + T lower_bound; + }; + + CUTLASS_HOST_DEVICE + T operator()(T const& value, T const& lower_bound) const { + constexpr bool PropagateNaN = true; + maximum mx; + + return mx(value, lower_bound); + } + + CUTLASS_HOST_DEVICE + T operator()(T const& value, Arguments const& args = Arguments()) const { + return this->operator()(value, args.lower_bound); + } +}; + +template +struct LowerBound> { + using Arguments = typename LowerBound::Arguments; + + CUTLASS_HOST_DEVICE + Array operator()(Array const& values, T const& lower_bound) const { + constexpr bool PropagateNaN = true; + maximum, PropagateNaN> mx; + + return mx(values, lower_bound); + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const& values, Arguments const& args = Arguments()) const { + return this->operator()(values, args.lower_bound); + } +}; + // Leaky Relu operator template struct LeakyReLU { @@ -567,6 +606,28 @@ struct GELU_taylor { } }; +template <> +struct GELU_taylor { + static const bool kIsHeavy = true; + using T = float; + CUTLASS_HOST_DEVICE + T operator()(T const &z) const { + // 0.5f * (x + x * tanh(x * (0.797885f + 0.0356774f * x * x))); + T k0 = T(0.7978845608028654); + T tmp = T(0.044715); + T k1 = T(k0*tmp); + multiply_add fma; + multiplies mul; + T v0 = mul(k1, z); + T v1 = fma(v0, z, k0); + T v2 = mul(z, v1); + T v3 = fast_tanh(v2); + T v4 = fma(z, v3, z); + T v5 = mul(cutlass::constants::half(), v4); + return v5; + } +}; + template struct GELU_taylor > { static const bool kIsHeavy = true; @@ -594,6 +655,30 @@ struct GELU_taylor > { } }; +template +struct GELU_taylor > { + static const bool kIsHeavy = true; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &value) const { + multiply_add> fma; + multiplies> mul; + fast_tanh_op> tanh; + // 0.5f * (x + x * tanh(x * (0.797885f + 0.0356774f * x * x))); + float k0 = float(0.7978845608028654); + float tmp = float(0.044715); + float k1 = float(k0*tmp); + + Array v0 = mul(k1, value); + Array v1 = fma(v0, value, k0); + Array v2 = mul(value, v1); + Array v3 = tanh(v2); + Array v4 = fma(value, v3, value); + Array v5 = mul(cutlass::constants::half(), v4); + return v5; + } +}; + template struct GELU_taylor > { static const bool kIsHeavy = true; diff --git a/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h b/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h index bdd75a69..245499b0 100644 --- a/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h +++ b/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h @@ -43,8 +43,6 @@ #pragma once -#if !(defined(__clang__) && defined(__CUDA__)) - #include "cutlass/wmma_array.h" #include "cutlass/layout/matrix.h" @@ -158,7 +156,3 @@ public: //////////////////////////////////////////////////////////////////////////////// -#else -#error (defined(__clang__) && defined(__CUDA__)) -#endif // !defined(__clang__) - diff --git a/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h b/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h index 8dbb1282..8129dce1 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h +++ b/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h @@ -34,8 +34,6 @@ #pragma once -#if !(defined(__clang__) && defined(__CUDA__)) - #include "cutlass/cutlass.h" #include "cutlass/wmma_array.h" #include "cutlass/layout/matrix.h" @@ -223,5 +221,4 @@ public: ///////////////////////////////////////////////////////////////////////////////////////////////// -#endif // !defined(__clang__) diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index ac73bf0b..4a758b7f 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -399,6 +399,20 @@ struct FastDivmod { return div(dividend); } + /// Computes integer division remainder using precomputed values. + CUTLASS_HOST_DEVICE + int rem(int dividend) const { + int quotient, remainder; + fast_divmod(quotient, remainder, dividend); + return remainder; + } + + /// Alias for `rem` + CUTLASS_HOST_DEVICE + int remainder(int dividend) const { + return rem(dividend); + } + /// Computes integer division and modulus using precomputed values. This is computationally /// inexpensive. /// diff --git a/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl index d2b87315..7b47c1f9 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl @@ -113,6 +113,122 @@ sm100_compute_stage_count_or_override_blockwise(StageCountAutoCarveout +auto sm100_make_simt_gmem_tiled_copy_SFA() { + + // we have at most a warp to perform the loads + + constexpr int ScaleGranularityM = size<0,0>(LayoutSFA{}); + constexpr int ScaleMsPerTile = size<0>(CtaShape_MNK{}) / ScaleGranularityM; + constexpr int ScaleGranularityK = size<1,0>(LayoutSFA{}); + constexpr int ScaleKsPerTile = size<2>(CtaShape_MNK{}) / ScaleGranularityK; + + if constexpr (size<0,1>(LayoutSFA{}.stride()) == 1) { + constexpr int LeadingScalesPerTileSFA = ScaleMsPerTile; + if constexpr (LeadingScalesPerTileSFA >= 32) { + constexpr int Alignment = cute::min(static_cast(LeadingScalesPerTileSFA * sizeof(Element)) / 32, 16); + using ScaleCopyTypeA = cute::uint_byte_t; + using SmemScalingCopyAtomA = Copy_Atom, Element>; + constexpr int ElementsPerSFACopy = static_cast(sizeof(ScaleCopyTypeA) / sizeof(Element)); + return make_tiled_copy(SmemScalingCopyAtomA{}, Layout>{}, Layout>>{}); + } + else { + using SmemScalingCopyAtomA = Copy_Atom, Element>; + return make_tiled_copy(SmemScalingCopyAtomA{}, Layout>>{}, Layout>{}); + } + } + else { + // we expect scale Ks per tile to be small + constexpr int LeadingScalesPerTileSFA = ScaleKsPerTile; + using SmemScalingCopyAtomA = Copy_Atom, Element>; + return make_tiled_copy(SmemScalingCopyAtomA{}, Layout>>{}, Layout>{}); + } +} + +template +auto sm100_make_simt_gmem_tiled_copy_SFB() { + + // we have at most a warp to perform the loads + + constexpr int ScaleGranularityN = size<0,0>(LayoutSFB{}); + constexpr int ScaleNsPerTile = size<1>(CtaShape_MNK{}) / ScaleGranularityN; + constexpr int ScaleGranularityK = size<1,0>(LayoutSFB{}); + constexpr int ScaleKsPerTile = size<2>(CtaShape_MNK{}) / ScaleGranularityK; + + if constexpr (size<0,1>(LayoutSFB{}.stride()) == 1) { + constexpr int LeadingScalesPerTileSFB = ScaleNsPerTile; + if constexpr (LeadingScalesPerTileSFB >= 32) { + constexpr int Alignment = cute::min(static_cast(LeadingScalesPerTileSFB * sizeof(Element)) / 32, 16); + using ScaleCopyTypeB = cute::uint_byte_t; + using SmemScalingCopyAtomB = Copy_Atom, Element>; + constexpr int ElementsPerSFBCopy = static_cast(sizeof(ScaleCopyTypeB) / sizeof(Element)); + return make_tiled_copy(SmemScalingCopyAtomB{}, Layout>{}, Layout>>{}); + } + else { + using SmemScalingCopyAtomB = Copy_Atom, Element>; + return make_tiled_copy(SmemScalingCopyAtomB{}, Layout>>{}, Layout>{}); + } + } + else { + // we expect scale Ks per tile to be small + constexpr int LeadingScalesPerTileSFB = ScaleKsPerTile; + using SmemScalingCopyAtomB = Copy_Atom, Element>; + return make_tiled_copy(SmemScalingCopyAtomB{}, Layout>>{}, Layout>{}); + } +} + +// For new MMA construction and partitioning that supports both dynamic and static cluster shape. +// Used in conjunction with make_tma_atom_(A|B)_sm100 +// TileShape_MNK is always static and has shape (MmaAtomShapeM, MmaAtomShapeN, TileK) +// ClusterShape_MNK can be dynamic or static. +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + class BuilderScheduleTag, + UMMA::ScaleIn ANeg = UMMA::ScaleIn::One, + UMMA::ScaleIn BNeg = UMMA::ScaleIn::One +> +constexpr auto +sm100_make_trivial_tiled_mma_blockwise() { + // MMA_2SM requested + if constexpr (cute::is_base_of_v ) { + return sm100_make_2sm_trivial_tiled_mma(); + } + // MMA_1SM requested + else if constexpr (cute::is_base_of_v ) { + return sm100_make_1sm_trivial_tiled_mma(); + } + // Auto scheduling requested + else if constexpr (cute::is_same_v) { + // Static cluster + if constexpr (cute::is_static_v) { + // For MMA_2SM we need a cluster shape that is multiple of 2x1 + // and only M=128 and M=256 are supported, otherwise, fall back to MMA_1SM + if constexpr (cute::size<0>(ClusterShape_MNK{}) % 2 == 0 && + cute::size<0>(TileShape_MNK{}) % 128 == 0) { + return sm100_make_2sm_trivial_tiled_mma(); + } + else { + return sm100_make_1sm_trivial_tiled_mma(); + } + // Dynamic cluster shape means we cannot assume we can use 2SM MMA + } + else { + return sm100_make_1sm_trivial_tiled_mma(); + } + } +} + } // namespace detail ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -161,9 +277,11 @@ struct CollectiveBuilder< using GmemLayoutBTag = cute::remove_cvref_t(GmemLayoutBTagPair{}))>; using GmemLayoutSFBTag = cute::remove_cvref_t(GmemLayoutBTagPair{}))>; - static_assert(cute::depth(GmemLayoutSFATag{}) == 2 and cute::depth(GmemLayoutSFBTag{}) == 2, + static_assert(cute::depth(cute::remove_pointer_t{}) == 2 and + cute::depth(cute::remove_pointer_t{}) == 2, "Expect SFA and SFB layout to be depth of two with shape ((SFVecMN, restMN),(SFVecK, restK), L)"); - static_assert(size<1,0>(GmemLayoutSFATag{}) == size<1, 0>(GmemLayoutSFBTag{}), + static_assert(size<1,0>(cute::remove_pointer_t{}) == + size<1,0>(cute::remove_pointer_t{}), "SFA and SFB must have equivalent SF vector sizes along K"); static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); @@ -183,7 +301,7 @@ struct CollectiveBuilder< TileShape_MNK, ClusterShape_MNK, GmemLayoutATag, GmemLayoutBTag, false /*is_sparse*/, is_2sm>(), "TileSize and MNK Major does not met with MMA Mix 8-bit TMA load requirement" ); - using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma< + using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma_blockwise< ElementAMma, ElementBMma, ElementAccumulator, decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK, UmmaMajorA, UmmaMajorB, BuilderScheduleTag>()); @@ -238,12 +356,14 @@ struct CollectiveBuilder< // SchedulerPipelineStageCount could be set to zero for Grouped GEMM, but we shouldn't define CLC Pipeline's barrier arrays of size zero. static constexpr uint32_t SchedulerPipelineStageCount = cute::is_same_v ? (AccumulatorPipelineStageCount + 1) : 1; + static constexpr bool IsArrayOfPointersGemm = (cute::is_base_of_v); + static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout< ClusterShape_MNK, AccumulatorPipelineStageCount, SchedulerPipelineStageCount, detail::CLCResponseSize, - false + IsArrayOfPointersGemm >::KernelSmemCarveout; // Reduce SMEM capacity available for buffers considering barrier allocations. static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; @@ -253,14 +373,23 @@ struct CollectiveBuilder< using TransformLoadPipelineStorage = typename cutlass::PipelineAsync<1>::SharedStorage; using TransformPipelineStorage = typename cutlass::PipelineUmmaAsync<1>::SharedStorage; - static constexpr int ScaleGranularityM = size<0,0>(GmemLayoutSFATag{}); - static constexpr int ScaleGranularityN = size<0,0>(GmemLayoutSFBTag{}); - static constexpr int ScaleGranularityK = size<1,0>(GmemLayoutSFBTag{}); + static constexpr int ScaleGranularityM = size<0,0>(cute::remove_pointer_t{}); + static constexpr int ScaleGranularityN = size<0,0>(cute::remove_pointer_t{}); + static constexpr int ScaleGranularityK = size<1,0>(cute::remove_pointer_t{}); static_assert(size<0>(CtaTileShape_MNK{}) >= ScaleGranularityM, "Scale Granularity must be smaller than or equal to the tile shape"); static_assert(size<1>(CtaTileShape_MNK{}) >= ScaleGranularityN, "Scale Granularity must be smaller than or equal to the tile shape"); static_assert(size<2>(CtaTileShape_MNK{}) >= ScaleGranularityK, "Scale Granularity must be smaller than or equal to the tile shape"); + using GmemTiledCopySFA = decltype(detail::sm100_make_simt_gmem_tiled_copy_SFA< + ElementAccumulator, + cute::remove_pointer_t, + CtaTileShape_MNK>()); + using GmemTiledCopySFB = decltype(detail::sm100_make_simt_gmem_tiled_copy_SFB< + ElementAccumulator, + cute::remove_pointer_t, + CtaTileShape_MNK>()); + using BlockTileScale_M = Int(TileShape_MNK{}) / ScaleGranularityM>; using BlockTileScale_N = Int(TileShape_MNK{}) / ScaleGranularityN>; using BlockTileScale_K = Int(TileShape_MNK{}) / ScaleGranularityK>; @@ -273,11 +402,18 @@ struct CollectiveBuilder< TransformLoadPipelineStorage, TransformPipelineStorage>(StageCountType{}); static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, B, and scales."); - using DispatchPolicy = cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedBlockwiseScaling< + using DispatchPolicy = cute::conditional_t< + IsArrayOfPointersGemm, + cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecializedBlockwiseScaling< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK>; + ClusterShape_MNK>, + cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedBlockwiseScaling< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK>>; using CollectiveOp = cutlass::gemm::collective::CollectiveMma< DispatchPolicy, @@ -287,11 +423,11 @@ struct CollectiveBuilder< ElementB, cute::tuple, cutlass::gemm::TagToStrideB_t>, TiledMma, - GmemTiledCopyA, + cute::tuple, SmemLayoutAtomA, void, cute::identity, - GmemTiledCopyB, + cute::tuple, SmemLayoutAtomB, void, cute::identity diff --git a/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl b/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl index 062631e7..24f8c202 100755 --- a/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl @@ -104,10 +104,10 @@ struct CollectiveBuilder< UmmaMajorB, BuilderScheduleTag>(); static constexpr bool UseMxf8f6f4 = Instr == detail::blockscaled::BlockScaledInstr::MXF4F6F8; - using PermTileM = decltype(cute::min(size<0>(TileShape_MNK{}), _128{})); using PermTileN = decltype(detail::sm120_tile_n_permute_selector()); - using PermTileK = cute::conditional_t; + using PermTileK = cute::conditional_t<(UseMxf8f6f4 + ), _32, _64>; static constexpr bool IsCooperative = !cute::is_base_of_v; // Data type used by MMA instruction @@ -124,7 +124,13 @@ struct CollectiveBuilder< Layout>, Layout>>; using TiledMma = decltype(cute::make_tiled_mma( - cute::rr_blockscaled_op_selector_sm120(), + cute::rr_blockscaled_op_selector_sm120(), AtomLayoutMNK{}, Tile{} )); @@ -150,8 +156,14 @@ struct CollectiveBuilder< using SmemLayoutAtomA = decltype(detail::sm120_rr_smem_selector(TileShape_MNK{}))>()); using SmemLayoutAtomB = decltype(detail::sm120_rr_smem_selector(TileShape_MNK{}))>()); - using SmemCopyAtomA = Copy_Atom()), SmemAllocTypeA>; - using SmemCopyAtomB = Copy_Atom()), SmemAllocTypeB>; + using SmemCopyAtomA = Copy_Atom()), SmemAllocTypeA>; + using SmemCopyAtomB = Copy_Atom()), SmemAllocTypeB>; using SmemCopyAtomSF = Copy_Atom, SmemAllocTypeSF>; // auto-vectorized LDS using SmemCopyAtomSFA = SmemCopyAtomSF; diff --git a/include/cutlass/gemm/collective/builders/sm120_common.inl b/include/cutlass/gemm/collective/builders/sm120_common.inl index 7915eb97..45e201b3 100644 --- a/include/cutlass/gemm/collective/builders/sm120_common.inl +++ b/include/cutlass/gemm/collective/builders/sm120_common.inl @@ -45,7 +45,11 @@ namespace cutlass::gemm::collective::detail { constexpr int sm120_smem_capacity_bytes = cutlass::arch::sm120_smem_capacity_bytes; // Helper for selecting the shared memory copy atom to use for operand A -template +template < + class ElementA, + class ElementB, + bool UseF8f6f4 +> CUTLASS_HOST_DEVICE constexpr auto sm120_rr_smem_copy_selector_A() { @@ -66,7 +70,11 @@ sm120_rr_smem_copy_selector_A() { } // Helper for selecting the shared memory copy atom to use for operand B -template +template < + class ElementA, + class ElementB, + bool UseF8f6f4 +> CUTLASS_HOST_DEVICE constexpr auto sm120_rr_smem_copy_selector_B() { diff --git a/include/cutlass/gemm/collective/builders/sm1xx_common.inl b/include/cutlass/gemm/collective/builders/sm1xx_common.inl index 76c71f80..cb9e74c1 100644 --- a/include/cutlass/gemm/collective/builders/sm1xx_common.inl +++ b/include/cutlass/gemm/collective/builders/sm1xx_common.inl @@ -467,6 +467,8 @@ check_input_datatypes() { // SfVectorSize = 64 for blockscaled sparse gemm static_assert( ((SfVectorSizeA == 32 && cute::is_same_v) + || (SfVectorSizeA == 32 && cute::is_same_v) + || (SfVectorSizeA == 32 && cute::is_same_v) || (SfVectorSizeA == 32 && cute::is_base_of_v) || (SfVectorSizeA == 32 && cute::is_base_of_v) || (SfVectorSizeA == 64 && cute::is_base_of_v) @@ -645,6 +647,8 @@ select_instr() { static_assert( (SfVectorSize == 32 && cute::is_same_v) || (SfVectorSize == 32 && cute::is_base_of_v) + || (SfVectorSize == 32 && cute::is_base_of_v) + || (SfVectorSize == 32 && cute::is_base_of_v) || (SfVectorSize == 32 && cute::is_base_of_v) || (SfVectorSize == 64 && cute::is_base_of_v || (SfVectorSize == 32 && cute::is_base_of_v) @@ -666,6 +670,8 @@ select_instr() { else { static_assert( ((SfVectorSize == 32 && cute::is_same_v) + || (SfVectorSize == 32 && cute::is_base_of_v) + || (SfVectorSize == 32 && cute::is_base_of_v) || (SfVectorSize == 32 && cute::is_base_of_v) || (SfVectorSize == 32 && cute::is_base_of_v) || (SfVectorSize == 64 && cute::is_base_of_v) diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index 4b4a105e..57b34afd 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -61,6 +61,7 @@ #include "cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp" #include "cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp" #include "cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp" +#include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp" #include "cutlass/gemm/collective/sm120_mma_tma.hpp" #include "cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp" #include "cutlass/gemm/collective/sm120_sparse_mma_tma.hpp" diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp index 69b33179..1ec07261 100644 --- a/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp @@ -28,8 +28,6 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ - - #pragma once #include "cutlass/cutlass.h" @@ -989,12 +987,59 @@ struct CollectiveMma< uint32_t skip_wait = k_tile_count <= 0; auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); - bool is_first_iter = true; // // PIPELINED MAIN LOOP // tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + if constexpr (IsOverlappingAccum) { + // first iteration manual unroll for tmem overlap kernel + if (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + } + else { + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + CUTLASS_PRAGMA_NO_UNROLL while (k_tile_count > 0) { // WAIT on mainloop_pipe_consumer_state until its data are available @@ -1018,12 +1063,6 @@ struct CollectiveMma< copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); } - // Wait for tmem accumulator buffer to become empty with a flipped phase - if (is_first_iter) { - accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); - is_first_iter = false; - } - // Unroll the K mode manually so we can set scale C to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { @@ -1036,6 +1075,7 @@ struct CollectiveMma< accumulators); tiled_mma.accumulate_ = UMMA::ScaleOut::One; } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); } diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp index 8ef3d8f5..f61c9da1 100644 --- a/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp @@ -1197,12 +1197,61 @@ struct CollectiveMma< uint32_t skip_wait = k_tile_count <= 0; auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); - bool is_first_iter = true; // // PIPELINED MAIN LOOP // tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + if constexpr (IsOverlappingAccum) { + // first iteration manual unroll for tmem overlap kernel + if (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_E, thr_tCsE_s2t(_,_,_,_,read_stage), thr_tCtE_s2t); + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtE(_,_,k_block), + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + } + else { + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + CUTLASS_PRAGMA_NO_UNROLL while (k_tile_count > 0) { // WAIT on mainloop_pipe_consumer_state until its data are available @@ -1227,12 +1276,6 @@ struct CollectiveMma< copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); } - // Wait for tmem accumulator buffer to become empty with a flipped phase - if (is_first_iter) { - accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); - is_first_iter = false; - } - // Unroll the K mode manually so we can set scale C to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { diff --git a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp new file mode 100644 index 00000000..1d6e1158 --- /dev/null +++ b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp @@ -0,0 +1,1330 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StridePairA_, + class ElementB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100ArrayTmaUmmaWarpSpecializedBlockwiseScaling< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StridePairA_, + ElementB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100ArrayTmaUmmaWarpSpecializedBlockwiseScaling< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = cute::remove_cvref_t(StridePairA_{}))>; + using LayoutSFA = cute::remove_cvref_t(StridePairA_{}))>; + using InternalStrideA = cute::remove_pointer_t; + using InternalLayoutSFA = cute::remove_pointer_t; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = cute::remove_cvref_t(StridePairB_{}))>; + using LayoutSFB = cute::remove_cvref_t(StridePairB_{}))>; + using InternalStrideB = cute::remove_pointer_t; + using InternalLayoutSFB = cute::remove_pointer_t; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + static constexpr int ScaleGranularityM = size<0,0>(InternalLayoutSFA{}); + + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + static_assert(size<0>(TileShape{}) % ScaleGranularityM == 0 and ScaleGranularityM <= size<0>(TileShape{}), "Scale Granularity M must divide Tile Shape"); + + static constexpr int ScaleGranularityN = size<0,0>(InternalLayoutSFB{}); + static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + static_assert(size<1>(TileShape{}) % ScaleGranularityN == 0 and ScaleGranularityN <= size<1>(TileShape{}), "Scale Granularity N must divide Tile Shape"); + + static_assert(size<1, 0>(InternalLayoutSFA{}) == size<1, 0>(InternalLayoutSFB{}), "Vector size K must be equal for SFA and SFB"); + + static constexpr int ScaleGranularityK = size<1, 0>(InternalLayoutSFA{}); + static constexpr int ScaleKsPerTile = size<2>(TileShape{}) / ScaleGranularityK; + static_assert(size<2>(TileShape{}) % ScaleGranularityK == 0 and ScaleGranularityK <= size<2>(TileShape{}), "Scale Granularity K must divide Tile Shape"); + static_assert(ScaleGranularityK % size<2>(typename TiledMma::AtomShape_MNK{}) == 0, "Scale Granularity K must be divisible by MMA_K"); + + static constexpr int K_BLOCK_MMAS_PER_SCALE_K = ScaleGranularityK / size<2>(typename TiledMma::AtomShape_MNK{}); + + static_assert(size<0>(CtaShape_MNK{}) >= ScaleGranularityM, "Scale Granularity must be smaller than or equal to the tile shape"); + static_assert(size<1>(CtaShape_MNK{}) >= ScaleGranularityN, "Scale Granularity must be smaller than or equal to the tile shape"); + static_assert(size<2>(CtaShape_MNK{}) >= ScaleGranularityK, "Scale Granularity must be smaller than or equal to the tile shape"); + + using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig(InternalLayoutSFA{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K, + size<0,1>(InternalLayoutSFB{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K>; + + using SmemLayoutAtomSFA = decltype(ScaleConfig::smem_atom_layoutSFA(CtaShape_MNK{})); + using SmemLayoutAtomSFB = decltype(ScaleConfig::smem_atom_layoutSFB(CtaShape_MNK{})); + + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = cute::remove_cvref_t(GmemTiledCopyPairA_{}))>; + using GmemTiledCopySFA = cute::remove_cvref_t(GmemTiledCopyPairA_{}))>; + using GmemTiledCopyB = cute::remove_cvref_t(GmemTiledCopyPairB_{}))>; + using GmemTiledCopySFB = cute::remove_cvref_t(GmemTiledCopyPairB_{}))>; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopABPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopABPipelineState = typename MainloopABPipeline::PipelineState; + + using MainloopSFPipeline = cutlass::PipelineAsync; + using MainloopSFPipelineState = typename MainloopSFPipeline::PipelineState; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync< + AccumulatorPipelineStageCount, + AtomThrShapeMNK>; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + // Two arrivals per thread in the warp (1 arrival and 1 arrival through cp.async.mbarrier) + static constexpr int NumMainloopSFProducerThreadEvents = 64; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + using TmaInternalElementA = cute::conditional_t, cutlass::tfloat32_t, ElementAMma>; + using TmaInternalElementB = cute::conditional_t, cutlass::tfloat32_t, ElementBMma>; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = uint_bit_t>; + using BitTypeElementB = uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + using SmemLayoutScaleA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutScaleB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + using PipelineABStorage = typename MainloopABPipeline::SharedStorage; + using PipelineSFStorage = typename MainloopSFPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + struct PipelineStorage { + alignas(16) PipelineABStorage pipeline_ab; + alignas(16) PipelineSFStorage pipeline_sf; + alignas(16) AccumulatorPipelineStorage pipeline_accum; + }; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const** ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const** ptr_B{nullptr}; + StrideB dB{}; + ElementAccumulator const** ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementAccumulator const** ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + cute::TmaDescriptor* tensormaps; + ArrayElementA const** ptr_A; + StrideA dA; + ArrayElementB const** ptr_B; + StrideB dB; + + ElementAccumulator const** ptr_SFA; + LayoutSFA layout_SFA; + ElementAccumulator const** ptr_SFB; + LayoutSFB layout_SFB; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_K = get<2>(init_shape); + auto init_L = get<3>(init_shape); + + // Tensor pointers will be fixed before the first access + TmaInternalElementA const* ptr_A_first_batch = nullptr; + TmaInternalElementB const* ptr_B_first_batch = nullptr; + + InternalStrideA stride_a; + InternalStrideB stride_b; + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + } + + // Batches/Groups are managed by using appropriate pointers to input matrices. + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB, + args.ptr_SFA, + args.layout_SFA, + args.ptr_SFB, + args.layout_SFB + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + [[maybe_unused]] Arguments const& args) { + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + bool implementable_sf = cutlass::detail::check_alignment(InternalLayoutSFA{}); + implementable_sf = implementable_sf && cutlass::detail::check_alignment(InternalLayoutSFB{}); + + if (!implementable_sf) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for Scale Factors.\n"); + } + implementable = implementable && implementable_sf; + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE auto + slice_accumulator(cute::Tensor const& accumulators, int stage) { + return accumulators(_,_,_,stage); + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_ab_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx, + [[maybe_unused]] int32_t init_group) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + const int32_t mock_L = 1; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,mock_L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,mock_L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-Cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + } + + template + CUTLASS_DEVICE auto + load_sf_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + int current_group) const { + return load_sf_update(problem_shape_MNKL, params, shared_tensors, current_group); + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + template + CUTLASS_DEVICE auto + load_sf_update( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + int current_group) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + const int32_t mock_L = 1; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,mock_L)); + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + + auto layout_SFA = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (IsGroupedGemmKernel) { + return params.layout_SFA[current_group]; + } + else { + return params.layout_SFA; + } + }(); + + auto layout_SFB = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (IsGroupedGemmKernel) { + return params.layout_SFB[current_group]; + } + else { + return params.layout_SFB; + } + }(); + + Tensor mSFA_mkl = make_tensor(make_gmem_ptr(params.ptr_SFA[current_group]), layout_SFA); // (m,k,l) + + Tensor mSFB_nkl = make_tensor(make_gmem_ptr(params.ptr_SFB[current_group]), layout_SFB); // (n,k,l) + + Tensor SFA_mkl_ident = make_identity_tensor(shape(layout_SFA)); + + Tensor SFB_nkl_ident = make_identity_tensor(shape(layout_SFB)); + + // Tile the tensors and defer the slice + Tensor gSFA_mkl = local_tile(mSFA_mkl, CtaShape_MNK{}, + make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, CtaShape_MNK{}, + make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + Tensor identSFA_mkl = local_tile(SFA_mkl_ident, CtaShape_MNK{}, + make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor identSFB_nkl = local_tile(SFB_nkl_ident, CtaShape_MNK{}, + make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + static_assert(rank(decltype(gSFA_mkl){}) == 5); + static_assert(rank(decltype(gSFB_nkl){}) == 5); + + // 1 thread copies entire set of scalar + GmemTiledCopySFA scale_copy_a{}; + GmemTiledCopySFB scale_copy_b{}; + + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x % size(scale_copy_a)); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x % size(scale_copy_b)); + + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), + SmemLayoutScaleA{}); // (CTA_M,CTA_K,P) + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), + SmemLayoutScaleB{}); // (CTA_M,CTA_K,P) + + Tensor tSFAgSFA_mkl = thr_scale_copy_a.partition_S(gSFA_mkl); // (CPY, BLK_M, BLK_K, m, k, l) + Tensor tSFAIdentSFA_mkl = thr_scale_copy_a.partition_S(identSFA_mkl); // (CPY, BLK_M, BLK_K, m, k, l) + + Tensor tSFAsSFA = thr_scale_copy_a.partition_D(sSFA); + + Tensor tSFBgSFB_nkl = thr_scale_copy_b.partition_S(gSFB_nkl); // (CPY, BLK_N, BLK_K, m, k, l) + Tensor tSFBIdentSFB_nkl = thr_scale_copy_b.partition_S(identSFB_nkl); // (CPY, BLK_N, BLK_K, m, k, l) + Tensor tSFBsSFB = thr_scale_copy_b.partition_D(sSFB); + + static_assert(rank(decltype(tSFAgSFA_mkl){}) == 6); + static_assert(rank(decltype(tSFBgSFB_nkl){}) == 6); + + return cute::make_tuple(gA_mkl, + tSFAgSFA_mkl, tSFBgSFB_nkl, + tSFAsSFA, tSFBsSFB, + tSFAIdentSFA_mkl, tSFBIdentSFB_nkl, + layout_SFA, layout_SFB); + } + + /// Setup data needed for transform + CUTLASS_DEVICE auto + accum_init( + TensorStorage& shared_tensors) const { + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), + SmemLayoutScaleA{}); // (CTA_M,CTA_K,P) + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), + SmemLayoutScaleB{}); // (CTA_M,CTA_K,P) + + return cute::make_tuple(sSFA, sSFB); + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + Params const& params, + [[maybe_unused]] cute::Tensor const& accumulators, + TensorStorage& shared_tensors, + [[maybe_unused]] uint32_t const tmem_nonaccum_offset) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA_ = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB_ = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(rank(tCrA_) == _4{}); + + auto mma_tile_shape_A = make_shape(get<0>(shape(tCrA_.layout())), + get<1>(shape(tCrA_.layout())), + Int{}, + _1{}); + + auto mma_tile_shape_B = make_shape(get<0>(shape(tCrB_.layout())), + get<1>(shape(tCrB_.layout())), + Int{}, + _1{}); + + Tensor tCrA = flat_divide(tCrA_, + mma_tile_shape_A)(_,_,_,_0{},_0{},_0{},_,_); // (MMA,MMA_M,MMA_K_PER_SCALE,MMA_K_REST,PIPE) + + Tensor tCrB = flat_divide(tCrB_, + mma_tile_shape_B)(_,_,_,_0{},_0{},_0{},_,_); // (MMA,MMA_N,MMA_K_PER_SCALE,MMA_K_REST,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + return cute::make_tuple(tiled_mma, tCrA, tCrB); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TensorMapA, class TensorMapB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_ab( + Params const& params, + MainloopABPipeline mainloop_ab_pipeline, + MainloopABPipelineState mainloop_ab_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b, + input_tensormaps] = load_inputs; + + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change) { + tensormaps_fence_acquire(input_tensormaps); + } + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_ab_pipeline.producer_try_acquire(mainloop_ab_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_ab_pipeline.producer_acquire(mainloop_ab_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopABPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_ab_pipeline.producer_get_barrier(mainloop_ab_pipe_producer_state); + + int write_stage = mainloop_ab_pipe_producer_state.index(); + ++mainloop_ab_pipe_producer_state; + barrier_token = mainloop_ab_pipeline.producer_try_acquire(mainloop_ab_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + } + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_ab_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_ab_tail(MainloopABPipeline mainloop_ab_pipeline, MainloopABPipelineState mainloop_ab_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_ab_pipeline.producer_tail(mainloop_ab_pipe_producer_state); + } + + /// Perform a collective-scoped transform + /// Producer Perspective + template < + class UnusedGTensorA, + class GTensorPartitionedSFA, class GTensorPartitionedSFB, + class STensorSFA, class STensorSFB, + class IdentPartitionedSFA, class IdentPartitionedSFB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_sf( + MainloopSFPipeline mainloop_sf_pipeline, + MainloopSFPipelineState mainloop_sf_pipe_producer_state, + cute::tuple const& mainloop_sf_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused, tSFAgSFA_mkl, tSFBgSFB_nkl, + tSFAsSFA, tSFBsSFB, + tSFAIdentSFA_mkl, tSFBIdentSFB_nkl, + layout_SFA, layout_SFB] = mainloop_sf_inputs; + + // slice out the work coord from partitioned tensors + GmemTiledCopySFA scale_copy_a{}; + GmemTiledCopySFB scale_copy_b{}; + + Tensor tSFAgSFA = tSFAgSFA_mkl(_, _, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + Tensor tSFBgSFB = tSFBgSFB_nkl(_, _, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + Tensor thr_tile_SFA_k = tSFAIdentSFA_mkl(_0{}, _, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor thr_tile_pSFA = make_tensor(shape(filter_zeros(thr_tile_SFA_k(_,_,_0{}), tSFAgSFA(_0{},_,_,_0{}).stride()))); + Tensor thr_tile_SFB_k = tSFBIdentSFB_nkl(_0{}, _, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + Tensor thr_tile_pSFB = make_tensor(shape(filter_zeros(thr_tile_SFB_k(_,_,_0{}), tSFBgSFB(_0{},_,_,_0{}).stride()))); + + // Issue the loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK pipe_producer_state for _writing_ + mainloop_sf_pipeline.producer_acquire(mainloop_sf_pipe_producer_state); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(thr_tile_pSFA); ++i) { + Tensor thr_tile_SFA = filter_zeros(thr_tile_SFA_k(_,_,*k_tile_iter), tSFAgSFA(_0{},_,_,_0{}).stride()); + thr_tile_pSFA(i) = elem_less(thr_tile_SFA(i), shape(filter_zeros(layout_SFA))) && threadIdx.x % 32 < size(scale_copy_a); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(thr_tile_pSFB); ++i) { + Tensor thr_tile_SFB = filter_zeros(thr_tile_SFB_k(_,_,*k_tile_iter), tSFBgSFB(_0{},_,_,_0{}).stride()); + thr_tile_pSFB(i) = elem_less(thr_tile_SFB(i), shape(filter_zeros(layout_SFB))) && threadIdx.x % 32 < size(scale_copy_b); + } + + copy_if(scale_copy_a, thr_tile_pSFA, filter_zeros(tSFAgSFA(_,_,_,*k_tile_iter)), filter_zeros(tSFAsSFA(_,_,_,mainloop_sf_pipe_producer_state.index()))); + copy_if(scale_copy_b, thr_tile_pSFB, filter_zeros(tSFBgSFB(_,_,_,*k_tile_iter)), filter_zeros(tSFBsSFB(_,_,_,mainloop_sf_pipe_producer_state.index()))); + mainloop_sf_pipeline.producer_commit(mainloop_sf_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive_noinc); + + __syncwarp(); + + ++mainloop_sf_pipe_producer_state; + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_sf_pipe_producer_state, k_tile_iter); + + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_sf_tail( + MainloopSFPipeline mainloop_sf_pipeline, + MainloopSFPipelineState mainloop_sf_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_sf_pipeline.producer_tail(mainloop_sf_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 4, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N, P)"); + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + CUTLASS_PRAGMA_UNROLL + for (int scale_k_iter = 0; scale_k_iter < size<3>(tCrA); ++scale_k_iter) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + auto acc = slice_accumulator(accumulators, accumulator_pipe_producer_state.index()); + static_assert(is_tmem>::value, "Accumulator must be tmem resident."); + static_assert(rank(remove_cvref_t{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + + // for each set of scale_k_iter we zero the accumulator + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, + tCrA(_,_,k_block,scale_k_iter,read_stage), + tCrB(_,_,k_block,scale_k_iter,read_stage), + acc); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + + } + + return make_tuple(mainloop_pipe_consumer_state, accumulator_pipe_producer_state); + + } + + /// Transform + template < + class FrgEngine, + class FrgLayout, + class TensorsSFA, + class TensorsSFB, + class CtaTileCoord, + class CopyOpT2R, + class EpilogueTile + > + CUTLASS_DEVICE auto + accum( + cute::tuple pipelines, + cute::tuple consumer_states, + cute::Tensor const& accumulators, + cute::tuple const& transform_inputs, + CtaTileCoord cta_tile_coord, + CopyOpT2R, + EpilogueTile, + int k_tile_count) { + + static_assert(size<0>(EpilogueTile{}) <= size<0>(CtaShape_MNK{}), "Restrict epilogue tile to be smaller than or equal to CTA Tile"); + static_assert(size<1>(EpilogueTile{}) <= size<1>(CtaShape_MNK{}), "Restrict epilogue tile to be smaller than or equal to CTA Tile"); + + + // + // PIPELINED Transform + // + + Tensor acc = slice_accumulator(accumulators, _0{}); + Tensor tAcc = acc(make_coord(_,_),_0{},_0{}); + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + auto [sSFA_, sSFB_] = transform_inputs; + + // Append N with a stride of 0 to SFA + Tensor sSFA = make_tensor(sSFA_.data(), make_layout( + make_shape(get<0>(sSFA_.shape()), get<1>(CtaShape_MNK{}), get<1>(sSFA_.shape()), get<2>(sSFA_.shape())), + make_stride(get<0>(sSFA_.stride()), _0{}, get<1>(sSFA_.stride()), get<2>(sSFA_.stride())) + )); + + CUTE_STATIC_ASSERT_V(size<0>(sSFA) == size<0>(tAcc)); + CUTE_STATIC_ASSERT_V(size<1>(sSFA) == size<1>(tAcc)); + + Tensor sSFA_epi = flat_divide(sSFA, EpilogueTile{}); + + // Append M with a stride of 0 to SFB + Tensor sSFB = make_tensor(sSFB_.data(), make_layout( + make_shape(get<0>(CtaShape_MNK{}), get<0>(sSFB_.shape()), get<1>(sSFB_.shape()), get<2>(sSFB_.shape())), + make_stride(_0{}, get<0>(sSFB_.stride()), get<1>(sSFB_.stride()), get<2>(sSFB_.stride())) + )); + + CUTE_STATIC_ASSERT_V(size<0>(sSFB) == size<0>(tAcc)); + CUTE_STATIC_ASSERT_V(size<1>(sSFB) == size<1>(tAcc)); + + Tensor sSFB_epi = flat_divide(sSFB, EpilogueTile{}); + + TiledCopy tiled_t2r_epi = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); + + int thread_idx = threadIdx.x % size(tiled_t2r_epi); + + ThrCopy thread_t2r_epi = tiled_t2r_epi.get_slice(thread_idx); + + Tensor acc_ident_epi = make_identity_tensor(shape(tAcc_epi)); + + Tensor tTR_rAcc_epi = thread_t2r_epi.partition_D(acc_ident_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + + Tensor tTR_sSFA_epi = thread_t2r_epi.partition_D(sSFA_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + Tensor tTR_sSFB_epi = thread_t2r_epi.partition_D(sSFB_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + + static_assert(rank(decltype(tTR_sSFA_epi){}) == 7); + + Tensor tTR_FullAcc = make_tensor(shape(tTR_rAcc_epi)); + Tensor tTR_PartAcc = make_tensor(shape(tTR_rAcc_epi(_,_,_,_0{},_0{}))); + + Tensor tTR_rSFA_compact = make_fragment_like(filter_zeros(tTR_sSFA_epi(_,_,_,_,_,_,_0{}))); + Tensor tTR_rSFB_compact = make_fragment_like(filter_zeros(tTR_sSFB_epi(_,_,_,_,_,_,_0{}))); + + Layout tTR_rSFA_layout = make_layout(tTR_sSFA_epi(_,_,_,_,_,_,_0{}).shape(), tTR_rSFA_compact.stride()); + Layout tTR_rSFB_layout = make_layout(tTR_sSFB_epi(_,_,_,_,_,_,_0{}).shape(), tTR_rSFB_compact.stride()); + + // Zero our accumulator + clear(tTR_FullAcc); + + auto [accumulator_pipeline, mainloop_sf_pipeline] = pipelines; + auto [accumulator_pipe_state, mainloop_sf_pipe_state] = consumer_states; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + mainloop_sf_pipeline.consumer_wait(mainloop_sf_pipe_state); + int read_idx = mainloop_sf_pipe_state.index(); + + copy(filter_zeros(tTR_sSFA_epi(_,_,_,_,_,_,read_idx)), tTR_rSFA_compact); + copy(filter_zeros(tTR_sSFB_epi(_,_,_,_,_,_,read_idx)), tTR_rSFB_compact); + + CUTE_STATIC_ASSERT_V(cosize(tTR_rSFA_layout) == size(tTR_rSFA_compact)); + CUTE_STATIC_ASSERT_V(cosize(tTR_rSFB_layout) == size(tTR_rSFB_compact)); + + Tensor tTR_rSFA = make_tensor(tTR_rSFA_compact.data(), tTR_rSFA_layout); + Tensor tTR_rSFB = make_tensor(tTR_rSFB_compact.data(), tTR_rSFB_layout); + + mainloop_sf_pipeline.consumer_release(mainloop_sf_pipe_state); + ++mainloop_sf_pipe_state; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < ScaleKsPerTile; ++k_block) { + + accumulator_pipeline.consumer_wait(accumulator_pipe_state); + + Tensor acc = slice_accumulator(accumulators, accumulator_pipe_state.index()); + Tensor tAcc = acc(make_coord(_,_),_0{},_0{}); + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N) + Tensor tTR_tAcc = thread_t2r_epi.partition_S(tAcc_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(tAcc_epi); ++epi_m) { + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(tAcc_epi); ++epi_n) { + + auto scale_a = tTR_rSFA(_,_,_,epi_m,epi_n,k_block * ScaleGranularityK); + auto scale_b = tTR_rSFB(_,_,_,epi_m,epi_n,k_block * ScaleGranularityK); + + Tensor full_acc = tTR_FullAcc(_,_,_,epi_m,epi_n); + // Compute tmem load predication if necessary + copy(tiled_t2r_epi, tTR_tAcc(_,_,_,epi_m,epi_n), tTR_PartAcc); + cutlass::arch::fence_view_async_tmem_load(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(full_acc); ++i) { + ElementAccumulator scale = scale_a(i) * scale_b(i); + full_acc(i) += scale * tTR_PartAcc(i); + } + } + } + cutlass::arch::fence_view_async_tmem_load(); + accumulator_pipeline.consumer_release(accumulator_pipe_state); + // release acc + ++accumulator_pipe_state; + } + + --k_tile_count; + } + + return cute::make_tuple(tTR_FullAcc, tiled_t2r_epi, cute::make_tuple(accumulator_pipe_state, mainloop_sf_pipe_state)); + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(observed_tma_load_a_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + __syncwarp(); + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + + TmaInternalElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]); + + TmaInternalElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_a_, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, + prob_shape_B, prob_stride_B); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape problem_shape, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_MNKL); + } + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release(shared_tensormaps, input_tensormaps); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + +private: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp index 6be30cbc..b0d2d0f6 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp @@ -667,12 +667,14 @@ struct CollectiveMma< uint32_t skip_wait = k_tile_count <= 0; auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); - bool is_first_iter = true; // // PIPELINED MAIN LOOP // tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + CUTLASS_PRAGMA_NO_UNROLL while (k_tile_count > 0) { // WAIT on mainloop_pipe_consumer_state until its data are available @@ -690,11 +692,6 @@ struct CollectiveMma< skip_wait = k_tile_count <= 0; // Peek at next iteration barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); - // Wait for tmem accumulator buffer to become empty with a flipped phase - if (is_first_iter) { - accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); - is_first_iter = false; - } // Unroll the K mode manually so we can set scale C to 1 CUTLASS_PRAGMA_UNROLL diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp index cb621a5f..a8ebd512 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp @@ -70,11 +70,11 @@ template < class ElementB_, class StridePairB_, class TiledMma_, - class GmemTiledCopyA_, + class GmemTiledCopyPairA_, class SmemLayoutAtomA_, class SmemCopyAtomA_, class TransformA_, - class GmemTiledCopyB_, + class GmemTiledCopyPairB_, class SmemLayoutAtomB_, class SmemCopyAtomB_, class TransformB_> @@ -90,11 +90,11 @@ struct CollectiveMma< ElementB_, StridePairB_, TiledMma_, - GmemTiledCopyA_, + GmemTiledCopyPairA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, - GmemTiledCopyB_, + GmemTiledCopyPairB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_> @@ -142,9 +142,6 @@ struct CollectiveMma< static constexpr int K_BLOCK_MMAS_PER_SCALE_K = ScaleGranularityK / size<2>(typename TiledMma::AtomShape_MNK{}); - static constexpr int TILE_M = size<0>(TileShape{}); - static constexpr int TILE_N = size<1>(TileShape{}); - using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig(AtomThrShapeMNK{}) == 1, "2SM MMA is not yet supported"); - static_assert(size<0>(CtaShape_MNK{}) >= ScaleGranularityM, "Scale Granularity must be smaller than or equal to the tile shape"); static_assert(size<1>(CtaShape_MNK{}) >= ScaleGranularityN, "Scale Granularity must be smaller than or equal to the tile shape"); static_assert(size<2>(CtaShape_MNK{}) >= ScaleGranularityK, "Scale Granularity must be smaller than or equal to the tile shape"); @@ -180,8 +175,10 @@ struct CollectiveMma< static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; using ElementAccumulator = typename TiledMma::ValTypeC; - using GmemTiledCopyA = GmemTiledCopyA_; - using GmemTiledCopyB = GmemTiledCopyB_; + using GmemTiledCopyA = cute::remove_cvref_t(GmemTiledCopyPairA_{}))>; + using GmemTiledCopySFA = cute::remove_cvref_t(GmemTiledCopyPairA_{}))>; + using GmemTiledCopyB = cute::remove_cvref_t(GmemTiledCopyPairB_{}))>; + using GmemTiledCopySFB = cute::remove_cvref_t(GmemTiledCopyPairB_{}))>; using SmemLayoutAtomA = SmemLayoutAtomA_; using SmemLayoutAtomB = SmemLayoutAtomB_; using SmemCopyAtomA = SmemCopyAtomA_; @@ -190,22 +187,22 @@ struct CollectiveMma< using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< - DispatchPolicy::Stages, - ClusterShape, - AtomThrShapeMNK>; - using MainloopPipelineState = typename MainloopPipeline::PipelineState; + using MainloopABPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopABPipelineState = typename MainloopABPipeline::PipelineState; - using Load2TransformPipeline = cutlass::PipelineAsync; - using Load2TransformPipelineState = typename Load2TransformPipeline::PipelineState; + using MainloopSFPipeline = cutlass::PipelineAsync; + using MainloopSFPipelineState = typename MainloopSFPipeline::PipelineState; - using Mma2TransformPipeline = cutlass::PipelineUmmaAsync< + using AccumulatorPipeline = cutlass::PipelineUmmaAsync< AccumulatorPipelineStageCount, AtomThrShapeMNK>; - using Mma2TransformPipelineState = typename Mma2TransformPipeline::PipelineState; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; - // Two arrivals per CTA (1 arrival and 1 arrival through cp.async.mbarrier) - static constexpr int NumLoad2TransformProducerThreadEvents = 2; + // Two arrivals per thread in the warp (1 arrival and 1 arrival through cp.async.mbarrier) + static constexpr int NumMainloopSFProducerThreadEvents = 64; static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, @@ -277,43 +274,28 @@ struct CollectiveMma< append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) )); - // Scaling gmem-to-smem copy atom - static constexpr int LeadingScalesPerTileSFA = size<0,1>(LayoutSFA{}.stride()) == 1 ? ScaleMsPerTile : ScaleKsPerTile; - using ScaleCopyTypeA = cute::uint_byte_t(sizeof(ElementAccumulator)) * LeadingScalesPerTileSFA, 16)>; - using SmemScalingCopyAtomA = Copy_Atom, ElementAccumulator>; - static constexpr int ElementsPerSFACopy = static_cast(sizeof(ScaleCopyTypeA) / sizeof(ElementAccumulator)); - - static constexpr int LeadingScalesPerTileSFB = size<0,1>(LayoutSFB{}.stride()) == 1 ? ScaleNsPerTile : ScaleKsPerTile; - using ScaleCopyTypeB = cute::uint_byte_t(sizeof(ElementAccumulator)) * LeadingScalesPerTileSFB, 16)>; - using SmemScalingCopyAtomB = Copy_Atom, ElementAccumulator>; - static constexpr int ElementsPerSFBCopy = static_cast(sizeof(ScaleCopyTypeB) / sizeof(ElementAccumulator)); - - using TiledCopyScaleA = decltype(make_tiled_copy(SmemScalingCopyAtomA{}, Layout>{}, Layout>>{})); - using TiledCopyScaleB = decltype(make_tiled_copy(SmemScalingCopyAtomB{}, Layout>{}, Layout>>{})); - struct SharedStorage { struct TensorStorage : cute::aligned_struct<128, _0> { cute::ArrayEngine> smem_A; cute::ArrayEngine> smem_B; - cute::ArrayEngine> smem_scale_A; - cute::ArrayEngine> smem_scale_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; } tensors; - using PipelineStorage = typename MainloopPipeline::SharedStorage; - PipelineStorage pipeline; + using PipelineABStorage = typename MainloopABPipeline::SharedStorage; + using PipelineSFStorage = typename MainloopSFPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; - using Load2TransformPipelineStorage = typename Load2TransformPipeline::SharedStorage; - Load2TransformPipelineStorage transform2load_pipeline; - - using Mma2TransformPipelineStorage = typename Mma2TransformPipeline::SharedStorage; - Mma2TransformPipelineStorage mma2transform_pipeline; + struct PipelineStorage { + alignas(16) PipelineABStorage pipeline_ab; + alignas(16) PipelineSFStorage pipeline_sf; + alignas(16) AccumulatorPipelineStorage pipeline_accum; + }; }; // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. using TensorStorage = typename SharedStorage::TensorStorage; using PipelineStorage = typename SharedStorage::PipelineStorage; - using Mma2TransformPipelineStorage = typename SharedStorage::Mma2TransformPipelineStorage; - using Load2TransformPipelineStorage = typename SharedStorage::Load2TransformPipelineStorage; // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly static constexpr uint32_t TmaTransactionBytes = @@ -328,12 +310,9 @@ struct CollectiveMma< template< class KTileCount, class GTensorPartitionedA, class GTensorPartitionedB, - class STensorA, class STensorB, - class GTensorPartitionedScaleA, class GTensorPartitionedScaleB, - class IdentTensorPartitionedScaleA, class IdentTensorPartitionedScaleB, - class STensorScaleA, class STensorScaleB + class STensorA, class STensorB > - struct LoadParams { + struct LoadABParams { // for scheduler KTileCount k_tiles; // for input tensor values @@ -342,6 +321,32 @@ struct CollectiveMma< STensorA tAsA; STensorB tBsB; + // the TMA multicast masks + uint16_t mcast_mask_a; + uint16_t mcast_mask_b; + + CUTLASS_DEVICE + LoadABParams ( + KTileCount k_tiles_, + GTensorPartitionedA tAgA_mkl_, GTensorPartitionedB tBgB_nkl_, + STensorA tAsA_, STensorB tBsB_, + uint16_t mcast_mask_a_, uint16_t mcast_mask_b_) + : k_tiles(k_tiles_) + , tAgA_mkl(tAgA_mkl_), tBgB_nkl(tBgB_nkl_) + , tAsA(tAsA_), tBsB(tBsB_) + , mcast_mask_a(mcast_mask_a_), mcast_mask_b(mcast_mask_b_) {} + }; + + template< + class KTileCount, + class GTensorPartitionedScaleA, class GTensorPartitionedScaleB, + class IdentTensorPartitionedScaleA, class IdentTensorPartitionedScaleB, + class STensorScaleA, class STensorScaleB + > + struct LoadSFParams { + // for scheduler + KTileCount k_tiles; + GTensorPartitionedScaleA tSFAgSFA_mkl; GTensorPartitionedScaleB tSFBgSFB_nkl; IdentTensorPartitionedScaleA tSFAIdentSFA_mkl; @@ -349,30 +354,20 @@ struct CollectiveMma< STensorScaleA tSFAsSFA; STensorScaleB tSFBsSFB; - // the TMA multicast masks - uint16_t mcast_mask_a; - uint16_t mcast_mask_b; - LayoutSFA layout_SFA; LayoutSFB layout_SFB; CUTLASS_DEVICE - LoadParams ( + LoadSFParams ( KTileCount k_tiles_, - GTensorPartitionedA tAgA_mkl_, GTensorPartitionedB tBgB_nkl_, - STensorA tAsA_, STensorB tBsB_, GTensorPartitionedScaleA tSFAgSFA_mkl_, GTensorPartitionedScaleB tSFBgSFB_nkl_, IdentTensorPartitionedScaleA tSFAIdentSFA_mkl_, IdentTensorPartitionedScaleB tSFBIdentSFB_nkl_, STensorScaleA tSFAsSFA_, STensorScaleB tSFBsSFB_, - uint16_t mcast_mask_a_, uint16_t mcast_mask_b_, LayoutSFA layout_SFA_, LayoutSFB layout_SFB_) : k_tiles(k_tiles_) - , tAgA_mkl(tAgA_mkl_), tBgB_nkl(tBgB_nkl_) - , tAsA(tAsA_), tBsB(tBsB_) , tSFAgSFA_mkl(tSFAgSFA_mkl_), tSFBgSFB_nkl(tSFBgSFB_nkl_) , tSFAIdentSFA_mkl(tSFAIdentSFA_mkl_), tSFBIdentSFB_nkl(tSFBIdentSFB_nkl_) , tSFAsSFA(tSFAsSFA_), tSFBsSFB(tSFBsSFB_) - , mcast_mask_a(mcast_mask_a_), mcast_mask_b(mcast_mask_b_) , layout_SFA(layout_SFA_), layout_SFB(layout_SFB_) {} }; @@ -393,14 +388,14 @@ struct CollectiveMma< template< class STensorScaleA, class STensorScaleB > - struct TransformParams { + struct AccumTransformParams { // for scheduler STensorScaleA sSFA; STensorScaleB sSFB; CUTLASS_DEVICE - TransformParams ( + AccumTransformParams ( STensorScaleA sSFA_, STensorScaleB sSFB_) : sSFA(sSFA_), sSFB(sSFB_) {} }; @@ -412,9 +407,9 @@ struct CollectiveMma< StrideA dA{}; ArrayElementB const* ptr_B{nullptr}; StrideB dB{}; - ElementAccumulator const* ptr_scale_A{nullptr}; + ElementAccumulator const* ptr_SFA{nullptr}; LayoutSFA layout_SFA{}; - ElementAccumulator const* ptr_scale_B{nullptr}; + ElementAccumulator const* ptr_SFB{nullptr}; LayoutSFB layout_SFB{}; RuntimeDataTypeA runtime_data_type_a{}; RuntimeDataTypeB runtime_data_type_b{}; @@ -451,9 +446,9 @@ struct CollectiveMma< RuntimeDataTypeA runtime_data_type_a; RuntimeDataTypeB runtime_data_type_b; - ElementAccumulator const* ptr_scale_A; + ElementAccumulator const* ptr_SFA; LayoutSFA layout_SFA; - ElementAccumulator const* ptr_scale_B; + ElementAccumulator const* ptr_SFB; LayoutSFB layout_SFB; }; @@ -539,9 +534,9 @@ struct CollectiveMma< hw_info.cluster_shape_fallback, args.runtime_data_type_a, args.runtime_data_type_b, - args.ptr_scale_A, + args.ptr_SFA, args.layout_SFA, - args.ptr_scale_B, + args.ptr_SFB, args.layout_SFB }; } @@ -568,8 +563,8 @@ struct CollectiveMma< CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); } - bool implementable_sf = cutlass::detail::check_alignment(args.layout_SFA); - implementable_sf = implementable_sf && cutlass::detail::check_alignment(args.layout_SFB); + bool implementable_sf = cutlass::detail::check_alignment(args.layout_SFA); + implementable_sf = implementable_sf && cutlass::detail::check_alignment(args.layout_SFB); if (!implementable_sf) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for Scale Factors.\n"); @@ -628,20 +623,12 @@ struct CollectiveMma< /// gB_nkl - The tiled tma tensor for input B /// tAsA - partitioned smem tensor for A /// tBsB - partitioned smem tensor for B - /// tSFAgSFA_mkl - partitioned gmem tensor for SFA - /// tSFBgSFB_nkl - partitioned gmem tensor for SFB - /// tSFAIdentSFA_mkl - partitioned identity tensor for SFA in gmem - /// tSFBIdentSFB_nkl - partitioned identity tensor for SFB in gmem - /// tSFAsSFA - partitioned smem tensor for SFA - /// tSFBsSFB - partitioned smem tensor for SFB /// mcast_mask_a - tma multicast mask for A /// mcast_mask_b - tma multicast mask for B - /// layout_SFA - layout of SFA in gmem - /// layout_SFB - layout of SFB in gmem template CUTLASS_DEVICE auto - load_init( + load_ab_init( ProblemShape_MNKL const& problem_shape_MNKL, MainloopParams const& mainloop_params, TensorStorage& shared_tensors) const { @@ -686,10 +673,38 @@ struct CollectiveMma< uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); - // Scales + LoadABParams load_params { + shape<3>(gA_mkl), // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + }; + return load_params; + } - Tensor mSFA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), mainloop_params.layout_SFA); // (m,k,l) - Tensor mSFB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), mainloop_params.layout_SFB); // (n,k,l) + /// Set up the data needed by this collective for load. + /// Return load params containing + /// tSFAgSFA_mkl - partitioned gmem tensor for SFA + /// tSFBgSFB_nkl - partitioned gmem tensor for SFB + /// tSFAIdentSFA_mkl - partitioned identity tensor for SFA in gmem + /// tSFBIdentSFB_nkl - partitioned identity tensor for SFB in gmem + /// tSFAsSFA - partitioned smem tensor for SFA + /// tSFBsSFB - partitioned smem tensor for SFB + /// layout_SFA - layout of SFA in gmem + /// layout_SFB - layout of SFB in gmem + template + CUTLASS_DEVICE auto + load_sf_init( + ProblemShape_MNKL const& problem_shape_MNKL, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + Tensor mSFA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFA), mainloop_params.layout_SFA); // (m,k,l) + Tensor mSFB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFB), mainloop_params.layout_SFB); // (n,k,l) Tensor SFA_mkl_ident = make_identity_tensor(shape(mainloop_params.layout_SFA)); @@ -710,15 +725,15 @@ struct CollectiveMma< static_assert(rank(decltype(gSFB_nkl){}) == 5); // 1 thread copies entire set of scalar - TiledCopyScaleA scale_copy_a{}; - TiledCopyScaleB scale_copy_b{}; + GmemTiledCopySFA scale_copy_a{}; + GmemTiledCopySFB scale_copy_b{}; - ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(_0{}); - ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(_0{}); + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x % size(scale_copy_a)); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x % size(scale_copy_b)); - Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_scale_A.begin()), + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutScaleA{}); // (CTA_M,CTA_K,P) - Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_scale_B.begin()), + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutScaleB{}); // (CTA_M,CTA_K,P) Tensor tSFAgSFA_mkl = thr_scale_copy_a.partition_S(gSFA_mkl); // (CPY, BLK_M, BLK_K, m, k, l) @@ -733,19 +748,18 @@ struct CollectiveMma< static_assert(rank(decltype(tSFAgSFA_mkl){}) == 6); static_assert(rank(decltype(tSFBgSFB_nkl){}) == 6); - LoadParams load_params { - shape<3>(gA_mkl), // for scheduler - tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + LoadSFParams load_params { + size<3>(gSFA_mkl), tSFAgSFA_mkl, tSFBgSFB_nkl, // for input scale tensor values tSFAIdentSFA_mkl, tSFBIdentSFB_nkl, // for predicating scale tensor copies tSFAsSFA, tSFBsSFB, // for scale tensor values - mcast_mask_a, mcast_mask_b, // multicast masks mainloop_params.layout_SFA, // for predicating scale tensor copies mainloop_params.layout_SFB // for predicating scale tensor copies }; return load_params; } + /// Set up the data needed by this collective for mma compute. template CUTLASS_DEVICE auto @@ -756,8 +770,27 @@ struct CollectiveMma< Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) // Allocate "fragments/descriptors" for A and B matrices - Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrA_ = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB_ = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(rank(tCrA_) == _4{}); + + auto mma_tile_shape_A = make_shape(get<0>(shape(tCrA_.layout())), + get<1>(shape(tCrA_.layout())), + Int{}, + _1{}); + + auto mma_tile_shape_B = make_shape(get<0>(shape(tCrB_.layout())), + get<1>(shape(tCrB_.layout())), + Int{}, + _1{}); + + Tensor tCrA = flat_divide(tCrA_, + mma_tile_shape_A)(_,_,_,_0{},_0{},_0{},_,_); // (MMA,MMA_M,MMA_K_PER_SCALE,MMA_K_REST,PIPE) + + Tensor tCrB = flat_divide(tCrB_, + mma_tile_shape_B)(_,_,_,_0{},_0{},_0{},_,_); // (MMA,MMA_N,MMA_K_PER_SCALE,MMA_K_REST,PIPE) + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); @@ -780,7 +813,7 @@ struct CollectiveMma< /// Set up the data needed by this collective for transform. template CUTLASS_DEVICE auto - transform_init( + accum_init( ProblemShape_MNKL const& problem_shape_MNKL, TensorStorage& shared_tensors) const { using X = Underscore; @@ -788,13 +821,13 @@ struct CollectiveMma< // Separate out problem shape for convenience auto [M,N,K,L] = problem_shape_MNKL; - Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.begin()), + Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutScaleA{}); // (ScaleMsPerTile,ScakeKsPerTile,P) - Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.begin()), + Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutScaleB{}); // (ScaleNsPerTile,ScaleKsPerTile,P) - TransformParams transform_params { + AccumTransformParams transform_params { sSFA, sSFB // for input tensor values }; return transform_params; @@ -803,34 +836,92 @@ struct CollectiveMma< /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective template < - class LoadParams, + class LoadABParams, class TileCoordMNKL, class KTileIterator > CUTLASS_DEVICE auto - load( - MainloopPipeline mainloop_pipeline, - Load2TransformPipeline load2transform_pipeline, - MainloopPipelineState mainloop_pipe_producer_state, - Load2TransformPipelineState load2transform_pipe_producer_state, - LoadParams const& load_inputs, + load_ab( + MainloopABPipeline mainloop_pipeline, + MainloopABPipelineState mainloop_pipe_producer_state, + LoadABParams const& load_inputs, TileCoordMNKL const& cta_coord_mnkl, KTileIterator k_tile_iter, int k_tile_count) { auto [unused_k_tiles, tAgA_mkl, tBgB_nkl, tAsA, tBsB, - tSFAgSFA_mkl, tSFBgSFB_nkl, - tSFAIdentSFA_mkl, tSFBIdentSFB_nkl, - tSFAsSFA, tSFBsSFB, - mcast_mask_a, mcast_mask_b, - layout_SFA, layout_SFB] = load_inputs; + mcast_mask_a, mcast_mask_b] = load_inputs; // slice out the work coord from partitioned tensors Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); - TiledCopyScaleA scale_copy_a{}; - TiledCopyScaleB scale_copy_b{}; + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopABPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + auto curr_mainloop_pipe_producer_state = mainloop_pipe_producer_state; + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_ab_tail( + MainloopABPipeline mainloop_pipeline, + MainloopABPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped transform + /// Load producer Perspective + template < + class LoadSFParams, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_sf( + MainloopSFPipeline mainloop_sf_pipeline, + MainloopSFPipelineState mainloop_sf_pipe_producer_state, + LoadSFParams const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_k_tiles, + tSFAgSFA_mkl, tSFBgSFB_nkl, + tSFAIdentSFA_mkl, tSFBIdentSFB_nkl, + tSFAsSFA, tSFBsSFB, + layout_SFA, layout_SFB] = load_inputs; + + // slice out the work coord from partitioned tensors + GmemTiledCopySFA scale_copy_a{}; + GmemTiledCopySFB scale_copy_b{}; Tensor tSFAgSFA = tSFAgSFA_mkl(_, _, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); @@ -842,69 +933,50 @@ struct CollectiveMma< Tensor thr_tile_pSFB = make_tensor(shape(filter_zeros(thr_tile_SFB_k(_,_,_0{}), tSFBgSFB(_0{},_,_,_0{}).stride()))); - auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); - - // Issue the Mainloop loads + // Issue the loads CUTLASS_PRAGMA_NO_UNROLL while (k_tile_count > 0) { - // LOCK mainloop_pipe_producer_state for _writing_ - mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); - - load2transform_pipeline.producer_acquire(load2transform_pipe_producer_state); - - using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); - - int write_stage = mainloop_pipe_producer_state.index(); - auto curr_mainloop_pipe_producer_state = mainloop_pipe_producer_state; - ++mainloop_pipe_producer_state; - barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + // LOCK pipe_producer_state for _writing_ + mainloop_sf_pipeline.producer_acquire(mainloop_sf_pipe_producer_state); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(thr_tile_pSFA); ++i) { Tensor thr_tile_SFA = filter_zeros(thr_tile_SFA_k(_,_,*k_tile_iter), tSFAgSFA(_0{},_,_,_0{}).stride()); - thr_tile_pSFA(i) = elem_less(thr_tile_SFA(i), shape(filter_zeros(layout_SFA))); + thr_tile_pSFA(i) = elem_less(thr_tile_SFA(i), shape(filter_zeros(layout_SFA))) && threadIdx.x % 32 < size(scale_copy_a); } CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(thr_tile_pSFB); ++i) { Tensor thr_tile_SFB = filter_zeros(thr_tile_SFB_k(_,_,*k_tile_iter), tSFBgSFB(_0{},_,_,_0{}).stride()); - thr_tile_pSFB(i) = elem_less(thr_tile_SFB(i), shape(filter_zeros(layout_SFB))); + thr_tile_pSFB(i) = elem_less(thr_tile_SFB(i), shape(filter_zeros(layout_SFB))) && threadIdx.x % 32 < size(scale_copy_b); } - if (cute::elect_one_sync()) { - copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); - copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); - copy_if(scale_copy_a, thr_tile_pSFA, filter_zeros(tSFAgSFA(_,_,_,*k_tile_iter)), filter_zeros(tSFAsSFA(_,_,_,load2transform_pipe_producer_state.index()))); - copy_if(scale_copy_b, thr_tile_pSFB, filter_zeros(tSFBgSFB(_,_,_,*k_tile_iter)), filter_zeros(tSFBsSFB(_,_,_,load2transform_pipe_producer_state.index()))); - load2transform_pipeline.producer_commit(load2transform_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive_noinc); - } + copy_if(scale_copy_a, thr_tile_pSFA, filter_zeros(tSFAgSFA(_,_,_,*k_tile_iter)), filter_zeros(tSFAsSFA(_,_,_,mainloop_sf_pipe_producer_state.index()))); + copy_if(scale_copy_b, thr_tile_pSFB, filter_zeros(tSFBgSFB(_,_,_,*k_tile_iter)), filter_zeros(tSFBsSFB(_,_,_,mainloop_sf_pipe_producer_state.index()))); + mainloop_sf_pipeline.producer_commit(mainloop_sf_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive_noinc); __syncwarp(); - ++load2transform_pipe_producer_state; + ++mainloop_sf_pipe_producer_state; --k_tile_count; ++k_tile_iter; } - return cute::make_tuple(mainloop_pipe_producer_state, load2transform_pipe_producer_state, k_tile_iter); + return cute::make_tuple(mainloop_sf_pipe_producer_state, k_tile_iter); } /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster CUTLASS_DEVICE void - load_tail( - MainloopPipeline mainloop_pipeline, - Load2TransformPipeline load2transform_pipeline, - MainloopPipelineState mainloop_pipe_producer_state, - Load2TransformPipelineState load2transform_pipe_producer_state) { + load_sf_tail( + MainloopSFPipeline mainloop_sf_pipeline, + MainloopSFPipelineState mainloop_sf_pipe_producer_state) { // Issue the epilogue waits // This helps avoid early exit of ctas in Cluster // Waits for all stages to either be released (all // Consumer UNLOCKs), or if the stage was never used // then would just be acquired since the phase was // still inverted from make_producer_start_state - mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); - load2transform_pipeline.producer_tail(load2transform_pipe_producer_state); + mainloop_sf_pipeline.producer_tail(mainloop_sf_pipe_producer_state); } /// Perform a collective-scoped matrix multiply-accumulate @@ -916,10 +988,10 @@ struct CollectiveMma< > CUTLASS_DEVICE auto mma( - cute::tuple pipelines, - cute::tuple pipeline_states, + cute::tuple pipelines, + cute::tuple pipeline_states, TmemStorage tmem_storage, MmaParams const& mma_inputs, CtaTileCoord cta_tile_coord, @@ -927,10 +999,10 @@ struct CollectiveMma< auto [tiled_mma, tCrA, tCrB] = mma_inputs; auto [mainloop_pipeline, - mma2transform_pipeline] = pipelines; + accumulator_pipeline] = pipelines; auto [mainloop_pipe_consumer_state, - mma2transform_pipe_producer_state] = pipeline_states; + accumulator_pipe_producer_state] = pipeline_states; uint32_t skip_wait = k_tile_count <= 0; auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); @@ -958,54 +1030,50 @@ struct CollectiveMma< // Peek at next iteration barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); - static_assert(size<2>(tCrA) / K_BLOCK_MMAS_PER_SCALE_K, "k blocks must be divisible by K_BLOCK_MMAS_PER_SCALE_K"); - CUTLASS_PRAGMA_UNROLL - for (int scale_k_blocks = 0; scale_k_blocks < size<2>(tCrA) / K_BLOCK_MMAS_PER_SCALE_K; ++scale_k_blocks) { - mma2transform_pipeline.producer_acquire(mma2transform_pipe_producer_state); + for (int scale_k_iter = 0; scale_k_iter < size<3>(tCrA); ++scale_k_iter) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); - auto acc = get<0>(slice_accumulator(tmem_storage, mma2transform_pipe_producer_state.index())); + auto acc = get<0>(slice_accumulator(tmem_storage, accumulator_pipe_producer_state.index())); static_assert(is_tmem>::value, "Accumulator must be tmem resident."); static_assert(rank(remove_cvref_t{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); // for each set of scale_k_blocks we zero the accumulator tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; - int start_k_block = scale_k_blocks * size<2>(tCrA) / K_BLOCK_MMAS_PER_SCALE_K; // Unroll the K mode manually so we can set scale C to 1 CUTLASS_PRAGMA_UNROLL - for (int k_block_offset = 0; k_block_offset < K_BLOCK_MMAS_PER_SCALE_K; ++k_block_offset) { - int k_block = start_k_block + k_block_offset; + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // (V,M) x (V,N) => (V,M,N) cute::gemm(tiled_mma, - tCrA(_,_,k_block,read_stage), - tCrB(_,_,k_block,read_stage), + tCrA(_,_,k_block,scale_k_iter,read_stage), + tCrB(_,_,k_block,scale_k_iter,read_stage), acc); tiled_mma.accumulate_ = UMMA::ScaleOut::One; } - mma2transform_pipeline.producer_commit(mma2transform_pipe_producer_state); - ++mma2transform_pipe_producer_state; + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; } mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); } - return make_tuple(mainloop_pipe_consumer_state, mma2transform_pipe_producer_state); + return make_tuple(mainloop_pipe_consumer_state, accumulator_pipe_producer_state); } /// Transform template < - class TransformParams, + class AccumTransformParams, class TmemStorage, class CtaTileCoord, class CopyOpT2R, class EpilogueTile > CUTLASS_DEVICE auto - transform( - cute::tuple pipelines, - cute::tuple consumer_states, + accum( + cute::tuple pipelines, + cute::tuple consumer_states, TmemStorage tmem_storage, - TransformParams const& transform_inputs, + AccumTransformParams const& transform_inputs, CtaTileCoord cta_tile_coord, CopyOpT2R, EpilogueTile, @@ -1076,14 +1144,14 @@ struct CollectiveMma< // Zero our accumulator clear(tTR_FullAcc); - auto [mma2transform_pipeline, load2transform_pipeline] = pipelines; - auto [mma2transform_pipe_state, load2transform_pipe_state] = consumer_states; + auto [accumulator_pipeline, mainloop_sf_pipeline] = pipelines; + auto [accumulator_pipe_state, mainloop_sf_pipe_state] = consumer_states; CUTLASS_PRAGMA_NO_UNROLL while (k_tile_count > 0) { - load2transform_pipeline.consumer_wait(load2transform_pipe_state); - int read_idx = load2transform_pipe_state.index(); + mainloop_sf_pipeline.consumer_wait(mainloop_sf_pipe_state); + int read_idx = mainloop_sf_pipe_state.index(); copy(filter_zeros(tTR_sSFA_epi(_,_,_,_,_,_,read_idx)), tTR_rSFA_compact); copy(filter_zeros(tTR_sSFB_epi(_,_,_,_,_,_,read_idx)), tTR_rSFB_compact); @@ -1094,15 +1162,15 @@ struct CollectiveMma< Tensor tTR_rSFA = make_tensor(tTR_rSFA_compact.data(), tTR_rSFA_layout); Tensor tTR_rSFB = make_tensor(tTR_rSFB_compact.data(), tTR_rSFB_layout); - load2transform_pipeline.consumer_release(load2transform_pipe_state); - ++load2transform_pipe_state; + mainloop_sf_pipeline.consumer_release(mainloop_sf_pipe_state); + ++mainloop_sf_pipe_state; CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < ScaleKsPerTile; ++k_block) { - mma2transform_pipeline.consumer_wait(mma2transform_pipe_state); + accumulator_pipeline.consumer_wait(accumulator_pipe_state); - Tensor acc = get<0>(slice_accumulator(tmem_storage, mma2transform_pipe_state.index())); + Tensor acc = get<0>(slice_accumulator(tmem_storage, accumulator_pipe_state.index())); Tensor tAcc = acc(make_coord(_,_),_0{},_0{}); Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N) Tensor tTR_tAcc = thread_t2r_epi.partition_S(tAcc_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) @@ -1128,15 +1196,15 @@ struct CollectiveMma< } } cutlass::arch::fence_view_async_tmem_load(); - mma2transform_pipeline.consumer_release(mma2transform_pipe_state); + accumulator_pipeline.consumer_release(accumulator_pipe_state); // release acc - ++mma2transform_pipe_state; + ++accumulator_pipe_state; } --k_tile_count; } - return cute::make_tuple(tTR_FullAcc, tiled_t2r_epi, cute::make_tuple(mma2transform_pipe_state, load2transform_pipe_state)); + return cute::make_tuple(tTR_FullAcc, tiled_t2r_epi, cute::make_tuple(accumulator_pipe_state, mainloop_sf_pipe_state)); } protected: diff --git a/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp index b970f95b..5ba16b41 100644 --- a/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp @@ -866,6 +866,11 @@ struct CollectiveMma< // PIPELINED MAIN LOOP // tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + if constexpr (not IsOverlappingAccum) { + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + CUTLASS_PRAGMA_NO_UNROLL while (k_tile_count > 0) { // WAIT on mainloop_pipe_consumer_state until its data are available @@ -884,15 +889,23 @@ struct CollectiveMma< // Peek at next iteration barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); - if (iter % UtccpReuseCnt == 0) { + if constexpr (UtccpReuseCnt == 1) { if (cute::elect_one_sync()) { copy(tiled_copy_s2t_E, thr_tCsE_s2t(_,_,_,_,read_stage), thr_tCtE_s2t); } } + else { + if (not (iter & 1)) { + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_E, thr_tCsE_s2t(_,_,_,_,read_stage), thr_tCtE_s2t); + } + } + } - // Wait for tmem accumulator buffer to become empty with a flipped phase - if (iter == 0) { - accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + if constexpr (IsOverlappingAccum) { + if (iter == 0) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } } // Unroll the K mode manually so we can set scale C to 1 diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index cc61dc8b..7c829ea1 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -475,6 +475,15 @@ struct KernelTmaWarpSpecializedMmaTransformSm100 final { static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; }; +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelPtrArrayTmaWarpSpecializedMmaTransformSm100 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + // Sparse Gemm template< int SchedulerPipelineStageCount_, @@ -602,12 +611,16 @@ struct KernelScheduleSm100PtrArrayDenseGemm : KernelScheduleSm100DenseGemm {}; struct KernelPtrArrayTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100PtrArrayDenseGemm {}; struct KernelPtrArrayTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100PtrArrayDenseGemm {}; - /////////////////////////////////////////////////////////////////////////////////////////////////////// -// SM100 Blockwise GEMM Dispatch Policies +// SM100 Blockwise GEMM + Ptr-Array GEMM Dispatch Policies /////////////////////////////////////////////////////////////////////////////////////////////////////// struct KernelScheduleSm100Blockwise : KernelScheduleSm100 {}; struct KernelTmaWarpSpecializedBlockwise1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100Blockwise {}; +struct KernelTmaWarpSpecializedBlockwise2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100Blockwise {}; + +struct KernelScheduleSm100PtrArrayBlockwise : KernelScheduleSm100Blockwise {}; +struct KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100PtrArrayBlockwise {}; +struct KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100PtrArrayBlockwise {}; /////////////////////////////////////////////////////////////////////////////////////////////////////// // SM100 Planar Complex GEMM Dispatch Policies @@ -728,14 +741,13 @@ struct KernelScheduleF8f6f4Sm120 final : KernelScheduleSm120DenseGemm {}; struct KernelScheduleBlockScaledGemmSm120 : KernelScheduleSm120 {}; struct KernelScheduleMxf8f6f4Sm120 : KernelScheduleBlockScaledGemmSm120 {}; struct KernelScheduleMxNvf4Sm120 : KernelScheduleBlockScaledGemmSm120 {}; -// Block Scaled Sparse GEMM: Specialize for instruction type, scale factor vector size. +// Block Scaled GEMM: Specialize for instruction type, scale factor vector size. struct KernelTmaWarpSpecializedNvf4Sm120 final : KernelScheduleMxNvf4Sm120, KernelTmaWarpSpecializedCooperative { }; struct KernelTmaWarpSpecializedPingpongNvf4Sm120 final : KernelScheduleMxNvf4Sm120, KernelTmaWarpSpecializedPingpong { }; struct KernelTmaWarpSpecializedMxf4Sm120 final : KernelScheduleMxNvf4Sm120, KernelTmaWarpSpecializedCooperative { }; struct KernelTmaWarpSpecializedPingpongMxf4Sm120 final : KernelScheduleMxNvf4Sm120, KernelTmaWarpSpecializedPingpong { }; struct KernelTmaWarpSpecializedMxf8f6f4Sm120 final : KernelScheduleMxf8f6f4Sm120, KernelTmaWarpSpecializedCooperative { }; struct KernelTmaWarpSpecializedPingpongMxf8f6f4Sm120 final : KernelScheduleMxf8f6f4Sm120, KernelTmaWarpSpecializedPingpong { }; - /////////////////////////////////////////////////////////////////////////////////////////////////////// // SM120 Sparse GEMM Dispatch Policies /////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -786,6 +798,21 @@ struct MainloopSm100TmaUmmaWarpSpecializedBlockwiseScaling { constexpr static bool IsOverlappingAccum = false; }; +// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100ArrayTmaUmmaWarpSpecializedBlockwiseScaling { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + using Schedule = KernelPtrArrayTmaWarpSpecializedMmaTransformSm100; + constexpr static bool IsOverlappingAccum = false; +}; + // n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule template< int Stages_, diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index 38fab69f..08605e00 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -68,6 +68,7 @@ struct IsCutlass3ArrayKernel +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileSchedulerTag_, + cute::enable_if_t< + cutlass::detail::is_kernel_tag_of_v>> { +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using InternalStrideA = typename CollectiveMainloop::InternalStrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using InternalStrideB = typename CollectiveMainloop::InternalStrideB; + using LayoutSFA = typename cutlass::detail::LayoutSFAType::type; + using LayoutSFB = typename cutlass::detail::LayoutSFBType::type; + using ElementSF = typename cutlass::detail::ElementSFType::type; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using Schedule = typename DispatchPolicy::Schedule; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 100); + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueTile = typename CollectiveEpilogue::EpilogueTile; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using InternalStrideC = typename CollectiveEpilogue::InternalStrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using InternalStrideD = typename CollectiveEpilogue::InternalStrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + static_assert(!IsOverlappingAccum, "Does not support overlapping accumulator"); + + // TileID scheduler + // Get Blk and Scheduling tile shapes + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + using TileSchedulerTag = cute::conditional_t; + + using TileScheduler = typename detail::TileSchedulerSelector< + TileSchedulerTag, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount, ProblemShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopABLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopSFLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + + + static constexpr uint32_t MaxThreadsPerBlock = cute::round_up(NumSchedThreads + + NumMainloopABLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + NumEpilogueThreads + + NumMainloopSFLoadThreads, 128); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + + // Pipeline and pipeline state types + using MainloopABPipeline = typename CollectiveMainloop::MainloopABPipeline; + using MainloopABPipelineState = typename CollectiveMainloop::MainloopABPipelineState; + + using MainloopSFPipeline = typename CollectiveMainloop::MainloopSFPipeline; + using MainloopSFPipelineState = typename CollectiveMainloop::MainloopSFPipelineState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using LoadOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + using AccumulatorPipeline = typename CollectiveMainloop::AccumulatorPipeline; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = typename CLCPipeline::PipelineState; + + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, + cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; + + static constexpr uint32_t GenericRegisterRequirement = 48; + static constexpr uint32_t AccumRegisterRequirement = 256; + + // Kernel level shared memory storage + struct SharedStorage { + // Barriers should be allocated in lower 8KB of SMEM for SM100 + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) LoadOrderBarrierStorage load_order; + alignas(16) CLCPipelineStorage clc; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(16) arch::ClusterBarrier tmem_dealloc; + alignas(16) arch::ClusterBarrier epilogue_throttle; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorMapStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorMapStorage = typename CollectiveEpilogue::TensorMapStorage; + using MainloopTensorMapStorage = typename CollectiveMainloop::TensorMapStorage; + alignas(128) EpilogueTensorMapStorage epilogue; + alignas(128) MainloopTensorMapStorage mainloop; + } tensormaps; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + + // Host facing host arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + TileSchedulerParams scheduler{}; + KernelHardwareInfo hw_info{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopABLoad = 2, + EpilogueLoad = 3, + Epilogue = 4, // 4 warps + MainloopSFLoad = 8, + Unused = 9, + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_ab_load = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + uint32_t main_sf_load = false; + uint32_t unused = false; + }; + + // + // Methods + // + + // Convert to underlying arguments. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + constexpr uint32_t NumEpilogueSubTiles = 1; + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + ProblemShape problem_shapes = args.problem_shape; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (!IsGroupedGemmKernel && sm_count != 0) { + CUTLASS_TRACE_HOST(" WARNING: SM100 tile scheduler does not allow for user specified SM counts.\n" + " To restrict a kernel's resource usage, consider using CUDA driver APIs instead (green contexts)."); + } + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, problem_shapes.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + TileSchedulerParams scheduler; + if constexpr (IsGroupedGemmKernel) { + scheduler = TileScheduler::to_underlying_arguments( + problem_shapes, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace); + } + else { + scheduler = TileScheduler::to_underlying_arguments( + problem_shapes.get_host_problem_shape(), TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ); + } + + return { + args.mode, + problem_shapes, + CollectiveMainloop::to_underlying_arguments(problem_shapes, args.mainloop, mainloop_workspace, args.hw_info), + CollectiveEpilogue::to_underlying_arguments(problem_shapes, args.epilogue, epilogue_workspace), + scheduler, + args.hw_info + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = true; + if constexpr (IsGroupedGemmKernel) { + // Group GEMM currently only supports rank-3 problem shapes + implementable &= (args.mode == GemmUniversalMode::kGrouped && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3); + } else { + implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4); + } + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Mainloop, Epilogue or Scheduler don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); + return implementable; + } + + if constexpr (IsDynamicCluster) { + static constexpr int MaxClusterSize = 16; + implementable &= size(args.hw_info.cluster_shape) <= MaxClusterSize; + implementable &= size(args.hw_info.cluster_shape_fallback) <= MaxClusterSize; + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + } + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Dynamic Cluster or Preferred Cluster don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); + return implementable; + } + + constexpr bool IsBlockscaled = !cute::is_void_v; + if constexpr (IsBlockscaled) { + if constexpr (IsDynamicCluster) { + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + // Special cluster check for scale factor multicasts. Due to limited size of scale factors, we can't multicast among + // more than 4 CTAs + implementable &= (args.hw_info.cluster_shape.x <= 4 && args.hw_info.cluster_shape.y <= 4 && + args.hw_info.cluster_shape_fallback.x <= 4 && args.hw_info.cluster_shape_fallback.y <= 4); + } + else { + // Special cluster check for scale factor multicasts. Due to limited size of scale factors, we can't multicast among + // more than 4 CTAs + implementable &= ((size<0>(ClusterShape{}) <= 4) && (size<1>(ClusterShape{}) <= 4)); + } + } + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + constexpr uint32_t NumEpilogueSubTiles = 1; + size_t workspace_size = 0; + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Mainloop + workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + constexpr uint32_t NumEpilogueSubTiles = 1; + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Mainloop + status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // NOTE: cluster_shape here is the major cluster shape, not fallback one + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, params.hw_info.cluster_shape); + + dim3 grid_shape; + if constexpr (IsGroupedGemmKernel) { + grid_shape = TileScheduler::get_grid_shape( + params.scheduler, + params.problem_shape, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + else { + grid_shape = TileScheduler::get_grid_shape( + params.scheduler, + params.problem_shape.get_host_problem_shape(), + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + return grid_shape; + } + + static constexpr + dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator() (Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + + auto problem_shape = params.problem_shape; + + // Account for more than one epilogue warp + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = [&] () CUTLASS_LAMBDA_FUNC_INLINE { + if (warp_idx < static_cast(WarpCategory::Epilogue)) { + return WarpCategory(warp_idx); + } else if (warp_idx < static_cast(WarpCategory::MainloopSFLoad)) { + return WarpCategory::Epilogue; + } else if (warp_idx == static_cast(WarpCategory::MainloopSFLoad)) { + return WarpCategory::MainloopSFLoad; + } else { + return WarpCategory::Unused; + } + }(); + + + uint32_t lane_predicate = cute::elect_one_sync(); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); + int cluster_size = size(cluster_shape); + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; + int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); + bool is_mma_leader_cta = cta_coord_v == 0; + constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; + [[maybe_unused]] uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop(params.mainloop, cluster_shape, cta_rank_in_cluster); + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Do we load source tensor C or other aux inputs + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA), // mma + (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::MainloopABLoad), // main_ab_load + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::MainloopSFLoad), // main_sf_load + (warp_category == WarpCategory::Unused) // unused + }; + + // Mainloop Load pipeline + typename MainloopABPipeline::Params mainloop_ab_pipeline_params; + if (WarpCategory::MainloopABLoad == warp_category) { + mainloop_ab_pipeline_params.role = MainloopABPipeline::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_ab_pipeline_params.role = MainloopABPipeline::ThreadCategory::Consumer; + } + mainloop_ab_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_ab_load; + mainloop_ab_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + mainloop_ab_pipeline_params.initializing_warp = 0; + MainloopABPipeline mainloop_ab_pipeline(shared_storage.pipelines.mainloop.pipeline_ab, + mainloop_ab_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + typename MainloopSFPipeline::Params mainloop_sf_pipeline_params; + if (WarpCategory::MainloopSFLoad == warp_category) { + mainloop_sf_pipeline_params.role = MainloopSFPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + mainloop_sf_pipeline_params.role = MainloopSFPipeline::ThreadCategory::Consumer; + } + mainloop_sf_pipeline_params.initializing_warp = 8; + mainloop_sf_pipeline_params.producer_arv_count = CollectiveMainloop::NumMainloopSFProducerThreadEvents; + mainloop_sf_pipeline_params.consumer_arv_count = NumEpilogueThreads; + + MainloopSFPipeline mainloop_sf_pipeline(shared_storage.pipelines.mainloop.pipeline_sf, + mainloop_sf_pipeline_params); + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 4; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Load order barrier + typename LoadOrderBarrier::Params load_order_barrier_params; + load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopABLoad) ? 0 : 1; + load_order_barrier_params.group_size = NumMainloopABLoadThreads; + load_order_barrier_params.initializing_warp = 5; + LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); + + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopABLoadThreads + NumEpilogueThreads + + NumMainloopSFLoadThreads + NumMMAThreads); + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; + } + clc_pipeline_params.transaction_bytes = CLCResponseSize; + clc_pipeline_params.initializing_warp = 1; + CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + + // Mainloop-Epilogue pipeline + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (WarpCategory::MMA == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + accumulator_pipeline_params.initializing_warp = 2; + AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.mainloop.pipeline_accum, + accumulator_pipeline_params, + cluster_shape); + + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if (WarpCategory::MainloopABLoad == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (WarpCategory::Sched == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + clc_throttle_pipeline_params.producer_arv_count = NumMainloopABLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 3; + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between MMA and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumMMAThreads + NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } + + // Initialize smem barrier for prologue throttling. Epilogue warps are stalled until the prologue finishes. + arch::ClusterBarrier& epilogue_throttle_barrier = shared_storage.pipelines.epilogue_throttle; + if (WarpCategory::MMA == warp_category && lane_predicate) { + epilogue_throttle_barrier.init( NumMMAThreads + + (is_first_cta_in_cluster ? NumSchedThreads : 0) + + NumMainloopABLoadThreads + + (is_epi_load_needed ? NumEpilogueLoadThreads : 0)); + } + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + MainloopABPipelineState mainloop_ab_pipe_consumer_state; + MainloopABPipelineState mainloop_ab_pipe_producer_state = cutlass::make_producer_start_state(); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CLCPipelineState clc_pipe_consumer_state; + CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + MainloopSFPipelineState mainloop_sf_pipe_consumer_state; + MainloopSFPipelineState mainloop_sf_pipe_producer_state = cutlass::make_producer_start_state(); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + int32_t sm_id = static_cast(cutlass::arch::SmId()); + + // Calculate mask after cluster barrier arrival + mainloop_ab_pipeline.init_masks(cluster_shape, block_id_in_cluster); + accumulator_pipeline.init_masks(cluster_shape, block_id_in_cluster); + + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + // + // TMEM "Allocation" + // + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + TiledMma tiled_mma; + auto acc_shape = collective_mainloop.partition_accumulator_shape(); + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + + pipeline_init_wait(cluster_size); + + if constexpr (IsGroupedGemmKernel) { + if (not work_tile_info.is_valid()) { + // When problem shapes are only on device, the grid launched may be larger than the total number of blocks across groups + return; + } + // In case user wants to engage less SMs than available on device + sm_id = blockIdx.x + (blockIdx.y * gridDim.x); + } + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(work_tile_info.L_idx), 1); + + if (is_participant.main_ab_load) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_arrive = is_epi_load_needed; + auto load_inputs = collective_mainloop.load_ab_init( + problem_shape_MNKL, params.mainloop, + shared_storage.tensors.mainloop, + shared_storage.tensormaps.mainloop, + params.hw_info.sm_count, sm_id, work_tile_info.L_idx); + Tensor gA_mkl = get<0>(load_inputs); + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = get(load_inputs); + + // Initial batch's tensor address update + // Even the first tile for a CTA can be from any of the batches. + // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool did_batch_change = true; + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + bool requires_clc_query = true; + + do { + int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); // Usually just returns work_tile_info.L_idx; + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); + } + if (did_batch_change) { + collective_mainloop.tensormaps_perform_update( + shared_storage.tensormaps.mainloop, + params.mainloop, + input_tensormaps, + problem_shape, + curr_batch + ); + } + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}, shape<3>(gA_mkl)); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_prologue = min(MainloopABPipeline::Stages, k_tile_count); + + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + auto cta_coord_mnk = append<4>(make_coord(get<0>(cta_coord_mnkl), get<1>(cta_coord_mnkl), get<2>(cta_coord_mnkl)), Int<0>{}); + + if constexpr (IsSchedDynamicPersistent) { + if (is_first_cta_in_cluster && requires_clc_query) { + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + } + } + + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads + auto [mainloop_ab_producer_state_next, k_tile_iter_next] = collective_mainloop.load_ab( + params.mainloop, + mainloop_ab_pipeline, + mainloop_ab_pipe_producer_state, + load_inputs, + cta_coord_mnk, + k_tile_iter, k_tile_prologue, + did_batch_change + ); + mainloop_ab_pipe_producer_state = mainloop_ab_producer_state_next; + + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + auto [mainloop_ab_producer_state_next_, unused_] = collective_mainloop.load_ab( + params.mainloop, + mainloop_ab_pipeline, + mainloop_ab_pipe_producer_state, + load_inputs, + cta_coord_mnk, + k_tile_iter_next, k_tile_count - k_tile_prologue, + false /* did_batch_change - prologue loads handle tensormap acquire */ + ); + mainloop_ab_pipe_producer_state = mainloop_ab_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); + } while (work_tile_info.is_valid()); + collective_mainloop.load_ab_tail(mainloop_ab_pipeline, mainloop_ab_pipe_producer_state); + + } + + else if (is_participant.main_sf_load) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + int32_t curr_batch = idx2crd(work_tile_info.L_idx, get<3>(problem_shape_MNKL)); // Usually just returns work_tile_info.L_idx; + + auto mainloop_sf_inputs = collective_mainloop.load_sf_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop, curr_batch); + + Tensor gA_mkl = get<0>(mainloop_sf_inputs); + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool requires_clc_query = true; + bool did_batch_change = true; + + do { + + int32_t curr_batch = idx2crd(work_tile_info.L_idx, size<4>(gA_mkl)); // Usually just returns work_tile_info.L_idx; + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); + } + if (did_batch_change) { + mainloop_sf_inputs = collective_mainloop.load_sf_update( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop, curr_batch); + } + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}, shape<3>(gA_mkl)); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the loads + // we are managingo an array of pointers to change batches, we need to neglect the L mode + auto cta_coord_mnk = append<4>(make_coord(get<0>(cta_coord_mnkl), get<1>(cta_coord_mnkl), get<2>(cta_coord_mnkl)), Int<0>{}); + + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads + auto [mainloop_sf_producer_state_next, k_tile_iter_next] = collective_mainloop.load_sf( + mainloop_sf_pipeline, + mainloop_sf_pipe_producer_state, + mainloop_sf_inputs, + cta_coord_mnk, + k_tile_iter, k_tile_count + ); + mainloop_sf_pipe_producer_state = mainloop_sf_producer_state_next; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + did_batch_change = curr_batch != idx2crd(work_tile_info.L_idx, size<4>(gA_mkl)); + } while (work_tile_info.is_valid()); + + collective_mainloop.load_sf_tail( + mainloop_sf_pipeline, + mainloop_sf_pipe_producer_state + ); + + } + + else if (is_participant.sched) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + // Grouped GEMM uses static tile scheduler + if constexpr (IsSchedDynamicPersistent) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + do { + if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + } + + else if (is_participant.mma) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + accumulators.data() = tmem_base_ptr; + int tmem_non_accumulator_base = tmem_base_ptr + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + + + auto mma_inputs = collective_mainloop.mma_init( + params.mainloop, + collective_mainloop.slice_accumulator(accumulators, 0), + shared_storage.tensors.mainloop, + tmem_non_accumulator_base /*Start SF TMEM allocation after the accumulator*/); + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + do { + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(work_tile_info.L_idx), 1); + } + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + if (is_mma_leader_cta) { + auto [mainloop_ab_pipe_consumer_state_next, accumulator_pipe_producer_state_next] = collective_mainloop.mma( + cute::make_tuple( + mainloop_ab_pipeline, accumulator_pipeline), + cute::make_tuple( + mainloop_ab_pipe_consumer_state, accumulator_pipe_producer_state), + accumulators, + mma_inputs, + cta_coord_mnkl, + k_tile_count); + mainloop_ab_pipe_consumer_state = mainloop_ab_pipe_consumer_state_next; + accumulator_pipe_producer_state = accumulator_pipe_producer_state_next; + } + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + tmem_allocator.release_allocation_lock(); + + // Leader MMA waits for leader + peer epilogues to release accumulator stage + if (is_mma_leader_cta) { + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + } + // Signal to peer MMA that entire tmem allocation can be deallocated + if constexpr (has_mma_peer_cta) { + // Leader does wait + arrive, follower does arrive + wait + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, not is_mma_leader_cta); + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); + } + + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_wait = true; + bool do_tail_load = false; + int current_wave = 0; + + // Fetch a copy of tensormaps for the CTA from Params + auto epi_load_tensormap = get<0>(collective_epilogue.load_init( + params.epilogue, shared_storage.tensormaps.epilogue, params.hw_info.sm_count, sm_id)); + // Initial batch's tensor address update + // Even the first tile for a CTA can be from any of the batches. + // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool did_batch_change = true; + constexpr bool IsEpiLoad = true; + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + do { + int32_t curr_batch = work_tile_info.L_idx; + if (did_batch_change) { + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_load_tensormap, + problem_shape, + curr_batch + ); + } + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + if (do_load_order_wait) { + load_order_barrier.wait(); + do_load_order_wait = false; + } + + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); + } + bool reverse_epi_n = IsOverlappingAccum && (current_wave % 2 == 0); + epi_load_pipe_producer_state = collective_epilogue.template load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue, + cute::make_tuple(epi_load_tensormap, did_batch_change), + reverse_epi_n + ); + + do_tail_load = true; + } + current_wave++; + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != work_tile_info.L_idx; + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + // Register reconfiguration + arch::warpgroup_reg_alloc(); + + // Throttle the epilogue warps to improve prologue performance + static constexpr int epilogue_throttle_phase_bit = 0; + epilogue_throttle_barrier.wait(epilogue_throttle_phase_bit); + + // Wait for tmem allocate here + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + accumulators.data() = tmem_base_ptr; + + auto accum_inputs = collective_mainloop.accum_init(shared_storage.tensors.mainloop); + + auto warp_idx_in_epi = canonical_warp_idx_sync() - static_cast(WarpCategory::Epilogue); + bool do_tail_store = false; + // Fetch a copy of tensormaps for the CTA from Params + auto epi_store_tensormap = get<0>(collective_epilogue.store_init( + params.epilogue, shared_storage.tensormaps.epilogue, params.hw_info.sm_count, sm_id)); + // Initial batch's tensor address update + // Even the first tile for a CTA can be from any of the batches. + // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool did_batch_change = true; + constexpr bool IsEpiLoad = false; + + auto pipelines = cute::make_tuple(accumulator_pipeline, mainloop_sf_pipeline); + auto states = cute::make_tuple(accumulator_pipe_consumer_state, mainloop_sf_pipe_consumer_state); + + do { + int32_t curr_batch = work_tile_info.L_idx; + if (did_batch_change && warp_idx_in_epi == 0) { + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_store_tensormap, + problem_shape, + curr_batch + ); + } + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Fusions may need problem shape for the current group + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); + } + + // Get accumulator + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + + auto [accum, tiled_t2r, next_state] = collective_mainloop.accum( + pipelines, + states, + accumulators, + accum_inputs, + cta_coord_mnkl, + typename CollectiveEpilogue::CopyOpT2R{}, + typename CollectiveEpilogue::EpilogueTile{}, + k_tile_count + ); + + states = next_state; + + // + // Epilogue and write to gD + // + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change && warp_idx_in_epi == 0) { + collective_epilogue.template tensormaps_fence_acquire(epi_store_tensormap); + } + auto [load_state_next, store_state_next] = collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + accum, + shared_storage.tensors.epilogue, + epi_store_tensormap, + tiled_t2r // tiled_t2r + ); + + do_tail_store |= TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != work_tile_info.L_idx; + } while (work_tile_info.is_valid()); + + // Only perform a tail store if one of the work units processed performed + // an epilogue. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } + else { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp index abf8ba3f..9aec1636 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp @@ -131,16 +131,19 @@ public: static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; // Warp specialization thread count per threadblock - static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; - static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopABLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + static constexpr uint32_t NumMainloopSFLoadThreads = NumThreadsPerWarp; // 1 warp - static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + - NumMainloopLoadThreads + NumMMAThreads + - NumEpilogueLoadThreads + NumEpilogueThreads; + + static constexpr uint32_t MaxThreadsPerBlock = cute::round_up(NumSchedThreads + + NumMainloopABLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + NumEpilogueThreads + + NumMainloopSFLoadThreads, 128); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; static constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_load_pipe_increment(CtaShape_MNK{}); @@ -152,8 +155,8 @@ public: static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); // Pipeline and pipeline state types - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - using MainloopPipelineState = typename CollectiveMainloop::MainloopPipelineState; + using MainloopABPipeline = typename CollectiveMainloop::MainloopABPipeline; + using MainloopABPipelineState = typename CollectiveMainloop::MainloopABPipelineState; using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; @@ -163,11 +166,11 @@ public: using LoadOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; - using Mma2TransformPipeline = typename CollectiveMainloop::Mma2TransformPipeline; - using Mma2TransformPipelineState = typename Mma2TransformPipeline::PipelineState; + using AccumulatorPipeline = typename CollectiveMainloop::AccumulatorPipeline; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; - using Load2TransformPipeline = typename CollectiveMainloop::Load2TransformPipeline; - using Load2TransformPipelineState = typename Load2TransformPipeline::PipelineState; + using MainloopSFPipeline = typename CollectiveMainloop::MainloopSFPipeline; + using MainloopSFPipelineState = typename MainloopSFPipeline::PipelineState; using CLCPipeline = cutlass::PipelineCLCFetchAsync; using CLCPipelineState = typename CLCPipeline::PipelineState; @@ -178,7 +181,7 @@ public: using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; - static constexpr uint32_t GenericRegisterRequirement = 104; + static constexpr uint32_t GenericRegisterRequirement = 48; static constexpr uint32_t AccumRegisterRequirement = 256; // Kernel level shared memory storage @@ -186,19 +189,15 @@ public: // Barriers should be allocated in lower 8KB of SMEM for SM100 struct PipelineStorage : cute::aligned_struct<16, _1> { using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; - using Load2TransformPipelineStorage = typename CollectiveMainloop::Load2TransformPipelineStorage; using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; using CLCPipelineStorage = typename CLCPipeline::SharedStorage; - using Mma2TransformPipelineStorage = typename CollectiveMainloop::Mma2TransformPipelineStorage; using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; alignas(16) MainloopPipelineStorage mainloop; - alignas(16) Load2TransformPipelineStorage load2transform; alignas(16) EpiLoadPipelineStorage epi_load; alignas(16) LoadOrderBarrierStorage load_order; alignas(16) CLCPipelineStorage clc; - alignas(16) Mma2TransformPipelineStorage mma2transform; alignas(16) CLCThrottlePipelineStorage clc_throttle; alignas(16) arch::ClusterBarrier tmem_dealloc; alignas(16) arch::ClusterBarrier epilogue_throttle; @@ -240,19 +239,23 @@ public: }; enum class WarpCategory : int32_t { - MMA = 0, - Sched = 1, - MainloopLoad = 2, - EpilogueLoad = 3, - Epilogue = 4 + MMA = 0, + Sched = 1, + MainloopABLoad = 2, + EpilogueLoad = 3, + Epilogue = 4, // 4 warps + MainloopSFLoad = 8, + Unused = 9, }; struct IsParticipant { - uint32_t mma = false; - uint32_t sched = false; - uint32_t main_load = false; - uint32_t epi_load = false; - uint32_t epilogue = false; + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_ab_load = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + uint32_t main_sf_load = false; + uint32_t unused = false; }; // @@ -407,8 +410,20 @@ public: // Account for more than one epilogue warp int warp_idx = canonical_warp_idx_sync(); - WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) - : WarpCategory::Epilogue; + WarpCategory warp_category = [&] () CUTLASS_LAMBDA_FUNC_INLINE { + if (warp_idx < static_cast(WarpCategory::Epilogue)) { + return WarpCategory(warp_idx); + } + else if (warp_idx < static_cast(WarpCategory::MainloopSFLoad)) { + return WarpCategory::Epilogue; + } + else if (warp_idx == static_cast(WarpCategory::MainloopSFLoad)) { + return WarpCategory::MainloopSFLoad; + } + else { + return WarpCategory::Unused; + } + }(); uint32_t lane_predicate = cute::elect_one_sync(); auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}); @@ -440,41 +455,43 @@ public: IsParticipant is_participant = { (warp_category == WarpCategory::MMA), // mma (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched - (warp_category == WarpCategory::MainloopLoad), // main_load + (warp_category == WarpCategory::MainloopABLoad), // main_ab_load (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load - (warp_category == WarpCategory::Epilogue) // epilogue + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::MainloopSFLoad), // main_sf_load + (warp_category == WarpCategory::Unused) // unused }; // Mainloop Load pipeline - typename MainloopPipeline::Params mainloop_pipeline_params; - if (WarpCategory::MainloopLoad == warp_category) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + typename MainloopABPipeline::Params mainloop_ab_pipeline_params; + if (WarpCategory::MainloopABLoad == warp_category) { + mainloop_ab_pipeline_params.role = MainloopABPipeline::ThreadCategory::Producer; } if (WarpCategory::MMA == warp_category) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + mainloop_ab_pipeline_params.role = MainloopABPipeline::ThreadCategory::Consumer; } - mainloop_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_load; - mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; - mainloop_pipeline_params.initializing_warp = 0; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, - mainloop_pipeline_params, - cluster_shape, - cute::true_type{}, // Perform barrier init - cute::false_type{}); // Delay mask calculation + mainloop_ab_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_ab_load; + mainloop_ab_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + mainloop_ab_pipeline_params.initializing_warp = 0; + MainloopABPipeline mainloop_ab_pipeline(shared_storage.pipelines.mainloop.pipeline_ab, + mainloop_ab_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation - typename Load2TransformPipeline::Params load2transform_pipeline_params; - if (WarpCategory::MainloopLoad == warp_category) { - load2transform_pipeline_params.role = Load2TransformPipeline::ThreadCategory::Producer; + typename MainloopSFPipeline::Params mainloop_sf_pipeline_params; + if (WarpCategory::MainloopSFLoad == warp_category) { + mainloop_sf_pipeline_params.role = MainloopSFPipeline::ThreadCategory::Producer; } if (WarpCategory::Epilogue == warp_category) { - load2transform_pipeline_params.role = Load2TransformPipeline::ThreadCategory::Consumer; + mainloop_sf_pipeline_params.role = MainloopSFPipeline::ThreadCategory::Consumer; } - load2transform_pipeline_params.initializing_warp = 0; - load2transform_pipeline_params.producer_arv_count = CollectiveMainloop::NumLoad2TransformProducerThreadEvents; - load2transform_pipeline_params.consumer_arv_count = NumEpilogueThreads; + mainloop_sf_pipeline_params.initializing_warp = 8; + mainloop_sf_pipeline_params.producer_arv_count = CollectiveMainloop::NumMainloopSFProducerThreadEvents; + mainloop_sf_pipeline_params.consumer_arv_count = NumEpilogueThreads; - Load2TransformPipeline load2transform_pipeline(shared_storage.pipelines.load2transform, - load2transform_pipeline_params); + MainloopSFPipeline mainloop_sf_pipeline(shared_storage.pipelines.mainloop.pipeline_sf, + mainloop_sf_pipeline_params); // Epilogue Load pipeline typename EpiLoadPipeline::Params epi_load_pipeline_params; @@ -498,8 +515,8 @@ public: // Load order barrier typename LoadOrderBarrier::Params load_order_barrier_params; - load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopLoad) ? 0 : 1; - load_order_barrier_params.group_size = NumMainloopLoadThreads; + load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopABLoad) ? 0 : 1; + load_order_barrier_params.group_size = NumMainloopABLoadThreads; load_order_barrier_params.initializing_warp = 5; LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); @@ -514,7 +531,8 @@ public: clc_pipeline_params.producer_blockid = 0; clc_pipeline_params.producer_arv_count = 1; clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * - (NumMainloopLoadThreads + NumEpilogueThreads + NumMMAThreads); + (NumMainloopABLoadThreads + NumEpilogueThreads + + NumMMAThreads + NumMainloopSFLoadThreads); if (is_epi_load_needed) { clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; } @@ -523,30 +541,30 @@ public: CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); // Mainloop-Epilogue pipeline - typename Mma2TransformPipeline::Params mma2transform_pipeline_params; + typename AccumulatorPipeline::Params accumulator_pipeline_params; if (WarpCategory::MMA == warp_category) { - mma2transform_pipeline_params.role = Mma2TransformPipeline::ThreadCategory::Producer; + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; } if (WarpCategory::Epilogue == warp_category) { - mma2transform_pipeline_params.role = Mma2TransformPipeline::ThreadCategory::Consumer; + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; } // Only one producer thread arrives on this barrier. - mma2transform_pipeline_params.producer_arv_count = 1; - mma2transform_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; - mma2transform_pipeline_params.initializing_warp = 2; - Mma2TransformPipeline mma2transform_pipeline(shared_storage.pipelines.mma2transform, - mma2transform_pipeline_params, + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + accumulator_pipeline_params.initializing_warp = 2; + AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.mainloop.pipeline_accum, + accumulator_pipeline_params, cluster_shape); // CLC throttle pipeline typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; - if (WarpCategory::MainloopLoad == warp_category) { + if (WarpCategory::MainloopABLoad == warp_category) { clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; } if (WarpCategory::Sched == warp_category) { clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; } - clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + clc_throttle_pipeline_params.producer_arv_count = NumMainloopABLoadThreads; clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; clc_throttle_pipeline_params.dst_blockid = 0; clc_throttle_pipeline_params.initializing_warp = 3; @@ -573,7 +591,7 @@ public: if (WarpCategory::MMA == warp_category && lane_predicate) { epilogue_throttle_barrier.init( NumMMAThreads + (is_first_cta_in_cluster ? NumSchedThreads : 0) + - NumMainloopLoadThreads + + NumMainloopABLoadThreads + (is_epi_load_needed ? NumEpilogueLoadThreads : 0)); } @@ -581,11 +599,11 @@ public: // To all producers and consumer threadblocks in the cluster pipeline_init_arrive_relaxed(cluster_size); - auto load_inputs = collective_mainloop.load_init( + auto load_inputs = collective_mainloop.load_ab_init( problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); - MainloopPipelineState mainloop_pipe_consumer_state; - MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + MainloopABPipelineState mainloop_ab_pipe_consumer_state; + MainloopABPipelineState mainloop_ab_pipe_producer_state = cutlass::make_producer_start_state(); EpiLoadPipelineState epi_load_pipe_consumer_state; EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); @@ -596,17 +614,17 @@ public: CLCPipelineState clc_pipe_consumer_state; CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); - Mma2TransformPipelineState mma2transform_pipe_consumer_state; - Mma2TransformPipelineState mma2transform_pipe_producer_state = cutlass::make_producer_start_state(); + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); - Load2TransformPipelineState load2transform_pipe_consumer_state; - Load2TransformPipelineState load2transform_pipe_producer_state = cutlass::make_producer_start_state(); + MainloopSFPipelineState mainloop_sf_pipe_consumer_state; + MainloopSFPipelineState mainloop_sf_pipe_producer_state = cutlass::make_producer_start_state(); dim3 block_id_in_cluster = cute::block_id_in_cluster(); // Calculate mask after cluster barrier arrival - mainloop_pipeline.init_masks(cluster_shape, block_id_in_cluster); - mma2transform_pipeline.init_masks(cluster_shape, block_id_in_cluster); + mainloop_ab_pipeline.init_masks(cluster_shape, block_id_in_cluster); + accumulator_pipeline.init_masks(cluster_shape, block_id_in_cluster); // TileID scheduler TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); @@ -619,7 +637,7 @@ public: pipeline_init_wait(cluster_size); - if (is_participant.main_load) { + if (is_participant.main_ab_load) { // Register reconfiguration arch::warpgroup_reg_dealloc(); @@ -633,15 +651,12 @@ public: epilogue_throttle_barrier.arrive(); bool requires_clc_query = true; - auto pipelines = cute::make_tuple(mainloop_pipeline, load2transform_pipeline); - auto states = cute::make_tuple(mainloop_pipe_producer_state, load2transform_pipe_producer_state); - do { // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}, load_inputs.k_tiles); auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); - auto k_tile_prologue = min(MainloopPipeline::Stages, k_tile_count); + auto k_tile_prologue = min(MainloopABPipeline::Stages, k_tile_count); if constexpr (IsSchedDynamicPersistent) { if (is_first_cta_in_cluster && requires_clc_query) { @@ -652,34 +667,28 @@ public: } // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads - auto [mainloop_producer_state_next, load2transform_producer_state_next, k_tile_iter_next] = collective_mainloop.load( - mainloop_pipeline, - load2transform_pipeline, - mainloop_pipe_producer_state, - load2transform_pipe_producer_state, + auto [mainloop_ab_producer_state_next, k_tile_iter_next] = collective_mainloop.load_ab( + mainloop_ab_pipeline, + mainloop_ab_pipe_producer_state, load_inputs, cta_coord_mnkl, k_tile_iter, k_tile_prologue ); - mainloop_pipe_producer_state = mainloop_producer_state_next; - load2transform_pipe_producer_state = load2transform_producer_state_next; + mainloop_ab_pipe_producer_state = mainloop_ab_producer_state_next; if (do_load_order_arrive) { load_order_barrier.arrive(); do_load_order_arrive = false; } - auto [mainloop_producer_state_next_, load2transform_producer_state_next_, unused_] = collective_mainloop.load( - mainloop_pipeline, - load2transform_pipeline, - mainloop_pipe_producer_state, - load2transform_pipe_producer_state, + auto [mainloop_ab_producer_state_next_, unused_] = collective_mainloop.load_ab( + mainloop_ab_pipeline, + mainloop_ab_pipe_producer_state, load_inputs, cta_coord_mnkl, k_tile_iter_next, k_tile_count - k_tile_prologue ); - mainloop_pipe_producer_state = mainloop_producer_state_next_; - load2transform_pipe_producer_state = load2transform_producer_state_next_; + mainloop_ab_pipe_producer_state = mainloop_ab_producer_state_next_; // Sync warp to prevent non-participating threads entering next wave early __syncwarp(); @@ -697,11 +706,61 @@ public: } } while (work_tile_info.is_valid()); - collective_mainloop.load_tail( - mainloop_pipeline, - load2transform_pipeline, - mainloop_pipe_producer_state, - load2transform_pipe_producer_state + collective_mainloop.load_ab_tail( + mainloop_ab_pipeline, + mainloop_ab_pipe_producer_state + ); + + } + + else if (is_participant.main_sf_load) { + auto mainloop_sf_inputs = collective_mainloop.load_sf_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); + + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool requires_clc_query = true; + + do { + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}, mainloop_sf_inputs.k_tiles); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads + auto [mainloop_sf_producer_state_next, k_tile_iter_next] = collective_mainloop.load_sf( + mainloop_sf_pipeline, + mainloop_sf_pipe_producer_state, + mainloop_sf_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_count + ); + mainloop_sf_pipe_producer_state = mainloop_sf_producer_state_next; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + + collective_mainloop.load_sf_tail( + mainloop_sf_pipeline, + mainloop_sf_pipe_producer_state ); } @@ -791,16 +850,16 @@ public: } if (is_mma_leader_cta) { - auto [mainloop_pipe_consumer_state_, mma2transform_pipe_producer_state_] = collective_mainloop.mma( - cute::make_tuple(mainloop_pipeline, mma2transform_pipeline), - cute::make_tuple(mainloop_pipe_consumer_state, mma2transform_pipe_producer_state), + auto [mainloop_ab_pipe_consumer_state_, accumulator_pipe_producer_state_] = collective_mainloop.mma( + cute::make_tuple(mainloop_ab_pipeline, accumulator_pipeline), + cute::make_tuple(mainloop_ab_pipe_consumer_state, accumulator_pipe_producer_state), tmem_storage, mma_inputs, cta_coord_mnkl, k_tile_count ); - mainloop_pipe_consumer_state = mainloop_pipe_consumer_state_; - mma2transform_pipe_producer_state = mma2transform_pipe_producer_state_; + mainloop_ab_pipe_consumer_state = mainloop_ab_pipe_consumer_state_; + accumulator_pipe_producer_state = accumulator_pipe_producer_state_; } work_tile_info = next_work_tile_info; @@ -817,7 +876,7 @@ public: // Leader MMA waits for leader + peer epilogues to release stage if (is_mma_leader_cta) { - mma2transform_pipeline.producer_tail(mma2transform_pipe_producer_state); + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); } // Signal to peer MMA that entire tmem allocation can be deallocated if constexpr (has_mma_peer_cta) { @@ -912,13 +971,13 @@ public: uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); - auto transform_inputs = collective_mainloop.transform_init( + auto accum_inputs = collective_mainloop.accum_init( problem_shape_MNKL, shared_storage.tensors.mainloop ); - auto pipelines = cute::make_tuple(mma2transform_pipeline, load2transform_pipeline); - auto states = cute::make_tuple(mma2transform_pipe_consumer_state, load2transform_pipe_consumer_state); + auto pipelines = cute::make_tuple(accumulator_pipeline, mainloop_sf_pipeline); + auto states = cute::make_tuple(accumulator_pipe_consumer_state, mainloop_sf_pipe_consumer_state); bool do_tail_store = false; do { @@ -935,11 +994,11 @@ public: ++clc_pipe_consumer_state; } - auto [accum, tiled_t2r, next_state] = collective_mainloop.transform( + auto [accum, tiled_t2r, next_state] = collective_mainloop.accum( pipelines, states, tmem_storage, - transform_inputs, + accum_inputs, cta_coord_mnkl, typename CollectiveEpilogue::CopyOpT2R{}, typename CollectiveEpilogue::EpilogueTile{}, diff --git a/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp b/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp index 6698cb09..6ceba52b 100755 --- a/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp @@ -405,7 +405,7 @@ public: return make_coord(m_coord, n_coord, _, l_coord); } - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static void issue_clc_query(PipelineState state, uint32_t mbarrier_addr, CLCResponse* clc_response_ptr) { #if defined(CUTLASS_ARCH_CLC_ENABLED) @@ -468,7 +468,7 @@ public: // Kernel helper function to get next work tile template - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE auto fetch_next_work( WorkTileInfo work_tile_info, @@ -627,9 +627,10 @@ public: store_query_response(state, make_invalid_response()); } - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE void store_query_response(PipelineState state, CLCResponse clc_response) { + #if defined(__CUDA_ARCH__) uint32_t smem_ptr = cute::cast_smem_ptr_to_uint(&clc_response_ptr_[state.index()]); asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" : : "r"(smem_ptr) @@ -638,6 +639,7 @@ public: , "r"(clc_response.data[2]) , "r"(clc_response.data[3])); cutlass::arch::fence_view_async_shared(); + #endif } CUTLASS_DEVICE diff --git a/include/cutlass/numeric_types.h b/include/cutlass/numeric_types.h index de9ee5fa..b79a3d25 100644 --- a/include/cutlass/numeric_types.h +++ b/include/cutlass/numeric_types.h @@ -86,8 +86,7 @@ template <> struct has_negative_zero : CUTE_STL_NAMESPACE::true_type{}; template <> struct has_negative_zero : CUTE_STL_NAMESPACE::true_type{}; template <> struct has_negative_zero : CUTE_STL_NAMESPACE::true_type{}; - -// Helper variable template +// Helper variable template template inline constexpr bool has_negative_zero_v = has_negative_zero::value; @@ -109,3 +108,6 @@ struct get_unpacked_element_type { } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// + + + diff --git a/include/cutlass/pipeline/sm100_pipeline.hpp b/include/cutlass/pipeline/sm100_pipeline.hpp index 7f324d31..4ebd8b5d 100644 --- a/include/cutlass/pipeline/sm100_pipeline.hpp +++ b/include/cutlass/pipeline/sm100_pipeline.hpp @@ -985,7 +985,7 @@ public: consumer_release(state.index()); } - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE uint32_t producer_get_barrier(PipelineState state) { return cute::cast_smem_ptr_to_uint(reinterpret_cast(&full_barrier_ptr_[state.index()])); } diff --git a/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp b/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp index 23ad9673..577c68c3 100644 --- a/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp +++ b/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp @@ -411,7 +411,7 @@ private: CUTE_UNROLL for (int elt_log_idx = 0; elt_log_idx < OneChunkSizeA{}; ++elt_log_idx) { ElementAMmaRawUnit elem_A = tAsA[elt_log_idx]; - + // Handle negative 0 ElementAMmaRawUnit masked_elem_A = elem_A; if constexpr (has_negative_zero_v) { @@ -506,10 +506,8 @@ private: constexpr bool IsRowMajor = cute::is_same_v; using Element = typename TensorSrc::element_type; - constexpr bool IsQmmaF6 = cute::sizeof_bits_v == 6; - CUTE_STATIC_ASSERT(cute::is_static_v, "shape(dSrc) needs to be static"); CUTE_STATIC_ASSERT(cute::is_static_v, "shape(dDst) needs to be static"); CUTE_STATIC_ASSERT(cute::sizeof_bits_v == cute::sizeof_bits_v, @@ -557,7 +555,6 @@ private: for (int iter_col_thr = 0; iter_col_thr < ValueShapeCols; ++iter_col_thr) { const int row_i = (iter_row_blk * ThreadShapeRows + threadIdx_X_row) * ValueShapeRows + iter_row_thr; const int col_i = (col_chunk_i * ThreadShapeCols + threadIdx_X_col) * ValueShapeCols + iter_col_thr; - if constexpr ( (not pred) and (not IsQmmaF6) ) { dDst(row_i, col_i) = dSrc(row_i, col_i); } diff --git a/include/cutlass/version.h b/include/cutlass/version.h index 1e2b5de9..4514330c 100644 --- a/include/cutlass/version.h +++ b/include/cutlass/version.h @@ -35,7 +35,7 @@ #include #define CUTLASS_MAJOR 3 -#define CUTLASS_MINOR 8 +#define CUTLASS_MINOR 9 #define CUTLASS_PATCH 0 #ifdef CUTLASS_VERSIONS_GENERATED diff --git a/media/docs/blackwell_cluster_launch_control.md b/media/docs/cpp/blackwell_cluster_launch_control.md similarity index 90% rename from media/docs/blackwell_cluster_launch_control.md rename to media/docs/cpp/blackwell_cluster_launch_control.md index faebb900..bdebed1d 100644 --- a/media/docs/blackwell_cluster_launch_control.md +++ b/media/docs/cpp/blackwell_cluster_launch_control.md @@ -76,8 +76,8 @@ __device__ clc_dynamic_persistent_kernel(...) { ### Cluster Launch Control Pipeline Class -Please refer to the `PipelineCLCFetchAsync` pipeline class defined in [Cluster launch control pipeline class](/include/cutlass/pipeline/sm100_pipeline.hpp). Cluster launch control queries can be pipelined and mananged by an asynchronous pipeline with producer-consumer relationship (See -[pipeline](/media/docs/pipeline.md) document). The producer is the scheduler warp of the 0th CTA in the cluster and the consumers are all warps that need `ClcID`s. +Please refer to the `PipelineCLCFetchAsync` pipeline class defined in [Cluster launch control pipeline class](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/pipeline/sm100_pipeline.hpp). Cluster launch control queries can be pipelined and mananged by an asynchronous pipeline with producer-consumer relationship (See +[pipeline](pipeline.md) document). The producer is the scheduler warp of the 0th CTA in the cluster and the consumers are all warps that need `ClcID`s. To setup a CLC pipeline correctly, we need to make sure the params are set to the right values: @@ -88,18 +88,18 @@ To setup a CLC pipeline correctly, we need to make sure the params are set to th ### Dynamic tile scheduler class -Please refer to `PersistentTileSchedulerSm100` class defined in [sm100 dynamic persistent tile scheduler](/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp). +Please refer to `PersistentTileSchedulerSm100` class defined in [sm100 dynamic persistent tile scheduler](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp). There are two important methods of the CLC scheduler class. The first is `advance_to_next_work`, which is intended to be executed by one elected thread from the scheduler warp. It effectively sends out the CLC query to the CLC. A CLC query response will be broadcast to the same shared memory address of all CTAs in the cluster. The other method is named `get_current_work`. It simply loads the CLC response from the shared memory buffer indexed by a pipeline state. -The CLC pipeline and scheduler classes are used together to ensure correct functionality and necessary synchronization of CLC feature. Please refer to [cluster launch control pipeline unit test](/test/unit/pipeline/pipeline_cluster_launch_control_async_warp_specialized_blackwell.cu). +The CLC pipeline and scheduler classes are used together to ensure correct functionality and necessary synchronization of CLC feature. Please refer to [cluster launch control pipeline unit test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/pipeline/pipeline_cluster_launch_control_async_warp_specialized_blackwell.cu). ## Blackwell Warp-specialized Persistent Kernel -Now, let's take a look at how CLC feature is used in our [Blackwell dense GEMM kernel](/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp). +Now, let's take a look at how CLC feature is used in our [Blackwell dense GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp). This particular warp-specialized kernel has the following warp assignment: diff --git a/media/docs/blackwell_functionality.md b/media/docs/cpp/blackwell_functionality.md similarity index 72% rename from media/docs/blackwell_functionality.md rename to media/docs/cpp/blackwell_functionality.md index f5d51bae..e751a124 100644 --- a/media/docs/blackwell_functionality.md +++ b/media/docs/cpp/blackwell_functionality.md @@ -77,10 +77,10 @@ All four layouts (TT, NN, NT, TT) are supported for all legacy data types. | | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test | |-------------------------------|------------|------------|----------------|-------------|-------------|-------------------------|-----------| |1 | tfloat32_t | tfloat32_t | TN, NN, NT, TT | 4 | 4 | tf32 | | -|2 | half_t | half_t | TN, NN, NT, TT | 8 | 8 | f16 | [Unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)| -|3 | bfloat16_t | bfloat16_t | TN, NN, NT, TT | 8 | 8 | f16 | [Similar to half_t unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)| -|4 | int8_t | int8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu)| -|5 | uint8_t | uint8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Similar to int8_t unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu)| +|2 | half_t | half_t | TN, NN, NT, TT | 8 | 8 | f16 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)| +|3 | bfloat16_t | bfloat16_t | TN, NN, NT, TT | 8 | 8 | f16 | [Similar to half_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)| +|4 | int8_t | int8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu)| +|5 | uint8_t | uint8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Similar to int8_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu)| For narrow precision Mmas, not all A/B type, and A/B layout combinations are supported by every `tcgen05.mma` instructions. Furthermore, tensor copy instructions for subbyte types impose additional alignment requirements while loading narrow-precision @@ -93,31 +93,31 @@ instructions supported by CUTLASS. **Table 2: Valid Data Type, Alignment, and Layout Combinations For Narrow Precision MMAs Without Block Scaling** | | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test | |-------------------------------|----------|----------|----------------|-------------|-------------|-------------------------|-----------| -|[1](#nonbs_rows_1_2_3_6) | float4_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | -|[2](#nonbs_rows_1_2_3_6) | float4_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | -|[3](#nonbs_rows_1_2_3_6) | float6_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | -|[4](#nonbs_rows_4_7) | float4_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) | -|[5](#nonbs_rows_5_8) | float8_t | float4_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) | -|[6](#nonbs_rows_1_2_3_6) | float6_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | -|[7](#nonbs_rows_4_7) | float6_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) | -|[8](#nonbs_rows_5_8) | float8_t | float6_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) | -|[9](#nonbs_rows_9) | float8_t | float8_t | TN, NN, NT, TT | 16 | 16 | f8f6f4 | [Unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_f32.cu)| +|[1](#nonbs_rows_1_2_3_6) | float4_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | +|[2](#nonbs_rows_1_2_3_6) | float4_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | +|[3](#nonbs_rows_1_2_3_6) | float6_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | +|[4](#nonbs_rows_4_7) | float4_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) | +|[5](#nonbs_rows_5_8) | float8_t | float4_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) | +|[6](#nonbs_rows_1_2_3_6) | float6_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | +|[7](#nonbs_rows_4_7) | float6_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) | +|[8](#nonbs_rows_5_8) | float8_t | float6_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) | +|[9](#nonbs_rows_9) | float8_t | float8_t | TN, NN, NT, TT | 16 | 16 | f8f6f4 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_f32.cu)| **Table 3: Valid Data Type, Alignment, and Layout Combinations for Block Scaled Narrow Precision MMAs** | | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind |Unit Test| |-------------------------|-------------|-------------|----------------|-------------|-------------|-------------------------|------| -|[1](#bs_rows_1) | nv_float4_t | nv_float4_t | TN | 32 | 32 | mxf4nvf4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16.cu)| -|[2](#bs_rows_2) | mx_float4_t | mx_float4_t | TN | 32 | 32 | mxf4, mxf4nvf4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_tn_layout.cu)| -|[3](#bs_rows_3) | mx_float4_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_nt_layout.cu)| -|[4](#bs_rows_4_5_7_8_10) | mx_float4_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_nt_layout.cu)| -|[5](#bs_rows_4_5_7_8_10) | mx_float6_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_nt_layout.cu)| -|[6](#bs_rows_6_9_11) | mx_float4_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_nt_layout.cu)| -|[7](#bs_rows_4_5_7_8_10) | mx_float8_t | mx_float4_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_nt_layout.cu)| -|[8](#bs_rows_4_5_7_8_10) | mx_float6_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_nt_layout.cu)| -|[9](#bs_rows_6_9_11) | mx_float6_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_nt_layout.cu)| -|[10](#bs_rows_4_5_7_8_10)| mx_float8_t | mx_float6_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_nt_layout.cu)| -|[11](#bs_rows_6_9_11) | mx_float8_t | mx_float8_t | TN, NN, NT, TT | 16 | 16 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_tn_layout.cu.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_nt_layout.cu)| +|[1](#bs_rows_1) | nv_float4_t | nv_float4_t | TN | 32 | 32 | mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16.cu)| +|[2](#bs_rows_2) | mx_float4_t | mx_float4_t | TN | 32 | 32 | mxf4, mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_tn_layout.cu)| +|[3](#bs_rows_3) | mx_float4_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_nt_layout.cu)| +|[4](#bs_rows_4_5_7_8_10) | mx_float4_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_nt_layout.cu)| +|[5](#bs_rows_4_5_7_8_10) | mx_float6_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_nt_layout.cu)| +|[6](#bs_rows_6_9_11) | mx_float4_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_nt_layout.cu)| +|[7](#bs_rows_4_5_7_8_10) | mx_float8_t | mx_float4_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_nt_layout.cu)| +|[8](#bs_rows_4_5_7_8_10) | mx_float6_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_nt_layout.cu)| +|[9](#bs_rows_6_9_11) | mx_float6_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_nt_layout.cu)| +|[10](#bs_rows_4_5_7_8_10)| mx_float8_t | mx_float6_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_nt_layout.cu)| +|[11](#bs_rows_6_9_11) | mx_float8_t | mx_float8_t | TN, NN, NT, TT | 16 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_tn_layout.cu.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_nt_layout.cu)| ## MMA tile shapes supported @@ -327,18 +327,18 @@ Similarly for epilogues, we can use `cutlass::epilogue::collective::EpilogueSche ## Building a Block Scaled Kernel -For non-blockscaled dense GEMM refer to [quick start page](quickstart.md#instantiating-a-blackwell-gemm-kernel). An example dense GEMM can be found: -1. [Blackwell FP16 GEMM example](../../examples/70_blackwell_gemm/). +For non-blockscaled dense GEMM refer to [quick start page](quickstart.md#instantiating-a-blackwell-sm100-gemm-kernel). An example dense GEMM can be found: +1. [Blackwell FP16 GEMM example](https://github.com/NVIDIA/cutlass/tree/main/examples/70_blackwell_gemm/). Narrow precision and block scaled narrow precision kernels can be built using CUTLASS 3.x collective builder interface (as described in [CUTLASS 3.0 GEMM API](gemm_api_3x.md#cutlass-30-gemm-api)). However, special attention needs to be given to A and B matrix layouts, alignment requirements, and dispatch policies to obtain a functionally correct and performant kernel which are listed above. -Several examples of block scaled kernels can be found in [examples/72_blackwell_narrow_precision_gemm](../../examples/72_blackwell_narrow_precision_gemm/) directory: -1. [NVF4 Gemm with block scaling](../../examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu) -2. [NVF4 Gemm with block scaling and NVF4 output matrix](../../examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu) -3. [Mixed precision Nvf4 x Mxf8 GEMM with block scaling](../../examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu) +Several examples of block scaled kernels can be found in [examples/72_blackwell_narrow_precision_gemm](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/) directory: +1. [NVF4 Gemm with block scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu) +2. [NVF4 Gemm with block scaling and NVF4 output matrix](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu) +3. [Mixed precision Nvf4 x Mxf8 GEMM with block scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu) Collective builder interface expects the same arguments as any other CUTLASS 3.x kernels as described [here](gemm_api_3x.md#collective-builder-for-collectivemmas) with a small difference for Collective MMA builder interface. @@ -508,30 +508,24 @@ Typically, GmemLayoutSFD would be same as the GmemLayoutD. ``` Above example made a gentle introduction to using the fusion operations in the epilogue. For more detailed example, see -[Blackwell GEMM with collective builder](../../examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) +[Blackwell GEMM with collective builder](https://github.com/NVIDIA/cutlass/tree/main/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) Note that we have first discussed the CollectiveMainloop, then the CollectiveEpilogue for clarity. -However, the CollectiveMainloop needs to know the SMEM utilization of the epilogue. Therefore, it needs to be setup before the CollectiveMainloop. See [examples/72_blackwell_narrow_precision_gemm](../../examples/72_blackwell_narrow_precision_gemm/) directory for full kernel and run setup. +However, the CollectiveMainloop needs to know the SMEM utilization of the epilogue. Therefore, it needs to be setup before the CollectiveMainloop. See [examples/72_blackwell_narrow_precision_gemm](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/) directory for full kernel and run setup. ### Scale Factor Layouts The scale factor layout consists of a 512B basic-block structure, as illustrated in the diagram below. Each block contains 128 M/N dimension and 4 scale factors (SF) along the K dimension. The byte order of the basic storage chunk is row-major, meaning that M0SF0 to M0SF3, M32SF0 to M32SF3, M64SF0 to M64SF3, and M96SF0 to M96SF3 are stored consecutively in GMEM. -[](../images/M128xK4_scalefactor_gmem.png) -

- /M128xK4_scalefactor_gmem.png -

+![ALT](../../images/M128xK4_scalefactor_gmem.png) If the scale factor tensor exceeds M128xSF4, it indicates that there are multiple basic blocks along both the M and SFK dimensions. The arrangement of these basic blocks follows a K-major order. Here is a diagram illustrating the scenario where M equals 512 and the SFK is 16. -[](../images/narrow_precison_multiple_block_sf_layout.png) -

- /narrow_precison_multiple_block_sf_layout.png -

+![ALT](../../images/narrow_precison_multiple_block_sf_layout.png) The creation of scale factor tensors' layouts are tedious. CUTLASS provides `Sm1xxBlockScaledConfig` to create these layouts easily -(See [sm100_blockscaled_layout.hpp](cutlass/include/cutlass/detail/sm100_blockscaled_layout.hpp)). +(See [sm100_blockscaled_layout.hpp](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/detail/sm100_blockscaled_layout.hpp)). The interface to create SFA and SFB tensor layouts is as follows: ```cpp @@ -548,6 +542,77 @@ auto tensor_sfb = make_tensor(bptr, layout_sfb); // Access SF for for element m,k of A tensor auto val_a_mk = tensor_sfa(make_coord(m,k,0)); ``` +# Blackwell SM120 GEMMs +The NVIDIA RTX 5000 Series GPUs introduce support for new narrow precision (4bit and 6bit) block-scaled and non-block-scaled tensor cores. The PTX ISA has extended the `mma` instructions to support these data formats which are 1x to 4x faster than Ada architecture's fp8 tensor cores. For more detailed information see [`mma` PTX documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#multiply-and-accumulate-instruction-mma). + +CUTLASS 4.0 has added support for these newly introduced narrow precision GEMMs. Similar to the Blackwell SM100 GEMMs, the SM120 GEMMs can be built using the collective builder interface. See examples in [examples/79_blackwell_geforce_gemm/](../../examples/79_blackwell_geforce_gemm/) and unit tests listed below. + +The data types supported and tensor alignment requirements are the same as the Blackwell SM100 GEMMs. The scale factor layout is also the same as SM100 mentioned above. `OpClassTensorOp` is used for non-blockscaled narrow precision GEMMs and `OpClassBlockScaledTensorOp` is used for blockscaled narrow precision GEMMs. + +| Ptx Instruction | Throughput | Notes | Unit Test | +|---------------------------------------------------------------------|----------------------------|-------|-----------| +|mma.sync.aligned.kind::f8f6f4 | 1x Ada Fp8 Tensor Core(2x for FP32 accumulator) | Mixed precision MMA with A={f4,f6,f8} x B={f4,f6,f8} TN layouts | [unit test](../../test/unit/gemm/device/sm120_tensorop_gemm/) | +|mma.sync.aligned.kind::mxf8f6f4.block_scale | 1x Ada Fp8 Tensor Core(2x for FP32 accumulator) | Block scaled mixed precision MMA with A={mxf4,mxf6,mxf8} x B={mxf4,mxf6,mxf8} with TN layouts | [unit test](../../test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf6_mxf8_f32_f32.cu) | +|mma.sync.aligned.kind::mxf4.block_scale | 2x Ada Fp8 Tensor Core(4x for FP32 accumulator) | Block scaled MMA with A={mxf4} x B={mxf4} with TN layouts | [unit test](../../test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf4_mxf4_f32_f32.cu) | +|mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::[2X\|4X] | 2x Ada Fp8 Tensor Core(4x for FP32 accumulator) | Block scaled MMA with A={mxf4} x B={mxf4} or A={nvf4} x B={nvf4} with TN layouts | [unit test](../../test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_f32.cu) | + +Besides the similarities, there are some key differences from the Blackwell SM100 GEMMs: + +## Cluster Size + +On Geforce series graphics card, there is no multicast feature therefore the cluster shape is fixed to 1x1x1. + +## Tensor Layout + +Only TN layout is supported. Matrix A is row major and matrix B is column major. + +## Pingpong v.s. cooperative kernel schedule + +Similar to Hopper's warp-group GEMM, SM120 GEMMs support both pingpong and cooperative kernel schedules. Pingpong kernel schedule has two groups of 4 MMA warps working on different output tiles, overlapping the mainloop and epilogue, while the cooperative kernel schedule has only one group of 8 MMA warps working on the same output tile. If `KernelScheduleAuto` is specified, `KernelTmaWarpSpecializedCooperative` will be selected by default. + +## Epilogue schedule: + +`EpilogueScheduleAuto` must be used. + +## Tile size: + +Below are tables that summarize the valid tile shapes and dispatch policies for SM120 GEMMs. If the output is `float_6_t`, the tile size in the leading dimension of output tensor must be 128. + +**Table 16: Valid Tile Shapes and Dispatch Policies for {float8_t, float_6_t, float_4_t} x {float8_t, float_6_t, float_4_t} of SM120 GEMMs** +| Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|----------------|----|----|----|----|------------------------------------| + 64x64x128 | Y | N | N | N | `KernelTmaWarpSpecializedPingpong` or `KernelTmaWarpSpecializedCooperative` | + 64x128x128 | Y | N | N | N | `KernelTmaWarpSpecializedPingpong` or `KernelTmaWarpSpecializedCooperative` | + 128x64x128 | Y | N | N | N | `KernelTmaWarpSpecializedPingpong` or `KernelTmaWarpSpecializedCooperative` | + 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecializedPingpong` or `KernelTmaWarpSpecializedCooperative` | + +**Table 17: Valid Tile Shapes for nv_float4_t x nv_float4_t of SM120 GEMMs** +| Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|----------------|----|----|----|----|------------------------------------| + 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecializedPingpong` or `KernelTmaWarpSpecializedCooperative` | + 256x128x128 | Y | N | N | N | `KernelTmaWarpSpecializedCooperative` | + 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecializedPingpong` or `KernelTmaWarpSpecializedCooperative` | + +**Table 18: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t of SM120 GEMMs** +| Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|----------------|----|----|----|----|------------------------------------| + 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecializedPingpong` or `KernelTmaWarpSpecializedCooperative` | + 256x128x128 | Y | N | N | N | `KernelTmaWarpSpecializedCooperative` | + 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecializedPingpong` or `KernelTmaWarpSpecializedCooperative` | + +**Table 19: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t of SM120 GEMMs** +| Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|----------------|----|----|----|----|------------------------------------| + 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecializedMxf8f6f4Sm120` or `KernelTmaWarpSpecializedPingpongMxf8f6f4Sm120` | + 256x128x128 | Y | N | N | N | `KernelTmaWarpSpecializedMxf8f6f4Sm120` | + 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecializedMxf8f6f4Sm120` or `KernelTmaWarpSpecializedPingpongMxf8f6f4Sm120` | + +Specialized policies must be used to generate mixed-input-datatype `mx_float4_t` kernels. + +**Table 20: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x {mx_float4_t, mx_float6_t, mx_float8_t}** +| Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|----------------|----|----|----|----|------------------------------------| + 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecializedPingpong` or `KernelTmaWarpSpecializedCooperative` | # Copyright diff --git a/media/docs/build/building_in_windows_with_visual_studio.md b/media/docs/cpp/build/building_in_windows_with_visual_studio.md similarity index 98% rename from media/docs/build/building_in_windows_with_visual_studio.md rename to media/docs/cpp/build/building_in_windows_with_visual_studio.md index 7548e7c7..ebadf321 100644 --- a/media/docs/build/building_in_windows_with_visual_studio.md +++ b/media/docs/cpp/build/building_in_windows_with_visual_studio.md @@ -1,5 +1,3 @@ -[README](../../README.md#documentation) > **CUTLASS 3.0: Building on Windows with Visual Studio** - # Building on Windows with Visual Studio CUTLASS 3.2 reintroduces support for the Microsoft Visual Studio compiler on Windows. diff --git a/media/docs/build/building_with_clang_as_host_compiler.md b/media/docs/cpp/build/building_with_clang_as_host_compiler.md similarity index 97% rename from media/docs/build/building_with_clang_as_host_compiler.md rename to media/docs/cpp/build/building_with_clang_as_host_compiler.md index b1cf6815..47b3971d 100644 --- a/media/docs/build/building_with_clang_as_host_compiler.md +++ b/media/docs/cpp/build/building_with_clang_as_host_compiler.md @@ -1,5 +1,3 @@ -[README](../../README.md#documentation) > **CUTLASS 3: Building with Clang as host compiler** - # Building with Clang as host compiler CUTLASS 3.2(.1) reintroduces support for building with diff --git a/media/docs/code_organization.md b/media/docs/cpp/code_organization.md similarity index 96% rename from media/docs/code_organization.md rename to media/docs/cpp/code_organization.md index fff1ce9c..84d9ab0f 100644 --- a/media/docs/code_organization.md +++ b/media/docs/cpp/code_organization.md @@ -1,6 +1,4 @@ -![ALT](../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Code Organization") - -[README](../../README.md#documentation) > **Code Organization** +![ALT](../../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Code Organization") # CUTLASS Code Organization @@ -78,13 +76,13 @@ include/ # Top-level include directory. Client applications * # Core library types such as Shape, Stride, Layout, Tensor, and associated operations ``` -See [Programming Guidelines](/media/docs/programming_guidelines.md) for further details about +See [Programming Guidelines](programming_guidelines.md) for further details about conventions and design patterns used throughout CUTLASS. ## CuTe CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly packages the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. This lets programmers focus on the logical descriptions of their algorithms while CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, implement, and modify all dense linear algebra operations. More documentation -for CuTe can be found in [`/media/docs/cute/`](/media/docs/cute/). +for CuTe can be found in [`cute/`](cute/index). ## Tools @@ -138,7 +136,7 @@ and may be built as follows. $ make cutlass_profiler -j ``` -[Further details about the CUTLASS Profiler are described here.](/media/docs/profiler.md) +[Further details about the CUTLASS Profiler are described here.](profiler.md) ### CUTLASS Utilities @@ -166,7 +164,7 @@ tools/ * ``` -[More details about CUTLASS Utilities may be found here.](/media/docs/utilities.md) +[More details about CUTLASS Utilities may be found here.](utilities.md) ## Examples diff --git a/media/docs/cute/00_quickstart.md b/media/docs/cpp/cute/00_quickstart.md similarity index 79% rename from media/docs/cute/00_quickstart.md rename to media/docs/cpp/cute/00_quickstart.md index c0904528..a14437ae 100644 --- a/media/docs/cute/00_quickstart.md +++ b/media/docs/cpp/cute/00_quickstart.md @@ -30,22 +30,22 @@ and how to launch kernels. CuTe's tests and examples build and run as part of CUTLASS's normal build process. -CuTe's unit tests live in the [`test/unit/cute`](../../../test/unit/cute) subdirectory. +CuTe's unit tests live in the [`test/unit/cute`](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute) subdirectory. -CuTe's examples live in the [`examples/cute`](../../../examples/cute) subdirectory. +CuTe's examples live in the [`examples/cute`](https://github.com/NVIDIA/cutlass/tree/main/examples/cute) subdirectory. ## Library Organization -CuTe is a header-only C++ library, so there is no source code that needs building. Library headers are contained within the top level [`include/cute`](../../../include/cute) directory, with components of the library grouped by directories that represent their semantics. +CuTe is a header-only C++ library, so there is no source code that needs building. Library headers are contained within the top level [`include/cute`](https://github.com/NVIDIA/cutlass/tree/main/include/cute) directory, with components of the library grouped by directories that represent their semantics. | Directory | Contents | |------------------------|------------------------| -| [`include/cute`](../../../include/cute) | Each header in the top level corresponds to one of the fundamental building blocks of CuTe, such as [`Layout`](../../../include/cute/layout.hpp) and [`Tensor`](../../../include/cute/tensor.hpp). | -| [`include/cute/container`](../../../include/cute/container) | Implementations of STL-like objects, such as tuple, array, and aligned array. | -| [`include/cute/numeric`](../../../include/cute/numeric) | Fundamental numeric data types that include nonstandard floating-point types, nonstandard integer types, complex numbers, and integer sequence. | -| [`include/cute/algorithm`](../../../include/cute/algorithm) | Implementations of utility algorithms such as copy, fill, and clear that automatically leverage architecture-specific features if available. | -| [`include/cute/arch`](../../../include/cute/arch) | Wrappers for architecture-specific matrix-matrix multiply and copy instructions. | -| [`include/cute/atom`](../../../include/cute/atom) | Meta-information for instructions in `arch` and utilities like partitioning and tiling. +| [`include/cute`](https://github.com/NVIDIA/cutlass/tree/main/include/cute) | Each header in the top level corresponds to one of the fundamental building blocks of CuTe, such as [`Layout`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/layout.hpp) and [`Tensor`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/tensor.hpp). | +| [`include/cute/container`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/container) | Implementations of STL-like objects, such as tuple, array, and aligned array. | +| [`include/cute/numeric`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/numeric) | Fundamental numeric data types that include nonstandard floating-point types, nonstandard integer types, complex numbers, and integer sequence. | +| [`include/cute/algorithm`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm) | Implementations of utility algorithms such as copy, fill, and clear that automatically leverage architecture-specific features if available. | +| [`include/cute/arch`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/arch) | Wrappers for architecture-specific matrix-matrix multiply and copy instructions. | +| [`include/cute/atom`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom) | Meta-information for instructions in `arch` and utilities like partitioning and tiling. ## Tutorial @@ -103,7 +103,7 @@ if (thread0()) { Some algorithms depend on some thread or threadblock, so you may need to print on threads or threadblocks other than zero. The header file -[`cute/util/debug.hpp`](../../../include/cute/util/debug.hpp), +[`cute/util/debug.hpp`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/util/debug.hpp), among other utilities, includes the function `bool thread(int tid, int bid)` that returns `true` if running on thread `tid` and threadblock `bid`. diff --git a/media/docs/cute/01_layout.md b/media/docs/cpp/cute/01_layout.md similarity index 97% rename from media/docs/cute/01_layout.md rename to media/docs/cpp/cute/01_layout.md index bf4f4f73..72150634 100644 --- a/media/docs/cute/01_layout.md +++ b/media/docs/cpp/cute/01_layout.md @@ -33,17 +33,17 @@ CuTe provides a number of traits to work with integers. * `cute::is_static`: Checks whether `T` is an empty type (so instantiations cannot depend on any dynamic information). Equivalent to `std::is_empty`. * `cute::is_constant`: Checks that `T` is a static integer AND its value is equivalent to `N`. -See the [`integral_constant` implementations](../../../include/cute/numeric/integral_constant.hpp) for more information. +See the [`integral_constant` implementations](https://github.com/NVIDIA/cutlass/tree/main/include/cute/numeric/integral_constant.hpp) for more information. ### Tuple A tuple is a finite ordered list of zero or more elements. -The [`cute::tuple` class](../../../include/cute/container/tuple.hpp) behaves like `std::tuple`, but works on device and host. It imposes restrictions on its template arguments and strips down the implementation for performance and simplicity. +The [`cute::tuple` class](https://github.com/NVIDIA/cutlass/tree/main/include/cute/container/tuple.hpp) behaves like `std::tuple`, but works on device and host. It imposes restrictions on its template arguments and strips down the implementation for performance and simplicity. ### IntTuple CuTe defines the IntTuple concept as either an integer, or a tuple of IntTuples. Note the recursive definition. -In C++, we define [operations on `IntTuple`](../../../include/cute/int_tuple.hpp). +In C++, we define [operations on `IntTuple`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/int_tuple.hpp). Examples of `IntTuple`s include: * `int{2}`, the dynamic integer 2. @@ -53,7 +53,7 @@ Examples of `IntTuple`s include: CuTe reuses the `IntTuple` concept for many different things, including Shape, Stride, Step, and Coord -(see [`include/cute/layout.hpp`](../../../include/cute/layout.hpp)). +(see [`include/cute/layout.hpp`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/layout.hpp)). Operations defined on `IntTuple`s include the following. diff --git a/media/docs/cute/02_layout_algebra.md b/media/docs/cpp/cute/02_layout_algebra.md similarity index 98% rename from media/docs/cute/02_layout_algebra.md rename to media/docs/cpp/cute/02_layout_algebra.md index d8142dbe..f02a98c1 100644 --- a/media/docs/cute/02_layout_algebra.md +++ b/media/docs/cpp/cute/02_layout_algebra.md @@ -17,7 +17,7 @@ In the previous section, we summarized `Layout`s with The `coalesce` operation is a "simplify" on functions from integers to integers. If we only care about input integers, then we can manipulate the shape and number of modes of the `Layout` without changing it as a function. The only thing `coalesce` can't change is the `Layout`'s `size`. -More specifically, you can find the checked post-conditions in [the `coalesce` unit test](../../../test/unit/cute/core/coalesce.cpp), which we'll reproduce here: +More specifically, you can find the checked post-conditions in [the `coalesce` unit test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/core/coalesce.cpp), which we'll reproduce here: ```cpp // @post size(@a result) == size(@a layout) // @post depth(@a result) <= 1 @@ -116,7 +116,7 @@ compatible(B, R) That is, every coordinate of `B` can also be used as a coordinate of `R`. This is an expected property of functional composition because `B` defines the *domain* of `R`. -You can find many examples and checked post-conditions in [the `composition` unit test](../../../test/unit/cute/core/composition.cpp). The post-conditions are precisely as we just stated. +You can find many examples and checked post-conditions in [the `composition` unit test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/core/composition.cpp). The post-conditions are precisely as we just stated. ```cpp // @post compatible(@a layout_b, @a result) // @post for all i, 0 <= i < size(@a layout_b), @a result(i) == @a layout_a(@a layout_b(i))) @@ -289,7 +289,7 @@ Before getting to "product" and "divide," we need one more operation. We can thi The `complement` of a layout attempts to find another layout that represents the "rest" -- the elements that aren't touched by the layout. -You can find many examples and checked post-conditions in [the `complement` unit test](../../../test/unit/cute/core/complement.cpp). The post-conditions include +You can find many examples and checked post-conditions in [the `complement` unit test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/core/complement.cpp). The post-conditions include ```cpp // @post cosize(make_layout(@a layout_a, @a result))) >= size(@a cotarget) // @post cosize(@a result) >= round_up(size(@a cotarget), cosize(@a layout_a)) @@ -309,7 +309,7 @@ The `cotarget` parameter above is most commonly an integer -- you can see we onl ### Complement Examples -`complement` is most effective on static shapes and strides, so consider all integers below to be static. Similar examples for dynamic shapes and strides as well as IntTuple `cotarget` can be found in [the unit test](../../../test/unit/cute/core/complement.cpp). +`complement` is most effective on static shapes and strides, so consider all integers below to be static. Similar examples for dynamic shapes and strides as well as IntTuple `cotarget` can be found in [the unit test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/core/complement.cpp). * `complement(4:1, 24)` is `6:4`. Note that `(4,6):(1,4)` has cosize `24`. The layout `4:1` is effectively repeated 6 times with `6:4`. diff --git a/media/docs/cute/03_tensor.md b/media/docs/cpp/cute/03_tensor.md similarity index 100% rename from media/docs/cute/03_tensor.md rename to media/docs/cpp/cute/03_tensor.md diff --git a/media/docs/cute/04_algorithms.md b/media/docs/cpp/cute/04_algorithms.md similarity index 91% rename from media/docs/cute/04_algorithms.md rename to media/docs/cpp/cute/04_algorithms.md index a00460ab..6b519729 100644 --- a/media/docs/cute/04_algorithms.md +++ b/media/docs/cpp/cute/04_algorithms.md @@ -4,7 +4,7 @@ This section summarizes the interfaces and implementations of common numerical algorithms performed on `Tensor`s. The implementation of these algorithms may be found in the -[include/cute/algorithm/](../../../include/cute/algorithm/) +[include/cute/algorithm/](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/) directory. ## `copy` @@ -12,7 +12,7 @@ directory. CuTe's `copy` algorithm copies the elements of a source `Tensor` into the elements of a destination `Tensor`. The various overloads of `copy` can be found in -[`include/cute/algorithm/copy.hpp`](../../../include/cute/algorithm/copy.hpp). +[`include/cute/algorithm/copy.hpp`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/copy.hpp). ### Interface and specialization opportunities @@ -82,7 +82,7 @@ such as `cp.async`, or its C++ interface `memcpy_async`. In that case, users will need to perform the additional synchronization appropriate to that underlying implementation before they may use the results of the `copy` algorithm. -[The CuTe GEMM tutorial example](../../../examples/cute/tutorial/) +[The CuTe GEMM tutorial example](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/) shows one such synchronization method. More optimized GEMM implementations use pipelining techniques to overlap asynchronous `copy` operations with other useful work. @@ -129,7 +129,7 @@ CuTe's optimized copy implementations can do all of these. ## `copy_if` CuTe's `copy_if` algorithm lives in the same header as `copy`, -[`include/cute/algorithm/copy.hpp`](../../../include/cute/algorithm/copy.hpp). +[`include/cute/algorithm/copy.hpp`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/copy.hpp). The algorithm takes source and destination `Tensor` parameters like `copy`, but it also takes a "predication `Tensor`" with the same shape as the input and output. @@ -195,7 +195,7 @@ for different architectures, please refer to the ## `axpby` The `axpby` algorithm lives in the header file -[`include/cute/algorithm/axpby.hpp`](../../../include/cute/algorithm/axpby.hpp). +[`include/cute/algorithm/axpby.hpp`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/axpby.hpp). It assigns to $y$ the result of $\alpha x + \beta y$, where $\alpha$ and $\beta$ are scalars and $x$ and $y$ are `Tensor`s. The name stands for "Alpha times X Plus Beta times Y," @@ -205,21 +205,21 @@ and is a generalization of the original BLAS "AXPY" routine ## `fill` The `fill` algorithm lives in the header file -[`include/cute/algorithm/fill.hpp`](../../../include/cute/algorithm/fill.hpp). +[`include/cute/algorithm/fill.hpp`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/fill.hpp). It overwrites the elements of its `Tensor` output argument with a given scalar value. ## `clear` The `clear` algorithm lives in the header file -[`include/cute/algorithm/clear.hpp`](../../../include/cute/algorithm/clear.hpp). +[`include/cute/algorithm/clear.hpp`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/clear.hpp). It overwrites the elements of its `Tensor` output argument with zeros. ## Other algorithms CuTe provides other algorithms. Their header files can be found in the -[`include/cute/algorithm`](../../../include/cute/algorithm) +[`include/cute/algorithm`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm) directory. ## Copyright diff --git a/media/docs/cute/0t_mma_atom.md b/media/docs/cpp/cute/0t_mma_atom.md similarity index 96% rename from media/docs/cute/0t_mma_atom.md rename to media/docs/cpp/cute/0t_mma_atom.md index 8896d9b9..d5d8bea8 100644 --- a/media/docs/cute/0t_mma_atom.md +++ b/media/docs/cpp/cute/0t_mma_atom.md @@ -66,7 +66,7 @@ including #### Location of files CuTe provides its Operations structs in the -[`include/cute/arch`](../../../include/cute/arch) +[`include/cute/arch`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/arch) directory, in header files starting with `mma`. #### Operation struct's name @@ -84,7 +84,7 @@ These often include For example, the Volta section below will refer to the `SM70_8x8x4_F32F16F16F32_NT` Operation struct defined in -[`include/cute/arch/mma_sm70.hpp`](../../../include/cute/arch/mma_sm70.hpp). +[`include/cute/arch/mma_sm70.hpp`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/arch/mma_sm70.hpp). * "SM70" refers to Volta. @@ -111,7 +111,7 @@ An Operation struct has the following members. An Operation struct has four public type aliases: `DRegisters`, `ARegisters`, `BRegisters`, and `CRegisters`. For example, the `SM70_8x8x4_F32F16F16F32_NT` Operation struct defined in -[`include/cute/arch/mma_sm70.hpp`](../../../include/cute/arch/mma_sm70.hpp) +[`include/cute/arch/mma_sm70.hpp`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/arch/mma_sm70.hpp) defines these as follows. ```c++ @@ -145,7 +145,7 @@ can still compile, even if the PTX instruction is not available. #### Location of files CuTe provides its Traits structs in the -[`include/cute/atom`](../../../include/cute/atom) +[`include/cute/atom`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom) directory, in header files starting with `mma_traits`. #### Contents @@ -175,7 +175,7 @@ An `MMA_Traits` specialization defines the following public type aliases. The specialization of MMA_Traits for the `SM70_8x8x4_F32F16F16F32_NT` Operation lives in the header file -[`include/cute/atom/mma_traits_sm70.hpp`](../../../include/cute/atom/mma_traits_sm70.hpp). +[`include/cute/atom/mma_traits_sm70.hpp`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/mma_traits_sm70.hpp). It looks like this. ```c++ @@ -254,7 +254,7 @@ Let us look at exactly how the 8 threads within a QP are mapped to the A, B and HMMA.8x8x4.quadpair.C.png

-The metainformation of this single instruction level view is what we want to encode in CuTe. Specifically, the QP level view in this diagram corresponds to the four MMA traits for [SM70_F32F16F16F32](../../../include/cute/arch/mma_sm70.hpp). These structs contain the `Element` types, the `Shape_MNK`, and the `ThrID` mapping we constructed above. Now, let us take a look at the definition of `CLayout`, the thread-data layout of accumulators. The job of `CLayout` is to construct a mapping between the `(logical_thr_id, logical_val_id)` and `(m, n)` coordinate in the C matrix which can then be used to build up more complicated layouts and operations like the 16x16x4 WMMA. +The metainformation of this single instruction level view is what we want to encode in CuTe. Specifically, the QP level view in this diagram corresponds to the four MMA traits for [SM70_F32F16F16F32](https://github.com/NVIDIA/cutlass/tree/main/include/cute/arch/mma_sm70.hpp). These structs contain the `Element` types, the `Shape_MNK`, and the `ThrID` mapping we constructed above. Now, let us take a look at the definition of `CLayout`, the thread-data layout of accumulators. The job of `CLayout` is to construct a mapping between the `(logical_thr_id, logical_val_id)` and `(m, n)` coordinate in the C matrix which can then be used to build up more complicated layouts and operations like the 16x16x4 WMMA. We can start constructing a `CLayout` from the picture above. As with any CuTe layout, it is a pair of `Shape` and corresponding `Stride`. Let us just look at the shape for now. We know that the HMMA uses 8 threads each of which own 8 values. Therefore, the shape of our mapping must have a size of 8 along two modes. With this, we have diff --git a/media/docs/cute/0x_gemm_tutorial.md b/media/docs/cpp/cute/0x_gemm_tutorial.md similarity index 99% rename from media/docs/cute/0x_gemm_tutorial.md rename to media/docs/cpp/cute/0x_gemm_tutorial.md index beb51523..630df5c5 100644 --- a/media/docs/cute/0x_gemm_tutorial.md +++ b/media/docs/cpp/cute/0x_gemm_tutorial.md @@ -1,7 +1,7 @@ # CuTe dense matrix-matrix multiply tutorial In this section, we review -[these examples](../../../examples/cute/tutorial/), +[these examples](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/), which demonstrate a few self-contained, single-file dense matrix-matrix multiply implementations using only CuTe. ## `sgemm_1.cu` @@ -535,7 +535,7 @@ gett(int m0, int m1, int n, int k, ``` Note that the only changes are the definition of shape `M`, the definition of strides `dA` and `dC`, and the definition of the CTA Tiler `bM`. The above uses a multimodel problem shape `M = (m0,m1)` and a multimodal CTA Tiler `bM = <_64,_2>` to change which portion of the global memory tensors `A` and `C` each CTA will be responsible for computing. -Similar examples can be found for CUTLASS 3.x kernels that are based on CuTe, such as [this Hopper GETT example](../../../examples/51_hopper_gett). +Similar examples can be found for CUTLASS 3.x kernels that are based on CuTe, such as [this Hopper GETT example](https://github.com/NVIDIA/cutlass/tree/main/examples/51_hopper_gett). ## Copyright diff --git a/media/docs/cute/0y_predication.md b/media/docs/cpp/cute/0y_predication.md similarity index 100% rename from media/docs/cute/0y_predication.md rename to media/docs/cpp/cute/0y_predication.md diff --git a/media/docs/cute/0z_tma_tensors.md b/media/docs/cpp/cute/0z_tma_tensors.md similarity index 100% rename from media/docs/cute/0z_tma_tensors.md rename to media/docs/cpp/cute/0z_tma_tensors.md diff --git a/media/docs/cpp/cute/index.rst b/media/docs/cpp/cute/index.rst new file mode 100644 index 00000000..a6611dd7 --- /dev/null +++ b/media/docs/cpp/cute/index.rst @@ -0,0 +1,17 @@ +.. _cpp_cute: + +CuTe +==================== + +.. toctree:: + :maxdepth: 2 + + 00_quickstart<00_quickstart.md> + 01_layout<01_layout.md> + 02_layout_algebra<02_layout_algebra.md> + 03_tensor<03_tensor.md> + 04_algorithms<04_algorithms.md> + 0t_mma_atom<0t_mma_atom.md> + 0x_gemm_tutorial<0x_gemm_tutorial.md> + 0y_predication<0y_predication.md> + 0z_tma_tensors<0z_tma_tensors.md> diff --git a/media/docs/cutlass_3x_backwards_compatibility.md b/media/docs/cpp/cutlass_3x_backwards_compatibility.md similarity index 94% rename from media/docs/cutlass_3x_backwards_compatibility.md rename to media/docs/cpp/cutlass_3x_backwards_compatibility.md index 85eca7d6..1dc42ef7 100644 --- a/media/docs/cutlass_3x_backwards_compatibility.md +++ b/media/docs/cpp/cutlass_3x_backwards_compatibility.md @@ -1,5 +1,3 @@ -[README](../../README.md#documentation) > **CUTLASS 3.0 GEMM Backwards Compatibility** - # CUTLASS 3.0 GEMM Backwards Compatibility Although CUTLASS 3.0 restructures the GEMM hierarchy and introduces new types for the @@ -16,7 +14,7 @@ The entry point for CUTLASS's Device GEMM API is the class `cutlass::gemm::device::GemmUniversalAdapter`. This class lives in the header file -[include/cutlass/gemm/device/gemm_universal_adapter.h](/include/cutlass/gemm/device/gemm_universal_adapter.h). +[include/cutlass/gemm/device/gemm_universal_adapter.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/device/gemm_universal_adapter.h). `GemmUniversalAdapter` is a "universal adapter" and serves as a common device interface @@ -89,7 +87,7 @@ and a collective epilogue. The entry point for CUTLASS's kernel API is the class `cutlass::gemm::kernel::GemmUniversal`. This class' declaration lives in the header file -[include/cutlass/gemm/kernel/gemm_universal.hpp](/include/cutlass/gemm/kernel/gemm_universal.hpp). +[include/cutlass/gemm/kernel/gemm_universal.hpp](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/gemm_universal.hpp). ```c++ /* @@ -128,11 +126,11 @@ Each kernel layer schedule is specialized for a GEMM scheduling algorithm and GPU architecture. Specializations of `kernel::GemmUniversal` for 3.0 APIs live in any of various `gemm_*.hpp` files in the directory -[include/cutlass/gemm/kernel/](../../include/cutlass/gemm/kernel/). +[include/cutlass/gemm/kernel/](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/). The specialization to which to dispatch is decided through the dispatch policy's `Schedule` type. Specializations for 2.x APIs live in the header file -[include/cutlass/gemm/kernel/gemm_universal.h](../../include/cutlass/gemm/kernel/gemm_universal.h). +[include/cutlass/gemm/kernel/gemm_universal.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/gemm_universal.h). ### Kernel API design differences @@ -204,7 +202,7 @@ if they wish to author custom mainloop code in the 3.x API. Similarly, for the GEMM inner loops, `cute::MMA_Atom`s replace the `gemm::warp` and `gemm::thread` layer code. Going forward, all new PTX instructions -and associated metadata development will occur directly inside [`cute/arch/*.hpp`](/include/cute/arch/) and [`cute/atom/*.hpp`](/include/cute/atom/). +and associated metadata development will occur directly inside [`cute/arch/*.hpp`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/arch/) and [`cute/atom/*.hpp`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/). The desired inner loop MMA iteration order and tiling can be achieved through careful selection of the atom layout, value layout, and permutations of the `cute::TiledMma`. @@ -212,7 +210,7 @@ selection of the atom layout, value layout, and permutations of the `cute::Tiled For epilogues, the `cutlass::epilogue::collective` layer replaces `cutlass::threadblock::collective`. However, the thread-level epilogue elementwise operations in `cutlass::epilogue::thread` will continue to be used in 3.x kernels as well, albeit, with a more idiomatic epilogue vectorization strategy. -[Example 50](/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu) +[Example 50](https://github.com/NVIDIA/cutlass/tree/main/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu) shows how to use 2.x epilogue thread operators with 3.0 API kernels. ## Porting from 2.x to 3.0 API @@ -271,7 +269,7 @@ For the matrix B, CUTLASS 2.x defines "layout tag" classes `cutlass::layout::ColumnMajor` and `cutlass::layout::RowMajor`, that live in the header file -[`cutlass/layout/matrix.h`](/include/cutlass/layout/matrix.h). +[`cutlass/layout/matrix.h`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/layout/matrix.h). The interpretation of these layouts in GEMM depends on whether they are applied to the input matrix A or B. For the matrix A, "column major" means @@ -304,7 +302,7 @@ whether we are talking about the A or B matrix. M and N major inputs always have static size-1 stride in their 0th (outer) mode. Similarly, K major inputs always contain the static size-1 stride in their 1st mode. This uniformity in stride order allows us to represent tensor layouts much more cleanly and treat both A and B equally in our interfaces. -See for example the following snippet from our [`kernel/sm70_gemm.hpp`](/include/cutlass/gemm/kernel/sm70_gemm.hpp) +See for example the following snippet from our [`kernel/sm70_gemm.hpp`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm70_gemm.hpp) for Ampere kernel schedules. ```c++ @@ -352,7 +350,7 @@ dynamic stride modes corresponding to the minor mode and the batch mode. Batch mode is included by default as all CUTLASS 3.0 kernels support packed batch-mode GEMMs out of the box. -The [`cutlass/gemm/gemm.h#440`](../../include/cutlass/gemm/gemm.h#440) +The [`cutlass/gemm/gemm.h#440`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/gemm.h#440) header file includes functions that can be useful for converting from CUTLASS 3.0 `cute::Stride`s back to CUTLASS 2.x layout tags. @@ -375,7 +373,7 @@ these 2.x reflective types from an assembled kernel with a more stable API, the specialization of `cutlass::gemm::device::GemmUniversalAdapter` for CUTLASS 3.0 kernel provides all aliases for all 2.x type aliases in addition to the layout tags. You can see how they are used in the header file -[`cutlass/gemm/device/gemm_universal_adapter.h`](/include/cutlass/gemm/device/gemm_universal_adapter.h). +[`cutlass/gemm/device/gemm_universal_adapter.h`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/device/gemm_universal_adapter.h). Here is an excerpt. ```c++ diff --git a/media/docs/cutlass_3x_design.md b/media/docs/cpp/cutlass_3x_design.md similarity index 95% rename from media/docs/cutlass_3x_design.md rename to media/docs/cpp/cutlass_3x_design.md index 54d6c35c..b1eed530 100644 --- a/media/docs/cutlass_3x_design.md +++ b/media/docs/cpp/cutlass_3x_design.md @@ -1,5 +1,3 @@ -[README](../../README.md#documentation) > **CUTLASS 3.0 Design and Hierarchy** - # CUTLASS 3.0 Design CUTLASS 3.0 is a major enhancement over the abstractions of CUTLASS 2.x @@ -29,7 +27,7 @@ CUTLASS 3.0 has the following design goals, in no particular order. CUTLASS 2.x decomposes the moving parts of a GEMM operation across a hierarchy that closely mirrors the organization of GPU architectures. This discussed in detail within the -[CUTLASS 2.x GEMM API documentation](/media/docs/gemm_api.md). +[CUTLASS 2.x GEMM API documentation](gemm_api.md). This design, however, sometimes results in a coupling that is too tight to extend to newer GPU features that might not fit into the same architectural hierarchy. For instance, Hopper's warp-group wide instructions do not naturally @@ -46,7 +44,7 @@ with a consistent interface to hardware acceleration regardless of the architecture specific details. The new conceptual GEMM hierarchy is discussed in detail in the dedicated -[CUTLASS 3.0 GEMM API documentation readme](/media/docs/gemm_api_3x.md), +[CUTLASS 3.0 GEMM API documentation readme](gemm_api_3x.md), along with code examples of the core concepts and types. ## Adoption of CuTe Layout and Tensors @@ -55,9 +53,9 @@ CUTLASS 3.0 introduces a new core library, CuTe, to describe and manipulate tens CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly packages the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. CUTLASS 3.0 adopts CuTe throughout the GEMM hierarchy in its templates, greatly simplifying the design, -improving code composability, and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](/media/docs/cute/00_quickstart.md). +improving code composability, and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](cute/00_quickstart.md). -![CuTe helps reduce named iterator types down to a single vocabulary type, `Layout`](/media/images/cutlass-reduction-in-named-iterators.png) +![CuTe helps reduce named iterator types down to a single vocabulary type, `Layout`](../../images/cutlass-reduction-in-named-iterators.png) Programming massively parallel systems with various layers of logical thread and data hierarchies is not a trivial task. diff --git a/media/docs/dependent_kernel_launch.md b/media/docs/cpp/dependent_kernel_launch.md similarity index 93% rename from media/docs/dependent_kernel_launch.md rename to media/docs/cpp/dependent_kernel_launch.md index a5d0a514..1beb8bf7 100644 --- a/media/docs/dependent_kernel_launch.md +++ b/media/docs/cpp/dependent_kernel_launch.md @@ -1,5 +1,3 @@ -[README](../../README.md#documentation) > **Dependent kernel launch** - # Dependent kernel launches The Hopper and Blackwell architectures supports a new feature through which two kernels in the same stream can @@ -37,11 +35,11 @@ gemm.run( ``` ## Model-Aware Optimizations with PDL -In [example 63](../../examples/63_hopper_gemm_with_weight_prefetch/README.md), we use PDL to explicitly optimize for +In [example 63](https://github.com/NVIDIA/cutlass/tree/main/examples/63_hopper_gemm_with_weight_prefetch/README.md), we use PDL to explicitly optimize for performance of kernels where we know that one of the input matricies (our weights) will not be produced by a prior kernel. In that case, we only need to wait on the prior kernels memory flush in order to load the other input matrix (our activations). During our prologue, we can prefetch our weights to improve performance for memory bandwidth-bound -problem sizes. For more informations we refer the reader to [the example](../../examples/63_hopper_gemm_with_weight_prefetch/README.md). +problem sizes. For more informations we refer the reader to [the example](https://github.com/NVIDIA/cutlass/tree/main/examples/63_hopper_gemm_with_weight_prefetch/README.md). ## Copyright diff --git a/media/docs/doxygen_mainpage.md b/media/docs/cpp/doxygen_mainpage.md similarity index 97% rename from media/docs/doxygen_mainpage.md rename to media/docs/cpp/doxygen_mainpage.md index 1ff521ac..c9f9dc9a 100644 --- a/media/docs/doxygen_mainpage.md +++ b/media/docs/cpp/doxygen_mainpage.md @@ -33,7 +33,7 @@ to CUTLASS 3.0, please refer to the For a code example showing how to define a GEMM kernel using CUTLASS, please refer to [the quickstart guide](./quickstart.md). -The [`examples` directory](../../examples) +The [`examples` directory](https://github.com/NVIDIA/cutlass/tree/main/examples) has a variety of examples. # Copyright diff --git a/media/docs/efficient_gemm.md b/media/docs/cpp/efficient_gemm.md similarity index 82% rename from media/docs/efficient_gemm.md rename to media/docs/cpp/efficient_gemm.md index 470c4eee..771d24db 100644 --- a/media/docs/efficient_gemm.md +++ b/media/docs/cpp/efficient_gemm.md @@ -1,6 +1,4 @@ -![ALT](../images/gemm-hierarchy-with-epilogue-no-labels.png "Efficient GEMM in CUDA") - -[README](../../README.md#documentation) > **Efficient GEMM in CUDA** +![ALT](../../images/gemm-hierarchy-with-epilogue-no-labels.png "Efficient GEMM in CUDA") # Efficient GEMM in CUDA @@ -60,7 +58,7 @@ This is the hierarchical GEMM computation embodied by CUTLASS. Each stage depict nested level of tiling which corresponds to a layer of concurrency within the CUDA execution model and to a level within the memory hierarchy, becoming increasingly finer moving left to right. -![ALT](../images/gemm-hierarchy-with-epilogue.png "Hierarchical GEMM in CUDA") +![ALT](../../images/gemm-hierarchy-with-epilogue.png "Hierarchical GEMM in CUDA") ### Threadblock-level GEMM @@ -154,7 +152,7 @@ following scopes. The following diagram illustrates the efficient, pipelined mainloop body used in CUTLASS GEMMs. -![ALT](../images/software-pipeline.png "Software pipeline in CUTLASS") +![ALT](../../images/software-pipeline.png "Software pipeline in CUTLASS") ### Threadblock Rasterization @@ -164,7 +162,7 @@ consecutively launched threadblocks to packed two-dimensional regions of the par problem to increase the probability that these will access the same tiles of global memory at approximately the same time. -Several functions are defined in [cutlass/gemm/threadblock_swizzle.h](../../include/cutlass/gemm/threadblock/threadblock_swizzle.h). +Several functions are defined in [cutlass/gemm/threadblock_swizzle.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/threadblock/threadblock_swizzle.h). ### Parallelized Reductions @@ -226,26 +224,26 @@ to the Hopper kernel design. Blackwell SM100 kernels have a substantially differ however, the concept of separating out producer and consumer agents still applies. Starting with Hopper, CUTLASS 3.0 incorporates the concept of [Warp Specialization](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#spatial-partitioning-also-known-as-warp-specialization) -as part of the kernel design. A thread block is partitioned into two sets of warps, [*producer* warp group](../../include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [*consumer* warp group](../../include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp). The *producer* warp group loads data from global memory into shared memory buffers using the new [Tensor Memory Accelerator (TMA)](https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/). +as part of the kernel design. A thread block is partitioned into two sets of warps, [*producer* warp group](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [*consumer* warp group](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp). The *producer* warp group loads data from global memory into shared memory buffers using the new [Tensor Memory Accelerator (TMA)](https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/). -[*Producer* warp group (DMA)](../../include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) waits for the shared memory buffers to be signaled as [empty](../../include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) by the *consumer* warp group using the newly added **Async Pipeline class** ([refer](pipeline.md)). Once the data is written into the shared memory, TMA is also updates the barrier associated with that stage to notify affected threads that the buffer has been [filled](../../include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp). The [*Consumer* warp group (MMA)](../../include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) on the other hand waits for the *producer* warp group to signal that the buffer is [filled](../../include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) and then launches tensor core MMA operations. Finally, the *consumer* warp group [releases](../../include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) the buffers for the next set of TMA loads to happens. +[*Producer* warp group (DMA)](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) waits for the shared memory buffers to be signaled as [empty](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) by the *consumer* warp group using the newly added **Async Pipeline class** ([refer](pipeline.md)). Once the data is written into the shared memory, TMA is also updates the barrier associated with that stage to notify affected threads that the buffer has been [filled](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp). The [*Consumer* warp group (MMA)](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) on the other hand waits for the *producer* warp group to signal that the buffer is [filled](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) and then launches tensor core MMA operations. Finally, the *consumer* warp group [releases](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) the buffers for the next set of TMA loads to happens. **Warp-Specialized Persistent Cooperative kernel design** -Another flavor of Warp-Specialized kernel design being introduced starting with Hopper is the [*Warp-Specialized Persistent Cooperative*](../../include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel. Like the Warp-Specialized kernel, the concepts of warp groups and barrier synchronization between warp groups remain the same in the cooperative design. +Another flavor of Warp-Specialized kernel design being introduced starting with Hopper is the [*Warp-Specialized Persistent Cooperative*](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel. Like the Warp-Specialized kernel, the concepts of warp groups and barrier synchronization between warp groups remain the same in the cooperative design. The distinctive feature of the Warp-Specialized Persistent Cooperative kernel are the following : -* Persistent thread blocks launched to occupy as many SMs as mentioned in the [KernelHardwareInfo](../../include/cutlass/kernel_hardware_info.hpp) struct. These persistent thread blocks are used to tile the output and thus (potentially) compute multiple output tiles through their lifetime. The main benefit this adds is amortization of the thread-block launch and kernel prologue overheads which are typical of all kernels. +* Persistent thread blocks launched to occupy as many SMs as mentioned in the [KernelHardwareInfo](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/kernel_hardware_info.hpp) struct. These persistent thread blocks are used to tile the output and thus (potentially) compute multiple output tiles through their lifetime. The main benefit this adds is amortization of the thread-block launch and kernel prologue overheads which are typical of all kernels. * Presence of two *consumer* warp groups cooperating on the same output tile by splitting the tile in half across the M dimension. This allows for larger tile sizes to be enabled - since the register pressure per *consumer* warp group is reduced - and hence improving performance. -Since each thread block now computes multiple output tiles, the shape of the grid launch and the scheduling of tiles to the thread blocks is managed using the new [*Tile Scheduler*](../../include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp). The *Tile Scheduler* considers the shape of the *clusters* as well as the available number of available SMs to compute a valid scheduling of the output tiles to launched thread blocks. +Since each thread block now computes multiple output tiles, the shape of the grid launch and the scheduling of tiles to the thread blocks is managed using the new [*Tile Scheduler*](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp). The *Tile Scheduler* considers the shape of the *clusters* as well as the available number of available SMs to compute a valid scheduling of the output tiles to launched thread blocks. **Warp-Specialized Persistent Ping-Pong kernel design** -The third kernel design is the [*Warp-Specialized Persistent Ping-Pong*](../../include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel. +The third kernel design is the [*Warp-Specialized Persistent Ping-Pong*](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel. Like the Warp-Specialized Persistent Cooperative, kernel the concepts of warp groups, barrier synchronization between warp groups, and the shape of the grid launch remain the same in the persistent ping-pong design. The distinctive feature of the Warp-Specialized Persistent Ping-Pong kernel is the following : * The two *consumer* warp groups are assigned a different output tile using the Tile Scheduler. This allows for *epilogue* of one *consumer* warp group to be overlapped with the math operations of the other *consumer* warp group - thus maximizing tensor core utilization. -* The *producer* warp group synchronizes using the [Ordered Sequence Barrier](../../include/cutlass/pipeline/pipeline.hpp) to fill buffers of the two *consumer* warp groups one after the other in order. +* The *producer* warp group synchronizes using the [Ordered Sequence Barrier](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/pipeline/pipeline.hpp) to fill buffers of the two *consumer* warp groups one after the other in order. # Resources diff --git a/media/docs/functionality.md b/media/docs/cpp/functionality.md similarity index 71% rename from media/docs/functionality.md rename to media/docs/cpp/functionality.md index 274bba62..396db1fe 100644 --- a/media/docs/functionality.md +++ b/media/docs/cpp/functionality.md @@ -1,17 +1,15 @@ -![ALT](../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Functionality") +![ALT](../../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Functionality") -[README](../../README.md#documentation) > **Functionality** # Functionality Note : CUTLASS-3 requires users to use CUDA 11.4 or newer, and SM70 or newer, for the target toolkit and architecture, respectively. -Please refer to the [Compatibility](/README.md#Compatibility) section for more details. - N - Column Major Matrix - T - Row Major matrix - {N,T} x {N,T} - All combinations, i.e., NN, NT, TN, TT -- [NHWC](/include/cutlass/layout/tensor.h#L63-206) - 4 dimension tensor used for convolution -- [NCxHWx](/include/cutlass/layout/tensor.h#L290-395) - Interleaved 4 dimension tensor used for convolution +- [NHWC](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/layout/tensor.h#L63-206) - 4 dimension tensor used for convolution +- [NCxHWx](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/layout/tensor.h#L290-395) - Interleaved 4 dimension tensor used for convolution - f - floating point - s - signed int - b - bit @@ -32,48 +30,48 @@ Hyperlinks to relevant unit tests demonstrate how specific template instances ma |**Opcode Class** | **Compute Capability** | **CUDA Toolkit** | **Data Type** | **Layouts** | **Unit Test** | |-----------------|------------------------|------------------|--------------------------------|------------------------|------------------| -| **TensorOp** | 90a | 12.0+ | `f16 * f16 + { f16, f32 } => { f16, f32 }` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu) | -| **TensorOp** | 90a | 12.0+ | `bf16 * bf16 + { f16, f32 } => { bf16, f32 }`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu) | -| **TensorOp** | 90a | 12.0+ | `{f32, tf32} * {f32, tf32} + f32 => f32`| { T } x { N } => {N,T} | [example](/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu) | -| **TensorOp** | 90a | 12.0+ | `s8 * s8 + s32 => {s32, s8}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32.cu) | +| **TensorOp** | 90a | 12.0+ | `f16 * f16 + { f16, f32 } => { f16, f32 }` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu) | +| **TensorOp** | 90a | 12.0+ | `bf16 * bf16 + { f16, f32 } => { bf16, f32 }`| {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu) | +| **TensorOp** | 90a | 12.0+ | `{f32, tf32} * {f32, tf32} + f32 => f32`| { T } x { N } => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu) | +| **TensorOp** | 90a | 12.0+ | `s8 * s8 + s32 => {s32, s8}` | { T } x { N } => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32.cu) | ### CUTLASS 2.x Kernels |**Opcode Class** | **Compute Capability** | **CUDA Toolkit** | **Data Type** | **Layouts** | **Unit Test** | |-----------------|------------------------|------------------|--------------------------------|------------------------|------------------| -| **Simt** | 50+ | 11.4+ | `f32 * f32 + f32 => f32` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/simt_sgemm_nt_sm50.cu) | -| **Simt** | 50+ | 11.4+ | `f64 * f64 + f64 => f64` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/simt_dgemm_nt_sm50.cu) | -| **Simt** | 60+ | 11.4+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/simt_hgemm_nt_sm50.cu) | -| **Simt** | 61+ | 11.4+ | `s8 * s8 + s32 => {s32,s8}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/simt_igemm_nt_sm50.cu) | -| **WmmaTensorOp** | 70+ | 11.4+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f16_sm70.cu) | -| **WmmaTensorOp** | 70+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f32_sm70.cu) | -| **WmmaTensorOp** | 75+ | 11.4+ | `s8 * s8 + s32 => {s32, s8}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu) | -| **WmmaTensorOp** | 75+ | 11.4+ | `s4 * s4 + s32 => {s32, s4}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_s4t_s4n_s4t_wmma_tensor_op_s32_sm75.cu) | -| **WmmaTensorOp** | 75+ | 11.4+ | `b1 ^ b1 + s32 => {s32, b1}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_b1t_b1n_b1t_wmma_tensor_op_s32_sm75.cu) | -| **TensorOp** | 70+ | 11.4+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f16_sm70.cu) | -| **TensorOp** | 70+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f32_sm70.cu) | -| **TensorOp** | 75+ | 11.4+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu) | -| **TensorOp** | 75+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm75.cu) | -| **TensorOp** | 75+ | 11.4+ | `s8 * s8 + s32 => {s32, s8}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu) | -| **TensorOp** | 75+ | 11.4+ | `s4 * s4 + s32 => {s32, s4}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu) | -| **TensorOp** | 75+ | 11.4+ | `b1 ^ b1 + s32 => {s32, b1}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu) | -| **TensorOp** | 80+ | 11.4+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu) | -| **TensorOp** | 80+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu) | -| **TensorOp** | 80+ | 11.4+ | `bf16 * bf16 + f32 => {bf16, f32}`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_bf16n_bf16t_bf16t_tensor_op_f32_sm80.cu) | -| **TensorOp** | 80+ | 11.4+ | `tf32 * tf32 + f32 => f32`| {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sm80.cu) | -| **TensorOp** | 80+ | 11.4+ | `s8 * s8 + s32 => {s32, s8}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu) | -| **TensorOp** | 80+ | 11.4+ | `s4 * s4 + s32 => {s32, s4}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu) | -| **TensorOp** | 80+ | 11.4+ | `b1 ^ b1 + s32 => {s32, b1}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu) | -| **TensorOp** | 80+ | 11.4+ | `f64 * f64 + f64 => f64` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu) | -| **TensorOp** | 80+ | 11.4+ | `cf32 * cf32 + cf32 => cf32` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu) | -| **TensorOp** | 80+ | 11.4+ | `cf64 * cf64 + cf64 => cf64` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu), [Gaussian 3m](/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu) | -| **SpTensorOp** | 80+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu) | -| **SpTensorOp** | 80+ | 11.4+ | `bf16 * bf16 + f32 => {bf16, f32}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu) | -| **SpTensorOp** | 80+ | 11.4+ | `tf32 * tf32 + f32 => f32` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu) | -| **SpTensorOp** | 80+ | 11.4+ | `s8 * s8 + s32 => {s8, s32}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu) | -| **SpTensorOp** | 80+ | 11.4+ | `s4 * s4 + s32 => {s4, s32}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu) | -| **TensorOp** | 90+ | 11.8+ | `f64 * f64 + f64 => f64` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) | +| **Simt** | 50+ | 11.4+ | `f32 * f32 + f32 => f32` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/simt_sgemm_nt_sm50.cu) | +| **Simt** | 50+ | 11.4+ | `f64 * f64 + f64 => f64` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/simt_dgemm_nt_sm50.cu) | +| **Simt** | 60+ | 11.4+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/simt_hgemm_nt_sm50.cu) | +| **Simt** | 61+ | 11.4+ | `s8 * s8 + s32 => {s32,s8}` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/simt_igemm_nt_sm50.cu) | +| **WmmaTensorOp** | 70+ | 11.4+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f16_sm70.cu) | +| **WmmaTensorOp** | 70+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f32_sm70.cu) | +| **WmmaTensorOp** | 75+ | 11.4+ | `s8 * s8 + s32 => {s32, s8}` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu) | +| **WmmaTensorOp** | 75+ | 11.4+ | `s4 * s4 + s32 => {s32, s4}` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s4t_s4n_s4t_wmma_tensor_op_s32_sm75.cu) | +| **WmmaTensorOp** | 75+ | 11.4+ | `b1 ^ b1 + s32 => {s32, b1}` | { T } x { N } => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_b1t_b1n_b1t_wmma_tensor_op_s32_sm75.cu) | +| **TensorOp** | 70+ | 11.4+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f16_sm70.cu) | +| **TensorOp** | 70+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f32_sm70.cu) | +| **TensorOp** | 75+ | 11.4+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu) | +| **TensorOp** | 75+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm75.cu) | +| **TensorOp** | 75+ | 11.4+ | `s8 * s8 + s32 => {s32, s8}` | { T } x { N } => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu) | +| **TensorOp** | 75+ | 11.4+ | `s4 * s4 + s32 => {s32, s4}` | { T } x { N } => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu) | +| **TensorOp** | 75+ | 11.4+ | `b1 ^ b1 + s32 => {s32, b1}` | { T } x { N } => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu) | +| **TensorOp** | 80+ | 11.4+ | `f16 * f16 + f16 => f16` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `bf16 * bf16 + f32 => {bf16, f32}`| {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_bf16n_bf16t_bf16t_tensor_op_f32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `tf32 * tf32 + f32 => f32`| {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `s8 * s8 + s32 => {s32, s8}` | { T } x { N } => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `s4 * s4 + s32 => {s32, s4}` | { T } x { N } => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `b1 ^ b1 + s32 => {s32, b1}` | { T } x { N } => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `f64 * f64 + f64 => f64` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `cf32 * cf32 + cf32 => cf32` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `cf64 * cf64 + cf64 => cf64` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu), [Gaussian 3m](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu) | +| **SpTensorOp** | 80+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu) | +| **SpTensorOp** | 80+ | 11.4+ | `bf16 * bf16 + f32 => {bf16, f32}` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu) | +| **SpTensorOp** | 80+ | 11.4+ | `tf32 * tf32 + f32 => f32` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu) | +| **SpTensorOp** | 80+ | 11.4+ | `s8 * s8 + s32 => {s8, s32}` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu) | +| **SpTensorOp** | 80+ | 11.4+ | `s4 * s4 + s32 => {s4, s32}` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu) | +| **TensorOp** | 90+ | 11.8+ | `f64 * f64 + f64 => f64` | {N,T} x {N,T} => {N,T} | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) | ## Device-level Implicit GEMM convolution @@ -84,19 +82,19 @@ One can find and/or create equivalent dgrad and wgrad convolutional operators. |**Opcode Class** | **Compute Capability** | **CUDA Toolkit** | **Data Type** | **Layouts** | **Unit Test** | |-----------------|------------------------|------------------|--------------------------------|------------------|------------------| -| **Simt** | 50+ | 11.4+ | `f32 * f32 + f32 => f32` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm50.cu) | -| **Simt** | 50+ | 11.4+ | `cf32 * cf32 + cf32 => cf32` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu) | -| **TensorOp** | 70+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu) | -| **TensorOp** | 75+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu) | -| **TensorOp** | 75+ | 11.4+ | `s8 * s8 + s32 => {s32, s8}` | NHWC, NCxHWx | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm75.cu), [ncxhwx](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.cu) | -| **TensorOp** | 75+ | 11.4+ | `s4 * s4 + s32 => {s32, s4}` | NHWC, NCxHWx | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu), [ncxhwx](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm75.cu) | -| **Simt** | 80+ | 11.4+ | `f32 * f32 + f32 => f32` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu) | -| **Simt** | 80+ | 11.4+ | `cf32 * cf32 + cf32 => cf32` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu) | -| **TensorOp** | 80+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) | -| **TensorOp** | 80+ | 11.4+ | `f16 * f16 + f16 => f16` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) | -| **TensorOp** | 80+ | 11.4+ | `tf32 * tf32 + f32 => f32` | NHWC | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu) | -| **TensorOp** | 80+ | 11.4+ | `s8 * s8 + s32 => {s32, s8}` | NHWC, NCxHWx | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm80.cu), [ncxhwx](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm80.cu) | -| **TensorOp** | 80+ | 11.4+ | `s4 * s4 + s32 => {s32, s4}` | NHWC, NCxHWx | [example](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm80.cu), [ncxhwx](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm80.cu) | +| **Simt** | 50+ | 11.4+ | `f32 * f32 + f32 => f32` | NHWC | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm50.cu) | +| **Simt** | 50+ | 11.4+ | `cf32 * cf32 + cf32 => cf32` | NHWC | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu) | +| **TensorOp** | 70+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| NHWC | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu) | +| **TensorOp** | 75+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| NHWC | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu) | +| **TensorOp** | 75+ | 11.4+ | `s8 * s8 + s32 => {s32, s8}` | NHWC, NCxHWx | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm75.cu), [ncxhwx](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.cu) | +| **TensorOp** | 75+ | 11.4+ | `s4 * s4 + s32 => {s32, s4}` | NHWC, NCxHWx | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu), [ncxhwx](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm75.cu) | +| **Simt** | 80+ | 11.4+ | `f32 * f32 + f32 => f32` | NHWC | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu) | +| **Simt** | 80+ | 11.4+ | `cf32 * cf32 + cf32 => cf32` | NHWC | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `f16 * f16 + f32 => {f16, f32}`| NHWC | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `f16 * f16 + f16 => f16` | NHWC | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `tf32 * tf32 + f32 => f32` | NHWC | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `s8 * s8 + s32 => {s32, s8}` | NHWC, NCxHWx | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm80.cu), [ncxhwx](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm80.cu) | +| **TensorOp** | 80+ | 11.4+ | `s4 * s4 + s32 => {s32, s4}` | NHWC, NCxHWx | [example](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm80.cu), [ncxhwx](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm80.cu) | diff --git a/media/docs/fundamental_types.md b/media/docs/cpp/fundamental_types.md similarity index 99% rename from media/docs/fundamental_types.md rename to media/docs/cpp/fundamental_types.md index 3bfc4453..b29fb5bf 100644 --- a/media/docs/fundamental_types.md +++ b/media/docs/cpp/fundamental_types.md @@ -1,6 +1,4 @@ -![ALT](../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS") - -[README](../../README.md#documentation) > **Fundamental Types** +![ALT](../../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS") # Fundamental Types diff --git a/media/docs/gemm_api.md b/media/docs/cpp/gemm_api.md similarity index 90% rename from media/docs/gemm_api.md rename to media/docs/cpp/gemm_api.md index e2aaaccb..fd8ecf5e 100644 --- a/media/docs/gemm_api.md +++ b/media/docs/cpp/gemm_api.md @@ -1,6 +1,4 @@ -![ALT](../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS GEMM API") - -[README](../../README.md#documentation) > **CUTLASS GEMM API** +![ALT](../../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS GEMM API") # CUTLASS GEMM API @@ -69,7 +67,7 @@ thread-level concurrency. This loop nest is expressed in CUTLASS via the following components which are specialized for data type, layout, and math instruction. -![ALT](/media/images/cutlass-gemm-components.png "CUTLASS GEMM Components") +![ALT](../../images/cutlass-gemm-components.png "CUTLASS GEMM Components") These components are described in the following sections. @@ -80,10 +78,10 @@ GEMM computation across the GPU. This operator is intended to be used in host-si has semantics similar to cuBLAS. The device-wide GEMM API is embodied by the following operators: -- [cutlass::gemm::device::Gemm](/include/cutlass/gemm/device/gemm.h) - basic GEMM operation -- [cutlass::gemm::device::GemmArray](/include/cutlass/gemm/device/gemm_array.h) - batched GEMM operation in which input matrices are read from arrays of pointers -- [cutlass::gemm::device::GemmBatched](/include/cutlass/gemm/device/gemm_batched.h) - batched GEMM operation in which input matrices are separated by a constant stride -- [cutlass::gemm::device::GemmSplitKParallel](/include/cutlass/gemm/device/gemm_splitk_parallel.h) - GEMM operation that partitions the GEMM K dimension then launches a separate reduction kernel +- [cutlass::gemm::device::Gemm](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/device/gemm.h) - basic GEMM operation +- [cutlass::gemm::device::GemmArray](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/device/gemm_array.h) - batched GEMM operation in which input matrices are read from arrays of pointers +- [cutlass::gemm::device::GemmBatched](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/device/gemm_batched.h) - batched GEMM operation in which input matrices are separated by a constant stride +- [cutlass::gemm::device::GemmSplitKParallel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/device/gemm_splitk_parallel.h) - GEMM operation that partitions the GEMM K dimension then launches a separate reduction kernel **Example:** launch a mixed-precision GEMM targeting Volta Tensor Cores. ```c++ @@ -127,14 +125,14 @@ GEMMs at this scope are expected to efficiently load tiles of data from global m products with warp-level GEMM operators. The threadblock-scoped matrix multiply operation is embodied by -[cutlass::gemm::threadblock::MmaPipelined](/include/cutlass/gemm/threadblock/mma_pipelined.h). +[cutlass::gemm::threadblock::MmaPipelined](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/threadblock/mma_pipelined.h). This is a class inspired by [std::transform_reduce()](https://en.cppreference.com/w/cpp/algorithm/transform_reduce) which computes the accumulated matrix product of a range of tiles defined by tile iterators. -![ALT](/media/images/cutlass-threadblock-mma-pipelined.png "cutlass::gemm::threadblock::MmaPipelined") +![ALT](../../images/cutlass-threadblock-mma-pipelined.png "cutlass::gemm::threadblock::MmaPipelined") In the case of GEMM, the tile iterators are -[cutlass::transform::threadblock::PredicatedTileIterator](/include/cutlass/transform/threadblock/predicated_tile_iterator.h) +[cutlass::transform::threadblock::PredicatedTileIterator](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/transform/threadblock/predicated_tile_iterator.h) to traverse a sequence of tiles in global memory with appropriate predication to avoid out-of-bounds memory accesses. @@ -213,14 +211,14 @@ The warp-level GEMM API is a generalization of CUDA's WMMA API to achieve the fo Defining a warp-level matrix multiply in CUTLASS is similar to WMMA as shown below. -![ALT](/media/images/cutlass-warp-level-gemm-api-instantiation.png "CUTLASS vs WMMA API") +![ALT](../../images/cutlass-warp-level-gemm-api-instantiation.png "CUTLASS vs WMMA API") The usage model is also similar. The following example computes a warp-level GEMM operation, accumulating a series of matrix products in a register-backed array. The input to a warp-level GEMM operation in CUTLASS _must_ be data in shared memory loaded by iterators or on register-backed fragments. -![ALT](/media/images/cutlass-warp-level-gemm-operation.png "CUTLASS warp-level GEMM API") +![ALT](../../images/cutlass-warp-level-gemm-operation.png "CUTLASS warp-level GEMM API") ```c++ #include "cutlass/gemm/warp/default_mma_tensor_op.h" @@ -513,8 +511,8 @@ column-major GEMM, operands A & B are transposed and swapped. To enable efficient row-major epilogue for both row-major and column-major output layout, CUTLASS' device-level GEMM operators `cutlass::device::Gemm` and `cutlass::device::GemmUniversal` provide two template definitions: -- (a) [General definition](/include/cutlass/gemm/device/gemm.h#L217) -- (b) [Specialized definition for column-major source/output](/include/cutlass/gemm/device/gemm.h#L545) +- (a) [General definition](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/device/gemm.h#L217) +- (b) [Specialized definition for column-major source/output](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/device/gemm.h#L545) Efficient row-major epilogue for: - (i) GEMM operator on row-major source/output uses template (a). It runs row-major GEMM and @@ -536,8 +534,8 @@ of input layouts. Thus, CUTLASS supports the following layout combinations for i CUTLASS defines a template-based interface to Tensor Core operations to avoid resorting to inline PTX. -- [mma_sm70.h](/include/cutlass/arch/mma_sm70.h) - Volta TensorCore operations -- [mma_sm75.h](/include/cutlass/arch/mma_sm75.h) - Turing TensorCore operations +- [mma_sm70.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/mma_sm70.h) - Volta TensorCore operations +- [mma_sm75.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/mma_sm75.h) - Turing TensorCore operations # Copyright diff --git a/media/docs/gemm_api_3x.md b/media/docs/cpp/gemm_api_3x.md similarity index 94% rename from media/docs/gemm_api_3x.md rename to media/docs/cpp/gemm_api_3x.md index ab6e6e09..c643fafd 100644 --- a/media/docs/gemm_api_3x.md +++ b/media/docs/cpp/gemm_api_3x.md @@ -1,6 +1,4 @@ -![ALT](../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS GEMM API") - -[README](../../README.md#documentation) > **CUTLASS 3.0 GEMM API** +![ALT](../../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS GEMM API") # CUTLASS 3.0 GEMM API @@ -71,7 +69,7 @@ is implied by CUDA grid launch semantics. However, for persistent kernels, these three loops are expressed in the source code as a single `while` loop that queries the -[work tile scheduler](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp) +[work tile scheduler](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp) for problem tiles on which to compute. Inside the three nested `for` loops, @@ -112,7 +110,7 @@ in order to assemble a kernel. This order is 3. wrap up the kernel with a device layer adapter. -This order is also reflected in the [CUTLASS 3.0 Hopper kernel examples](/examples/48_hopper_warp_specialized_gemm) as seen in the excerpt below. +This order is also reflected in the [CUTLASS 3.0 Hopper kernel examples](https://github.com/NVIDIA/cutlass/tree/main/examples/48_hopper_warp_specialized_gemm) as seen in the excerpt below. ```c++ // Step 1: Generate the required collective layer mainloop specialization @@ -208,7 +206,7 @@ Any looping over multiple tiles that the algorithm might need to do would happen here. The `CollectiveMma` class is declared in the header -[cutlass/gemm/collective/collective_mma.hpp](/include/cutlass/gemm/collective/collective_mma.hpp). +[cutlass/gemm/collective/collective_mma.hpp](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/collective_mma.hpp). ```c++ namespace cutlass::gemm::collective { @@ -328,7 +326,7 @@ all operations that conceptually belong to the same class. This design has the f The primary `CollectiveMma` is intended to be an expert user interface that allows full control over all the properties of the collective's GPU micro-kernel. However, often a user just wants an off-the-shelf GEMM mainloop implementation parameterized on simple configuration parameters. CUTLASS 3.0 -provides [`cutlass::gemm::collective::CollectiveBuilder`](/include/cutlass/gemm/collective/collective_builder.hpp) for such scenarios. +provides [`cutlass::gemm::collective::CollectiveBuilder`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/collective_builder.hpp) for such scenarios. ```c++ namespace cutlass::gemm::collective { @@ -382,7 +380,7 @@ may also change in the future as we adopt user feedback. If the builder is able to provide a collective mainloop type for the given set of parameters, it will be aliased within as `CollectiveOp`. For more information on how to -parameterize kernels conveniently with the collective builder, please see example [49_hopper_gemm_with_collective_builder](/examples/49_hopper_gemm_with_collective_builder). +parameterize kernels conveniently with the collective builder, please see example [49_hopper_gemm_with_collective_builder](https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder). ### Epilogue @@ -390,7 +388,7 @@ The collective epilogue implements element-wise operations involving the output matrix. Users can provide a custom epilogue, or use one of the standard epilogues. These live in the directory -[include/cutlass/epilogue/collective/](/include/cutlass/epilogue/collective/), +[include/cutlass/epilogue/collective/](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/), and include classes like `cutlass::epilogue::collective::DefaultEpilogue` and @@ -418,7 +416,7 @@ epilogues, and/or other operations. The entry point API for CUTLASS 3.0 kernel is the class `cutlass::gemm::kernel::GemmUniversal`, found in the header file -[include/cutlass/gemm/kernel/gemm_universal.hpp](/include/cutlass/gemm/kernel/gemm_universal.hpp). +[include/cutlass/gemm/kernel/gemm_universal.hpp](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/gemm_universal.hpp). `GemmUniversal` is a stateless universal device kernel that implements GEMM as the composition of two parts: @@ -478,24 +476,24 @@ We will explain *collective* in more detail below. Specializations of `kernel::GemmUniversal` for 3.0 APIs live in any of various `gemm_*.hpp` files in the directory -[include/cutlass/gemm/kernel/](/include/cutlass/gemm/kernel/). +[include/cutlass/gemm/kernel/](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/). Specializations for 2.x APIs can be found in the header file -[include/cutlass/gemm/kernel/gemm_universal.h](/include/cutlass/gemm/kernel/gemm_universal.h). +[include/cutlass/gemm/kernel/gemm_universal.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/gemm_universal.h). CUTLASS 3.x implements various embodiments of `kernel::GemmUniversal`. Each kernel layer schedule is specialized for a GEMM scheduling algorithm and GPU architecture. Specializations of `kernel::GemmUniversal` for 3.0 APIs live in any of various `include/cutlass/gemm/kernel/{arch_tag}*.hpp` files in the directory -[include/cutlass/gemm/kernel/](/include/cutlass/gemm/kernel/). +[include/cutlass/gemm/kernel/](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/). Which specialization to dispatch to is decided through the dispatch policy's `Schedule` type. For example, the header file -[include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) +[include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) has a specialization of `kernel::GemmUniversal` for Hopper that uses a warp-specialized mainloop with a persistent scheduling algorithm, while the header file -[include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) +[include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) has a specialization of `GemmUniversal` for Hopper that uses a warp-specialized but non-persistent algorithm. @@ -536,7 +534,7 @@ It serves the same purpose as cuBLAS and behaves similarly. The entry point for the Device GEMM API is the class `cutlass::gemm::device::GemmUniversalAdapter`. This class lives in the header file -[include/cutlass/gemm/device/gemm_universal_adapter.h](/include/cutlass/gemm/device/gemm_universal_adapter.h). +[include/cutlass/gemm/device/gemm_universal_adapter.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/device/gemm_universal_adapter.h). `GemmUniversalAdapter` is a stateful, reusable handle, which is parameterized on the `cutlass::gemm::kernel` type. diff --git a/media/docs/grouped_scheduler.md b/media/docs/cpp/grouped_scheduler.md similarity index 93% rename from media/docs/grouped_scheduler.md rename to media/docs/cpp/grouped_scheduler.md index 4b86e915..333496f7 100644 --- a/media/docs/grouped_scheduler.md +++ b/media/docs/cpp/grouped_scheduler.md @@ -1,6 +1,4 @@ -![ALT](../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Grouped Kernel Schedulers") - -[README](../../README.md#documentation) > **Grouped Kernel Schedulers** +![ALT](../../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Grouped Kernel Schedulers") # CUTLASS Grouped Kernel Schedulers @@ -59,12 +57,12 @@ Consider, for example, the threadblock-to-tile mapping that occurs for a group o each consisting of a grid of 2x2 tiles. Suppose that eight threadblocks are launched. The figure below illustrates the threadblock ID assigned to each tile in each GEMM in the group. -![ALT](/media/images/grouped-gemm-schedule-2x2.png "CUTLASS grouped GEMM scheduler assigning threadblocks to four GEMMs with 2x2 grids of tiles") +![ALT](../../images/grouped-gemm-schedule-2x2.png "CUTLASS grouped GEMM scheduler assigning threadblocks to four GEMMs with 2x2 grids of tiles") A similar mapping for problems that do not have the same number of tiles is shown below: -![ALT](/media/images/grouped-gemm-schedule-varied.png "CUTLASS grouped GEMM scheduler assigning threadblocks to four GEMMs with varying tile count") +![ALT](../../images/grouped-gemm-schedule-varied.png "CUTLASS grouped GEMM scheduler assigning threadblocks to four GEMMs with varying tile count") ## Computing the schedule for a given block Each threadblock in the grouped GEMM computes its own schedule by calling @@ -114,7 +112,7 @@ of a grid of 2x2 tiles. Matrix C in each problem is lower triangular, indicated shaded tiles. Consider that eight threadblocks are launched to compute the grouped problem. The default grouped GEMM scheduler will assign threadblocks to tiles in the following order: -![ALT](/media/images/grouped-syr2k-schedule-using-grouped-gemm-scheduler.png "CUTLASS grouped GEMM scheduler assigning threadblocks to four SYR2Ks with 2x2 grids of tiles") +![ALT](../../images/grouped-syr2k-schedule-using-grouped-gemm-scheduler.png "CUTLASS grouped GEMM scheduler assigning threadblocks to four SYR2Ks with 2x2 grids of tiles") In this case, threadblocks 1 and 5 are continuously assigned to inactive tiles. In scenarios in which problems within the group have varying size, we have observed @@ -129,7 +127,7 @@ lower-triangular problem (and vice-versa for upper-triangular problems). Using the example above, the resulting assignment of threadblocks to tiles from such a scheduler might be: -![ALT](/media/images/grouped-syr2k-schedule-ideal.png "CUTLASS grouped SYR2K scheduler assigning threadblocks to four SYR2Ks with 2x2 grids of tiles") +![ALT](../../images/grouped-syr2k-schedule-ideal.png "CUTLASS grouped SYR2K scheduler assigning threadblocks to four SYR2Ks with 2x2 grids of tiles") Achieving this schedule requires mapping from a threadblock ID to tile coordinates `(i, j)`. @@ -139,7 +137,7 @@ first calculate row and column indices assuming one-indexed rows, tiles, and threadblock IDs, and then subtract one to convert to zero-indexed versions. Our description borrows heavily from the mapping described [here](https://stackoverflow.com/a/40954159). -![ALT](/media/images/grouped-syr2k-schedule-3x3.png "CUTLASS grouped SYR2K scheduler assigning threadblocks to one SYR2K with a 3x3 grids of tiles") +![ALT](../../images/grouped-syr2k-schedule-3x3.png "CUTLASS grouped SYR2K scheduler assigning threadblocks to one SYR2K with a 3x3 grids of tiles") ### Calculating row `i` given threadblock ID `t` For a given row i, all threadblock IDs t in that row satisfy the following: @@ -199,7 +197,7 @@ each of which contains 2 "true tiles." We can thus first map a threadblock ID to using the equations above, and then map it to the "true tile" within its "macro tile." In the example of a 2x4 grid, this mapping would look as follows: -![ALT](/media/images/grouped-syr2k-schedule-macro.png "CUTLASS grouped SYR2K scheduler converting a grid into a 'macro grid' for computing tile mappings for non-square grids") +![ALT](../../images/grouped-syr2k-schedule-macro.png "CUTLASS grouped SYR2K scheduler converting a grid into a 'macro grid' for computing tile mappings for non-square grids") A zero-indexed threadblock ID `t` is mapped to its "macro tile ID" `t_macro` as: ``` @@ -245,7 +243,7 @@ The only modification needed for upper-triangular matrices is to swap `i_macro` # Scheduler modes The grouped kernel schedulers come with two different modes for finding the next tile for a block to compute. These techniques are controlled by -the [`cutlass::gemm::kernel::GroupScheduleMode`](../../include/cutlass/gemm/kernel/grouped_problem_visitor.h) enum. +the [`cutlass::gemm::kernel::GroupScheduleMode`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/grouped_problem_visitor.h) enum. We describe each mode in greater detail below. ## `GroupScheduleMode::kDeviceOnly` (default) @@ -340,7 +338,7 @@ Thus, there are 216 tiles across the group. Suppose this grouped GEMM is run on GA100, which has 108 SMs. Suppose that the occupancy given the parameters of the grouped GEMM is one -- one threadblock can be active at a time on an SM. The grouped GEMM will, thus, run with 108 -persistent threadblocks, each of which computes (216 / 108) = 2 tiles. +persistent threadblocks, each of which computes (256 / 108) = 2 tiles. Under the round-robin assignment of tiles to threadblocks employed by the grouped GEMM scheduler, the assignment of tiles to threadblocks @@ -379,7 +377,7 @@ scheduling mode by around 30%. To ease the process of sorting groups and their associated metadata in this manner, the device-level grouped kernels provide a `sort_problems()` method. -An example of how to use this may be found in the [grouped GEMM example](../../examples/24_gemm_grouped/gemm_grouped.cu). +An example of how to use this may be found in the [grouped GEMM example](https://github.com/NVIDIA/cutlass/tree/main/examples/24_gemm_grouped/gemm_grouped.cu). Finally, while sorting problems can be helpful in certain scenarios, it is not guaranteed to improve performance. In some cases, performance can diff --git a/media/docs/ide_setup.md b/media/docs/cpp/ide_setup.md similarity index 97% rename from media/docs/ide_setup.md rename to media/docs/cpp/ide_setup.md index 9b023659..6a332b31 100644 --- a/media/docs/ide_setup.md +++ b/media/docs/cpp/ide_setup.md @@ -1,5 +1,3 @@ -[README](../../README.md#documentation) > **IDE Setup for CUTLASS Development** - # IDE Setup for CUTLASS Development This document outlines instructions and tips for setting up a local editor for CUTLASS development, including support @@ -33,7 +31,7 @@ and you might see faster responses and more stable performance with clangd. * ...others, depending on which files you edit 1. Edit C++ standard to be `c++17`, `gnu++17`, or equivalent. 1. Edit `defines` to define preprocessor variables. See -[Global Config below](#Global-Config) for examples. The important +[Global Config below](#global-config) for examples. The important ones include `__CUDACC_VER_MAJOR__`, `__CUDA_ARCH__`, `__CUDA_ARCH_FEAT_SM90_ALL__`. But configure them according to your target architecture. 1. ...and possible edit any other fields for your specific setup. diff --git a/media/docs/implicit_gemm_convolution.md b/media/docs/cpp/implicit_gemm_convolution.md similarity index 84% rename from media/docs/implicit_gemm_convolution.md rename to media/docs/cpp/implicit_gemm_convolution.md index 9b00cfc2..d65b9a90 100644 --- a/media/docs/implicit_gemm_convolution.md +++ b/media/docs/cpp/implicit_gemm_convolution.md @@ -1,6 +1,4 @@ -![ALT](../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Implicit GEMM API") - -[README](../../README.md#documentation) > **Implicit GEMM Convolution** +![ALT](../../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Implicit GEMM API") # CUTLASS Convolution @@ -55,7 +53,7 @@ f(p, r) = p * stride_h + R - r - 1 + pad_h g(q, s) = q * stride_w + S - s - 1 + pad_w ``` -A [host](/tools/util/include/cutlass/util/reference/host/convolution.h) and [device](/tools/util/include/cutlass/util/reference/device/convolution.h) +A [host](https://github.com/NVIDIA/cutlass/tree/main/tools/util/include/cutlass/util/reference/host/convolution.h) and [device](https://github.com/NVIDIA/cutlass/tree/main/tools/util/include/cutlass/util/reference/device/convolution.h) reference implementation are provided in the CUTLASS Utilities. This computation may be mapped to the elements of a matrix product as follows. @@ -145,7 +143,7 @@ for (int gemm_i = 0; gemm_i < GEMM_M; ++gemm_i) { } } ``` -The [CUTLASS GEMM implementation](/media/docs/efficient_gemm.md) explicitly iterates over tiles. Consequently, +The [CUTLASS GEMM implementation](efficient_gemm.md) explicitly iterates over tiles. Consequently, a tile iterator could be implemented to compute these functions analytically and load the appropriate elements. However, the resulting modulo arithmetic would be computationally intensive, and overhead would limit performance of a GEMM kernel targeting Turing Tensor Cores. @@ -169,7 +167,7 @@ This enables 128-bit vector memory acceses which lead to efficient CUDA kernels. CUTLASS defines CUDA C++ templates accepting numerous template arguments to specialize the resulting kernel by operation, data type, tile configuration, math instruction, and fused output operation. -In [turing_tensorop_conv2dfprop.cu](/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu), a convolution +In [turing_tensorop_conv2dfprop.cu](https://github.com/NVIDIA/cutlass/tree/main/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu), a convolution operation is defined as follows. ```c++ @@ -232,7 +230,7 @@ Internal accumulation is performed using 32-bit integers (`int32_t`), and an ele is performed on the output in single-precision floating point (`float`). The threadblock and warp-level tile shapes refer to the hierarchically blocked GEMM computation -[described here](/media/docs/gemm_api.md). Larger tiles achieve greater reuse of data loaded through shared memory +[described here](gemm_api.md). Larger tiles achieve greater reuse of data loaded through shared memory but launch fewer CTAs and may not fully occupy the GPU for small problem sizes. Smaller tile configurations achieve lower peak utilizations but may better match the number of SMs within the GPU for real-world workloads. @@ -318,7 +316,7 @@ if (status != cutlass::Status::kSuccess) { ``` The example demonstrates how the input and output tensors may be written to a file as CSV using -`cutlass::HostTensor<>` defined in the [CUTLASS Utilities](/media/docs/utilities.md). +`cutlass::HostTensor<>` defined in the [CUTLASS Utilities](utilities.md). ```c++ std::ofstream output_workspace(ss.str()); @@ -339,41 +337,41 @@ The example demonstrates how the input and output tensors may be written to a fi CUTLASS defines the following CUDA C++ templates to implement Implicit GEMM Convolution which are described in greater detail in subsequent sections. **Activations tile iterators** load the activations tile into registers. Two implementations are provided: -- [conv2d_fprop_activation_tile_access_iterator_analytic.h](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h) computes pointer deltas and masks analytically -- [conv2d_fprop_activation_tile_access_iterator_optimized.h](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h) optimizes iterating over global memory and +- [conv2d_fprop_activation_tile_access_iterator_analytic.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h) computes pointer deltas and masks analytically +- [conv2d_fprop_activation_tile_access_iterator_optimized.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h) optimizes iterating over global memory and creating GEMM-A tile in shared memory. **Filter tile iterators** load filters into registers. Similarly, two implementations are provided: -- [conv2d_fprop_filter_tile_access_iterator_analytic.h](/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h) computes pointer deltas and masks analytically -- [conv2d_fprop_filter_tile_access_iterator_optimized.h](/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h) optimizes iterating over global memory and +- [conv2d_fprop_filter_tile_access_iterator_analytic.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h) computes pointer deltas and masks analytically +- [conv2d_fprop_filter_tile_access_iterator_optimized.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h) optimizes iterating over global memory and creating GEMM-B tile in shared memory. The improvements covered by optimized iterators are: a. Precomputing kernel-invariant pointer deltas on the host b. Computing cta-invariant mask predicates on device-side iterator ctors -c. Use of [fast divmod](/include/cutlass/fast_math.h) to map GEMM dimensions to convolution tensors. +c. Use of [fast divmod](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/fast_math.h) to map GEMM dimensions to convolution tensors. For example, an _optimized_ activation iterator uses fast divmod to map GEMM _M_ to NPQ. **Pipelined mainloop** loads threadblock-scoped tiles from global memory into shared memory and then applies CUTLASS warp-level GEMM operations to load from Shared Memory and issue instructions to Turing Tensor Cores. -- [mma_pipelined.h](/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h) +- [mma_pipelined.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h) Operations for storing to shared memory and performing warp-wide matrix multiply operations using Turing Tensor Cores are applied directly from the CUTLASS GEMM components. These include the following components. **Regular Tile Iterator** implemented in -[transform::threadblock::RegularTileIterator](/include/cutlass/transform/threadblock/regular_tile_iterator.h) +[transform::threadblock::RegularTileIterator](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/transform/threadblock/regular_tile_iterator.h) stores register-backed fragments to Shared Memory in permuted layouts. -**Warp-level GEMM** defined in [cutlass::gemm::warp::MmaTensorOp](/include/cutlass/gemm/warp/mma_tensor_op.h) +**Warp-level GEMM** defined in [cutlass::gemm::warp::MmaTensorOp](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/warp/mma_tensor_op.h) defines tile iterators to load from Shared Memory and issue math instructions to Turing Tensor Cores. -Further details are [described in here](/media/docs/gemm_api.md#warp-level-matrix-multiply-api). +Further details are [described in here](gemm_api.md#warp-level-matrix-multiply-api). **Epilogue** reorders accumulator elements among threads within a threadblock to efficiently update -the output tensor. It is implemented in [epilogue::threadblock::Epilogue](/include/cutlass/epilogue/threadblock/epilogue.h). +the output tensor. It is implemented in [epilogue::threadblock::Epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/threadblock/epilogue.h). ### Loading Activations and Filters @@ -383,7 +381,7 @@ of channels. After iterating over all filter positions, the convolution algorith next interval of channels and proceeds from filter `r=0, s=0`. The matrix product of one threadblock tile is computed per iteration of -the mainloop as described in the [CUTLASS GEMM implementation](/media/docs/efficient_gemm.md). To +the mainloop as described in the [CUTLASS GEMM implementation](efficient_gemm.md). To summarize, the threadblock tile of activations and filters are loaded from tensors in global memory and stored to shared memory. Each thread within the threadblock loads one or more vectors and collectively span the entire tile. @@ -394,9 +392,9 @@ Filters tensors. Each index in the GEMM _M_ dimension corresponds to a unique _( index of the output tensor, and pointers may be computed based on this as well as filter position _(r,s)_. -![ALT](/media/images/conv2d-fprop-int4.png "Convolution Forward Propagation on INT4 data.") +![ALT](../../images/conv2d-fprop-int4.png "Convolution Forward Propagation on INT4 data.") -The CUTLASS component that embodies this functionality is [Conv2dFpropFilterTileAccessIteratorAnalytic](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h). +The CUTLASS component that embodies this functionality is [Conv2dFpropFilterTileAccessIteratorAnalytic](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h). Its constructor computes the mapping of GEMM _M_ to _(N, P, Q)_, the `at()` method maps the linear offset into the Activations tensor for each memory access the thread is to perform. Additionally, the method `valid()` computes the valided of the access for each filter position and for each memory access to indicate whether the memory access will be within the bounds of the @@ -456,11 +454,11 @@ void advance() { } ``` -Similar logic holds for [Conv2dFpropFilterTileAccessIteratorAnalytic](/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h). +Similar logic holds for [Conv2dFpropFilterTileAccessIteratorAnalytic](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h). To reduce computational overhead in the mainloop body, the pointer offsets may be precomputed in host code and provided to the CUDA kernel as a lookup table in its `Params` structure. -As shown in [Conv2dFpropFilterTileAccessIteratorOptimized](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h), +As shown in [Conv2dFpropFilterTileAccessIteratorOptimized](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h), the logic to compute offsets from filter position has been extracted to the `Params` constructor. ```c++ @@ -535,11 +533,11 @@ threads within a warp. The following operations are supported. Functionally, the Turing 8x8x32 matrix multiply operation distributes the _A_, _B_, and _C_ matrix across 32 threads within a warp according to the following illustration. -![ALT](/media/images/mma-8x8x32.png "Turing Tensor Op") +![ALT](../../images/mma-8x8x32.png "Turing Tensor Op") This Tensor Core operation is accessible to the CUDA programmer via the PTX instruction [`mma.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-8832). -CUTLASS wraps inline PTX with device-side intrinsics defined in [`cutlass/arch/mma_sm75.h`](/include/cutlass/arch/mma_sm75.h) +CUTLASS wraps inline PTX with device-side intrinsics defined in [`cutlass/arch/mma_sm75.h`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/mma_sm75.h) as in the following example. ```c++ @@ -565,7 +563,7 @@ per row. The arrangement of SMEM pointers and destination registers within threads is illustrated as follows. Thread 0 is highlighted in the illustration to emphasize the mapping. -![ALT](/media/images/ldmatrix-8x128bx4.png "Turing ldmatrix PTX instruction") +![ALT](../../images/ldmatrix-8x128bx4.png "Turing ldmatrix PTX instruction") The size of the Turing Tensor Core operation computing matrix multiply-accumulate on INT4 data is 8-by-8-by-32 elements. `ldmatrix` fetches up to 32 rows (or columns) per operation. Sixteen Tensor Core operations may be issued @@ -574,7 +572,7 @@ as shown in the following figure. Larger tiles are possible by increasing the nu and issuing more Tensor Core operations, up to warp-level matrix operations of size 64-by-64-by-32. The limit is the number of registers to hold the accumulator elements. -![ALT](/media/images/ldmatrix-tensorop-32x32x32.png "Turing ldmatrix PTX instruction feeding Tensor Core operations") +![ALT](../../images/ldmatrix-tensorop-32x32x32.png "Turing ldmatrix PTX instruction feeding Tensor Core operations") ### Shared Memory Layouts @@ -588,7 +586,7 @@ load from Shared Memory using `ldmatrix`. The following figure illustrates the t the loading the activations and filters threadblock tiles from global memory and the permuted layout in Shared Memory. -![ALT](/media/images/tensor-op-permuted-smem-layout-TN.png "Shared Memory layout used for Turing Tensor Cores") +![ALT](../../images/tensor-op-permuted-smem-layout-TN.png "Shared Memory layout used for Turing Tensor Cores") In the illustration, one warp-wide memory access is highlighted in blue, with individual threads loading one 128-bit vector. The tile in global memory could correspond either to the activations @@ -618,7 +616,7 @@ The following figure shows how the first sixteen threads participating in an `ld logically map to the c=0..31 slice of a matrix in Shared Memory. This slice is known as a "k-group" within the code because it corresponds to the same K-index of a warp-level matrix multiply. -![ALT](/media/images/tensor-op-permuted-smem-layout-TN-k0.png "Load kgroup=0 from Shared Memory using ldmatrix") +![ALT](../../images/tensor-op-permuted-smem-layout-TN-k0.png "Load kgroup=0 from Shared Memory using ldmatrix") The lower half of the figure shows the physical arrangement in Shared Memory, with threads offset by row and column according to the XOR function. By inspection, we can observe there are no bank conflicts, as _T0 ... T7_ each access unique @@ -632,9 +630,9 @@ the following sequence: - **^3** advances from _k=3_ to _k=0_ The first of these transitions is shown below. -![ALT](/media/images/tensor-op-permuted-smem-layout-TN-k1.png "Advance to kgroup=1 from Shared Memory using ldmatrix") +![ALT](../../images/tensor-op-permuted-smem-layout-TN-k1.png "Advance to kgroup=1 from Shared Memory using ldmatrix") -The [CUTLASS warp-level GEMM API](/media/docs/gemm_api.md#warp-level-matrix-multiply-api) defines templates for +The [CUTLASS warp-level GEMM API](gemm_api.md#warp-level-matrix-multiply-api) defines templates for loading slices of data from permuted Shared Memory and issuing operations to Tensor Cores. ### Updating the Output Tensor @@ -647,11 +645,11 @@ needed. The **Epilogue** is the component for exchanging accumulator elements through Shared Memory, loading slices of the output matrix or tensor, applying an elementwise operation such as linear scaling or bias, and storing the result to the output tensor. CUTLASS structures this as several components: -- [cutlass::epilogue::threadblock::Epilogue](/include/cutlass/epilogue/threadblock/epilogue.h) - the top-level component for looping over the entire threadblock tile -- [cutlass::epilogue::warp::TileIteratorTensorOp](/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h) - a specialized component for storing accumulators for Tensor Core to Shared Memory -- [cutlass::epilogue::threadblock::SharedLoadIterator](/include/cutlass/epilogue/threadblock/shared_load_iterator.h) - a component for loading elements from a row-major arrangement in Shared Memory -- [cutlass::epilogue::threadblock::PredicatedTileIterator](/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h) - a component for loading or storing matrix fragments to Global Memory (with bounds checks) -- [cutlass::epilogue::thread::LinearCombination](/include/cutlass/epilogue/thread/linear_combination.h) - an element-wise function computing `alpha * AB + beta * C` to compute the final output +- [cutlass::epilogue::threadblock::Epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/threadblock/epilogue.h) - the top-level component for looping over the entire threadblock tile +- [cutlass::epilogue::warp::TileIteratorTensorOp](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h) - a specialized component for storing accumulators for Tensor Core to Shared Memory +- [cutlass::epilogue::threadblock::SharedLoadIterator](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/threadblock/shared_load_iterator.h) - a component for loading elements from a row-major arrangement in Shared Memory +- [cutlass::epilogue::threadblock::PredicatedTileIterator](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h) - a component for loading or storing matrix fragments to Global Memory (with bounds checks) +- [cutlass::epilogue::thread::LinearCombination](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/thread/linear_combination.h) - an element-wise function computing `alpha * AB + beta * C` to compute the final output ## Unit Tests @@ -663,13 +661,13 @@ b. showcase instantiation of use of these templates in device code, and c. assert functional correctness. **Convolution unit tests** -- Device-wide convolution operator: [conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu) +- Device-wide convolution operator: [conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu) **GEMM unit tests** -- Warp-scoped matrix multiply for Turing Tensor Cores: [gemm_sm75.cu](/test/unit/gemm/warp/gemm_sm75.cu) +- Warp-scoped matrix multiply for Turing Tensor Cores: [gemm_sm75.cu](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/warp/gemm_sm75.cu) **Epilogue unit tests** -- Epilogue for Turing Tensor Cores: [epilogue_tensor_op.cu](/test/unit/epilogue/threadblock/epilogue_tensor_op.cu) +- Epilogue for Turing Tensor Cores: [epilogue_tensor_op.cu](https://github.com/NVIDIA/cutlass/tree/main/test/unit/epilogue/threadblock/epilogue_tensor_op.cu) # Convolution Example @@ -681,10 +679,10 @@ of Implicit GEMM Convolution. Example `09_turing_tensorop_conv2dfprop` computes a forward convolutional layer in which inputs and outputs are 4-b integers. The example source is visible in -[examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu](/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu). +[examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu](https://github.com/NVIDIA/cutlass/tree/main/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu). -Before building the example, first perform the prerequisite steps for building any CUTLASS component [described here](/media/docs/quickstart.md). +Before building the example, first perform the prerequisite steps for building any CUTLASS component [described here](quickstart.md). Compute capability 7.5 refers to the Turing architecture, and this work requires CUDA 10.2 Toolkit or later to target Turing Tensor Cores using the native `mma` [PTX instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-8832). @@ -708,7 +706,7 @@ initialize them to random values, and compute the result of a convolutional laye tensors may be saved to .csv files, and the CUTLASS host-side reference check may be executed to verify correctness. The complete usage statement is visible by running with `--help`: -```bash +``` $ ./examples/09_turing_tensorop_conv2dfprop/09_turing_tensorop_conv2dfprop --help 09_turing_tensorop_conv2dfprop example diff --git a/media/docs/layout.md b/media/docs/cpp/layout.md similarity index 98% rename from media/docs/layout.md rename to media/docs/cpp/layout.md index bd544c0a..5e1d4d29 100644 --- a/media/docs/layout.md +++ b/media/docs/cpp/layout.md @@ -1,6 +1,4 @@ -![ALT](../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Layouts and Tensors") - -[README](../../README.md#documentation) > **Layouts and Tensors** +![ALT](../../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Layouts and Tensors") Note: This document talks about CUTLASS 2.x layout tag types. CUTLASS 3.0 deprecates all legacy 2.x layout tags in favour of a single `cute::Layout` diff --git a/media/docs/cpp/overview.md b/media/docs/cpp/overview.md new file mode 100644 index 00000000..b696686a --- /dev/null +++ b/media/docs/cpp/overview.md @@ -0,0 +1,619 @@ +![ALT](../../images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") + +# Overview + +# CUTLASS 3.9.0 + +_CUTLASS 3.9.0 - March 2025_ + +CUTLASS is a collection of CUDA C++ template abstractions for implementing +high-performance matrix-matrix multiplication (GEMM) and related computations at all levels +and scales within CUDA. It incorporates strategies for hierarchical decomposition and +data movement similar to those used to implement cuBLAS and cuDNN. CUTLASS decomposes +these "moving parts" into reusable, modular software components abstracted by C++ template +classes. Primitives for different levels of a conceptual parallelization hierarchy +can be specialized and tuned via custom tiling sizes, data types, +and other algorithmic policy. The resulting flexibility simplifies their use +as building blocks within custom kernels and applications. + +To support a wide variety of applications, CUTLASS provides extensive support for +mixed-precision computations, providing specialized data-movement and +multiply-accumulate abstractions for FP64, FP32, TF32, FP16, BF16, +[FP32 emulation via tensor core instruction](https://github.com/NVIDIA/cutlass/tree/main/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm), + 8b floating point types (e5m2 and e4m3), + block scaled data types (NVIDIA NVFP4 and OCP standard MXFP4, MXFP6, MXFP8), + narrow integer types (4 and 8b signed and unsigned integers), + and binary 1b data types (where architectures allow for the +native support of such data types). +CUTLASS demonstrates optimal matrix multiply operations +targeting the programmable, high-throughput _Tensor Cores_ implemented by +NVIDIA's Volta, Turing, Ampere, Ada, Hopper, and Blackwell architectures. + +In addition to GEMMs, CUTLASS implements high-performance convolution via +the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution +operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. +This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components. + +See the [Quick Start Guide](quickstart.md) to get started quickly. + +See the [functionality docs](functionality.md) for a more comprehensive +list of kernel level features, data types, instructions, and minimum supported by CUTLASS on each GPU +architecture. + +# What's New in CUTLASS 3.9 + +* Support for Blackwell SM120 kernels for GeForce GPUs in CUTLASS 3.x API: + - Collective mainloops that target for: + * [Blockscaled datatypes with support for dense GEMM](../../../include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp) + * [Blockscaled datatypes with support for sparse GEMM](../../../include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp) + - New [GEMM](../../../include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](../../../include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders. + - [Blackwell SM120 epilogue](../../../include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp) and [full set of EVT fusions](../../../include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp). +* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM120 architecture: + - [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](../../../examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu). + - [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](../../../examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu). + - [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](../../../examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu). +* Set of unit tests that demonstrate the usage of both [sparse](../../../test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](../../../test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM. +* Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures: + - Enhancement of [blockwise GEMM](../../../examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture. + - Enhancement of [groupwise GEMM](../../../examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture. + - Support for [grouped GEMM with blockwise and groupwise scaling](../../../examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture. + - Support for [blockwise GEMM](../../../examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture. + - Support for [groupwise GEMM](../../../examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture. + - Support for [grouped GEMM with blockwise](../../../examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](../../../examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture. +* Added support for enhanced kernel performance search (auto-tuning) in CUTLASS profiler: + - Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels. + - Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance. + - Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration. + - More detailed introductions and examples to leverage this feature can be found in [profiler.md](./profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss). + +Note: CUTLASS 3.x builds are known to be down on Windows platforms for all CUDA toolkits. +CUTLASS team is working on a fix. + +**See the [CHANGELOG](../release_notes.md) for details of all past releases and updates.** + +# Performance + +CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels, +they exhibit nearly optimal utilization of peak theoretical throughput. The figure below +shows CUTLASS 3.8's performance as a % of theoretical peak utilization +on various input and output data types when run on NVIDIA Blackwell SM100 architecture GPU. + +![ALT](../../images/cutlass-3.8-blackwell-gemm-peak-performance.svg "") + +The two figures below show the continual CUTLASS performance improvements +on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture) since +CUTLASS 3.1. +CUTLASS 3.5.1 was compiled with the [CUDA 12.5u1 Toolkit](https://developer.nvidia.com/cuda-downloads). +Tensor Core operations are implemented using CUDA's +[mma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) and +[wgmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) instructions. + +![ALT](../../images/cutlass-3.5.1-gemm-peak-performance.png "") +![ALT](../../images/cutlass-3.5.1-gemm-peak-performance-fp8.png "") + +# CuTe + +CUTLASS 3.0 introduced a new core library, CuTe, to describe and manipulate tensors of threads and data. +CuTe is a collection of C++ CUDA template abstractions for +defining and operating on hierarchically multidimensional layouts of threads and data. +CuTe provides `Layout` and `Tensor` objects that compactly package the type, +shape, memory space, and layout of data, while performing the complicated indexing for the user. +This lets programmers focus on the logical descriptions of their algorithms while +CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, +implement, and modify all dense linear algebra operations. + +The core abstractions of CuTe are hierarchically multidimensional layouts +which can be composed with data arrays to represent tensors. +The representation of layouts is powerful enough to represent nearly +everything we need to implement efficient dense linear algebra. +Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning. + +CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. +This greatly simplifies the design and improves code composability and readability. +More documentation specific to CuTe can be found in its +[dedicated documentation directory](cute/00_quickstart.md). + +# Compatibility + +Minimum requirements: + +- Architecture: Volta (compute capability 7.0) +- Compiler: Must support at least C++17 +- CUDA Toolkit version: 11.4 + +CUTLASS requires a C++17 host compiler and +performs best when built with the [**CUDA 12.8 Toolkit**](https://developer.nvidia.com/cuda-downloads). +It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, and all other CUDA 12.x versions. + +## Operating Systems + +We have tested the following environments. + +|**Operating System** | **Compiler** | +|-----------------|----------| +| Ubuntu 18.04 | GCC 7.5.0 | +| Ubuntu 20.04 | GCC 10.3.0 | +| Ubuntu 22.04 | GCC 11.2.0 | + +Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended. + +Note: CUTLASS 3.x builds are known to be down on Windows platforms for all CUDA toolkits. +CUTLASS team is working on a fix. + +## Hardware + +CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on Volta, Turing, Ampere, Ada, and Hopper architecture based NVIDIA GPUs. + +|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit Required by CUTLASS-3**| +|---|---|---| +|NVIDIA V100 Tensor Core GPU |7.0|11.4| +|NVIDIA TitanV |7.0|11.4| +|NVIDIA GeForce RTX 20x0 series |7.5|11.4| +|NVIDIA T4 |7.5|11.4| +|NVIDIA A100 Tensor Core GPU |8.0|11.4| +|NVIDIA A10 |8.6|11.4| +|NVIDIA GeForce RTX 30x0 series |8.6|11.4| +|NVIDIA GeForce RTX 40x0 series |8.9|11.8| +|NVIDIA L40 |8.9|11.8| +|NVIDIA H100 Tensor Core GPU |9.0|11.8| +|NVIDIA H200 Tensor Core GPU |9.0|11.8| +|NVIDIA B200 Tensor Core GPU |10.0|12.8| +|NVIDIA GeForce RTX 50x0 series |10.0|12.8| + +## Target Architecture + +In general, PTX code generated for one target architecture can be run on future architectures +(i.e., it is forward compatible). +However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose +PTX does not have forward compatibility guarantees. +Several Hopper and Blackwell PTX instructions fall under this category of +architecture-accelerated features, and thus require a `sm_90a` or `sm100a` target architecture +(note the "a" appended). For more details on this and other architecture-accelerated instructions, +please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability). + +The target architecture information is passed on to CUTLASS via the cmake flag +`CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, +users are required to build CUTLASS with `90a` as the target architecture. +If a user accidentally builds a kernel which uses SM90a features +(e.g. Hopper Tensor Core Instructions), using the SM90 target +(note the lack of "a"), with either CUDA Toolkit 12 or 11.8, +the kernel is expected to fail with a runtime error. + +``` +cmake .. -DCUTLASS_NVCC_ARCHS="90a" +``` +Or + +``` +cmake .. -DCUTLASS_NVCC_ARCHS="100a" +``` + +Note: The NVIDIA Blackwell SM100 architecture used in the datacenter +products has a different compute capability than the one underpinning +NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels +compiled for Blackwell SM100 architecture with arch conditional features +(using `sm100a`) are not compatible with RTX 50 series GPUs. + +Please refer to the [functionality documentation](functionality.md) +for details on which kernels require which target architectures. + +# Documentation + +CUTLASS is described in the following documents and the accompanying +[Doxygen documentation](https://nvidia.github.io/cutlass). + +- [Quick Start Guide](quickstart.md) - basics of building and running CUTLASS +- [Functionality](functionality.md) - summarizes functionality available in CUTLASS +- [Efficient GEMM in CUDA](efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA +- [CUTLASS 3.x Design](cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components +- [GEMM API 3.x](gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts +- [GEMM API 2.x](gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts +- [Implicit GEMM Convolution](implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS +- [Code Organization](code_organization.md) - describes the organization and contents of the CUTLASS project +- [Terminology](terminology.md) - describes terms used in the code +- [Programming Guidelines](programming_guidelines.md) - guidelines for writing efficient modern CUDA C++ +- [Fundamental types](fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays +- [Layouts](layout.md) - describes layouts of matrices and tensors in memory +- [Tile Iterators](tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory +- [CUTLASS Profiler](profiler.md) - command-line driven profiling application +- [CUTLASS Utilities](utilities.md) - additional templates used to facilitate rapid development +- [Dependent kernel launch](dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent +kernels in the same stream, and how it is used in CUTLASS. + +# Resources +We have also described the structure of an efficient GEMM in our talk at the +[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf). + +- [CUTLASS: Software Primitives for Dense Linear Algebra at All Levels and Scales within CUDA](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2018-s8854/) +- [Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100](https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745/) +- [Accelerating Convolution with Tensor Cores in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring21-s31883/) +- [Accelerating Backward Data Gradient by Increasing Tensor Core Utilization in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41996/) +- [CUTLASS: Python API, Enhancements, and NVIDIA Hopper](https://www.nvidia.com/en-us/on-demand/session/gtcfall22-a41131/) + +# Building CUTLASS + +CUTLASS is a header-only template library and does not need to be built to be used by other +projects. Client applications should target CUTLASS's `include/` directory in their include +paths. + +CUTLASS unit tests, examples, and utilities can be build with CMake. +The minimum version of CMake is given in the [Quickstart guide](quickstart.md). +Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed +on your system. + +```bash +$ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc +``` + +Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels +for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, 8.0, 8.6, 8.9, and 9.0. +To reduce compile time you can specify +the architectures to build CUTLASS for by changing the CMake configuration setting +`CUTLASS_NVCC_ARCHS`. + +```bash +$ mkdir build && cd build + +$ cmake .. -DCUTLASS_NVCC_ARCHS=80 # compiles for NVIDIA's Ampere Architecture +``` + +From the `build/` directory, compile and run the CUTLASS unit tests by building the target `test_unit` with make. + +The unit tests are organized as several binaries mirroring the top-level namespaces of CUTLASS, +and they may be executed in parallel via make's `-j` command line argument. + +```bash +$ make test_unit -j +... +... +... +[----------] Global test environment tear-down +[==========] 946 tests from 57 test cases ran. (10812 ms total) +[ PASSED ] 946 tests. +``` + +All tests should pass on supported platforms, though the exact number of tests may vary over time. + + +# Project Structure + +CUTLASS is arranged as a header-only library along with Utilities, Tools, Examples, and unit tests. +[Doxygen documentation](https://nvidia.github.io/cutlass) provides a complete list of files, classes, +and template concepts defined in the CUTLASS project. + +A detailed explanation of the source code organization may be found in the +[CUTLASS documentation](code_organization.md), but several main components are summarized below. + +## CUTLASS Template Library + +``` +include/ # client applications should target this directory in their build's include paths + + cutlass/ # CUDA Templates for Linear Algebra Subroutines and Solvers - headers only + + arch/ # direct exposure of architecture features (including instruction-level GEMMs) + + conv/ # code specialized for convolution + + epilogue/ # code specialized for the epilogue of gemm/convolution + + gemm/ # code specialized for general matrix product computations + + layout/ # layout definitions for matrices, tensors, and other mathematical objects in memory + + platform/ # CUDA-capable Standard Library components + + reduction/ # bandwidth-limited reduction kernels that do not fit the "gemm" model + + thread/ # simt code that can be performed within a CUDA thread + + transform/ # code specialized for layout, type, and domain transformations + + * # core vocabulary types, containers, and basic numeric operations + + cute/ # CuTe Layout, layout algebra, MMA/Copy atoms, tiled MMA/Copy + + algorithm/ # Definitions of core operations such as copy, gemm, and operations on cute::tuples + + arch/ # Bare bones PTX wrapper structs for copy and math instructions + + atom/ # Meta-information either link to or built from arch/ operators + + mma_atom.hpp # cute::Mma_Atom and cute::TiledMma + + copy_atom.hpp # cute::Copy_Atom and cute::TiledCopy + + *sm*.hpp # Arch specific meta-information for copy and math operations + + * # Core library types such as Shape, Stride, Layout, Tensor, and associated operations + +``` + +### CUTLASS SDK Examples + +[CUTLASS SDK examples](https://github.com/NVIDIA/cutlass/tree/main/examples) apply CUTLASS templates to implement basic computations. + +### Tools + +``` +tools/ + library/ # CUTLASS Instance Library - contains instantiations of all supported CUTLASS templates + include/ + cutlass/ + library/ + + profiler/ # CUTLASS Profiler - command-line utility for executing operations in the + # CUTLASS Library + + util/ # CUTLASS Utilities - contains numerous helper classes for + include/ # manging tensors in device memory, reference + cutlass/ # implementations for GEMM, random initialization + util/ # of tensors, and I/O. +``` + +### Test + +The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate +basic usage of Core API components and complete tests of the CUTLASS GEMM computations. + +Instructions for building and running the Unit tests are described in the [Quickstart guide](quickstart.md). + +# Performance Profiling + +The `tools/profiler/` directory contains a command-line utility for launching each of the GEMM kernels. +It can be built as follows: + +```bash +$ make cutlass_profiler -j16 +``` +## Building all GEMM and Convolution kernels (_long_ build times) + +By default, only one tile size is instantiated for each data type, math instruction, and layout. +To instantiate all, set the following environment variable when running CMake from an empty `build/` directory. +Beware, this results in *tens of thousands* of kernels and long build times. +This would also result in a large binary size and on some platforms linker to fail on building the library. +Therefore, it's highly recommended to generate only a subset of kernels as demonstrated in the sub-section below. +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCUTLASS_LIBRARY_KERNELS=all +... +$ make cutlass_profiler -j16 +``` + +## Building a subset of GEMM and Convolution kernels (_reduced_ build times) + +To compile strictly one kernel or a small set of kernels, a comma-delimited list of kernel names with +wildcard characters may be used to reduce the set of kernels. The following examples show building exactly one +or a subset of kernels for NVIDIA Ampere and Turing architecture: + +### Building a subset Tensor Core GEMM kernels + +To compile a subset of Tensor Core GEMM kernels with FP32 accumulation and FP16 input targeting NVIDIA Ampere and Turing architecture, +use the below cmake command line: +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_s*gemm_f16_*_nt_align8 +... +$ make cutlass_profiler -j16 +``` + +Example command line for profiling a subset of Tensor Core GEMM kernels is as follows: +```bash +./tools/profiler/cutlass_profiler --kernels=cutlass_tensorop_s*gemm_f16_*_nt_align8 --m=3456 --n=4096 --k=4096 + +... +============================= + Problem ID: 1 + + Provider: CUTLASS + OperationKind: gemm + Operation: cutlass_tensorop_s1688gemm_f16_256x128_32x2_nt_align8 + + Status: Success + Verification: ON + Disposition: Passed + +reference_device: Passed + cuBLAS: Passed + + Arguments: --gemm_kind=universal --m=3456 --n=4096 --k=4096 --A=f16:column --B=f16:row --C=f32:column --alpha=1 \ + --beta=0 --split_k_slices=1 --batch_count=1 --op_class=tensorop --accum=f32 --cta_m=256 --cta_n=128 \ + --cta_k=32 --stages=2 --warps_m=4 --warps_n=2 --warps_k=1 --inst_m=16 --inst_n=8 --inst_k=8 --min_cc=75 \ + --max_cc=1024 + + Bytes: 118489088 bytes + FLOPs: 115992428544 flops + + Runtime: 1.55948 ms + Memory: 70.7616 GiB/s + + Math: 74378.8 GFLOP/s + + + +============================= +... +``` + +### Building one CUDA Core GEMM kernel + +To compile one SGEMM kernel targeting NVIDIA Ampere and Turing architecture, use the below cmake command line: +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_simt_sgemm_128x128_8x2_nn_align1 +... +$ make cutlass_profiler -j16 +``` + +Example command line for profiling single SGEMM CUDA kernel is as follows: +```bash +$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=3456 --n=4096 --k=4096 + +============================= + Problem ID: 1 + + Provider: CUTLASS + OperationKind: gemm + Operation: cutlass_simt_sgemm_128x128_8x2_nn_align1 + + Status: Success + Verification: ON + Disposition: Passed + + cuBLAS: Passed + + Arguments: --m=3456 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0 --split_k_slices=1 \ + --batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 --stages=2 --warps_m=4 \ + --warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 --max_cc=1024 + + Bytes: 180355072 bytes + FLOPs: 115992428544 flops + + Runtime: 6.73655 ms + Memory: 24.934 GiB/s + + Math: 17218.4 GFLOP/s + +============================= +``` + +### Building a subset of Tensor Core Convolution kernels + +To compile a subset of Tensor core convolution kernels implementing forward propagation (fprop) with FP32 accumulation +and FP16 input targeting NVIDIA Ampere and Turing architecture, use the below cmake command line: +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_s*fprop_optimized_f16 +... +$ make cutlass_profiler -j16 +``` + +Example command line for profiling a subset of Tensor Core convolution kernels is as follows: + +```bash +$ ./tools/profiler/cutlass_profiler --kernels=cutlass_tensorop_s*fprop_optimized_f16 --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3 + +... +============================= + Problem ID: 1 + + Provider: CUTLASS + OperationKind: conv2d + Operation: cutlass_tensorop_s16816fprop_optimized_f16_128x128_32x5_nhwc + + Status: Success + Verification: ON + Disposition: Passed + +reference_device: Passed + + Arguments: --conv_kind=fprop --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3 --p=224 --q=224 --pad_h=1 --pad_w=1 \ + --stride_h=1 --stride_w=1 --dilation_h=1 --dilation_w=1 --Activation=f16:nhwc --Filter=f16:nhwc --Output=f32:nhwc \ + --conv_mode=cross --iterator_algorithm=optimized --alpha=1 --beta=0 --split_k_mode=serial --split_k_slices=1 \ + --eq_gemm_provider=none --op_class=tensorop --accum=f32 --cta_m=128 --cta_n=128 --cta_k=32 --stages=5 \ + --warps_m=2 --warps_n=2 --warps_k=1 --inst_m=16 --inst_n=8 --inst_k=16 --min_cc=80 --max_cc=1024 + + Bytes: 1130659840 bytes + FLOPs: 118482796544 flops + + Runtime: 0.711496 ms + Memory: 1479.99 GiB/s + + Math: 166526 GFLOP/s + +============================= +... +``` + + +### Building one Convolution CUDA kernel + +To compile and run one CUDA Core convolution kernel implementing forward propagation (fprop) with F32 accumulation +and FP32 input targeting NVIDIA Ampere and Turing architecture, use the below cmake command line: +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_simt_sfprop_optimized_128x128_8x2_nhwc +... +$ make cutlass_profiler -j16 +``` + +Example command line for profiling one CUDA Core convolution kernel: + +```bash +$ ./tools/profiler/cutlass_profiler --kernels=cutlass_simt_sfprop_optimized_128x128_8x2_nhwc --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3 + + +============================= + Problem ID: 1 + + Provider: CUTLASS + OperationKind: conv2d + Operation: cutlass_simt_sfprop_optimized_128x128_8x2_nhwc + + Status: Success + Verification: ON + Disposition: Passed + +reference_device: Passed + + Arguments: --conv_kind=fprop --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3 --p=224 --q=224 --pad_h=1 --pad_w=1 \ + --stride_h=1 --stride_w=1 --dilation_h=1 --dilation_w=1 --Activation=f32:nhwc --Filter=f32:nhwc --Output=f32:nhwc \ + --conv_mode=cross --iterator_algorithm=optimized --alpha=1 --beta=0 --split_k_mode=serial --split_k_slices=1 \ + --eq_gemm_provider=none --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 --stages=2 --warps_m=4 \ + --warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 --max_cc=1024 + + Bytes: 2055798784 bytes + FLOPs: 118482796544 flops + + Runtime: 7.34266 ms + Memory: 260.752 GiB/s + + Math: 16136.2 GFLOP/s + + +============================= + +``` + +## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler +- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels: + - [GEMM CMake Examples](quickstart.md#gemm-cmake-examples) + - [Implicit GEMM convolution CMake Examples](quickstart.md#convolution-cmake-examples) +- [Further details about the CUTLASS Profiler are described here.](profiler.md) + + +# About + +CUTLASS is released by NVIDIA Corporation as Open Source software under the +[3-clause "New" BSD license](LICENSE.txt). + +# Contributors + +The official list of CUTLASS developers and contributors is available here: [CONTRIBUTORS](CONTRIBUTORS.md). + +# Copyright + +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/media/docs/pipeline.md b/media/docs/cpp/pipeline.md similarity index 93% rename from media/docs/pipeline.md rename to media/docs/cpp/pipeline.md index 1a8b551a..aa647304 100644 --- a/media/docs/pipeline.md +++ b/media/docs/cpp/pipeline.md @@ -42,10 +42,10 @@ CUTLASS now includes abstractions for the following features introduced in Hopper. 1. Thread block cluster - level synchronization and query - [APIs](/include/cute/arch/cluster_sm90.hpp) + [APIs](https://github.com/NVIDIA/cutlass/tree/main/include/cute/arch/cluster_sm90.hpp) 2. Abstractions for new - [barrier instructions](/include/cutlass/arch/barrier.h) + [barrier instructions](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/barrier.h) which help with efficient synchronization of threads within a thread block cluster. @@ -54,7 +54,7 @@ for the following features introduced in Hopper. In order to write a performant GEMM Kernel, software pipelining is critical to hide the latency of global memory loads. (Please refer to the -[Efficient GEMM](/media/docs/efficient_gemm.md#pipelining) document.) +[Efficient GEMM](efficient_gemm.md#pipelining) document.) Different threads or groups of threads may have different roles in the pipeline. Some are "producers" that load data or perform computations @@ -73,7 +73,7 @@ dozens of different kinds of asynchronously executing operations that synchronize using multiple barriers organized as a circular list. This complexity is too much for human programmers to manage by hand. As a result, we have developed -[asynchronous Pipeline classes](/include/cutlass/pipeline/). +[asynchronous Pipeline classes](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/pipeline/). These classes help developers orchestrate a pipeline of asynchronous producer and consumer threads, without needing to worry about lower-level hardware details. @@ -173,8 +173,8 @@ and then synchronize among 3 asynchronously executing threads: Please note that this is a basic example. There are different versions possible, depending on what the producer and consumer threads are doing. -Please refer to our [unit tests](/test/unit/pipeline) -and the other [pipeline classes](/include/cutlass/pipeline/pipeline.hpp) +Please refer to our [unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/pipeline) +and the other [pipeline classes](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/pipeline/pipeline.hpp) for more details. # Copyright diff --git a/media/docs/profiler.md b/media/docs/cpp/profiler.md similarity index 97% rename from media/docs/profiler.md rename to media/docs/cpp/profiler.md index c4de675c..58088dff 100644 --- a/media/docs/profiler.md +++ b/media/docs/cpp/profiler.md @@ -1,6 +1,4 @@ -![ALT](../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Profiler") - -[README](../../README.md#documentation) > **CUTLASS Profiler** +![ALT](../../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Profiler") # CUTLASS Profiler @@ -33,9 +31,9 @@ tools/ # Emitting kernels via `emit_kernel_listing.py` -We provide a Python script `emit_kernel_listing.py` that allows a user to selectively test a subset of profiler-based kernels stamped out in `generator.py`. A unique benefit to generate kernels and test via this script is that it can feed a series of runtime arguments, such as different `M`/`N`/`K` and `alpha`/`beta`, to each kernel, instead of relying on a single default value. It also properly generates runtime datatype and cluster shapes for certain kernels to help reduce the generated kernel count and accordingly the total compilation time. An interested user may refer to [emit_kernel_listing.py](../../python/cutlass_library/emit_kernel_listing.py) for details. To enable this new feature, a user should add `-DCUTLASS_BUILD_FOR_PROFILER_REGRESSIONS=ON` when building CUTLASS profiler. +We provide a Python script `emit_kernel_listing.py` that allows a user to selectively test a subset of profiler-based kernels stamped out in `generator.py`. A unique benefit to generate kernels and test via this script is that it can feed a series of runtime arguments, such as different `M`/`N`/`K` and `alpha`/`beta`, to each kernel, instead of relying on a single default value. It also properly generates runtime datatype and cluster shapes for certain kernels to help reduce the generated kernel count and accordingly the total compilation time. An interested user may refer to [emit_kernel_listing.py](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/emit_kernel_listing.py) for details. To enable this new feature, a user should add `-DCUTLASS_BUILD_FOR_PROFILER_REGRESSIONS=ON` when building CUTLASS profiler. -### Instantiating more kernels with Hopper +## Instantiating more kernels with Hopper With Hopper (SM90), you will need to use an additional flag, `CUTLASS_LIBRARY_INSTANTIATION_LEVEL`, in order to instantiate all possible combinations, which unlike previous architectures, will be in the order of millions of kernels. @@ -81,12 +79,12 @@ Instruction shape levels control the selection of WGMMA shapes used in kernel ge - **Level 2**: Includes shapes that are powers of 2. - **Level 3**: Includes all other shapes. -The detailed defination of the three instantiation levels controlling cluster shape, MMA shape multiplier, and instruction shape can be found in [sm90_shapes.py](../../python/cutlass_library/sm90_shapes.py). +The detailed defination of the three instantiation levels controlling cluster shape, MMA shape multiplier, and instruction shape can be found in [sm90_shapes.py](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py). -Schedule pruning levels decide the epilogue schedule and mainloop schedule to stamp out a kernel instance. As defined in `get_valid_schedules` in [sm90_utils.py](../../python/cutlass_library/sm90_utils.py), +Schedule pruning levels decide the epilogue schedule and mainloop schedule to stamp out a kernel instance. As defined in `get_valid_schedules` in [sm90_utils.py](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_utils.py), - **Level >= 1**: Indicates that no pruning is being applied. -- **Level 0**: Indicates pruning according to existing [generator.py](../../python/cutlass_library/generator.py) behavior. +- **Level 0**: Indicates pruning according to existing [generator.py](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/generator.py) behavior. An instantiation level `500`, which is padded to `0500`, thus indicates: @@ -95,7 +93,7 @@ An instantiation level `500`, which is padded to `0500`, thus indicates: - **Cluster Sizes**: At level 5, allowing for clusters with 1, 2, 4, 8, or 16 CTAs. - **Schedule Pruning**: At level 0, where pruning is applied according to the existing `generator.py` behavior. -### Mixed input data type kernels for Hopper +## Mixed input data type kernels for Hopper With Hopper (SM90), the kernel generator will generate the following combinations of mixed input data types ("mixed dtype"): @@ -118,7 +116,7 @@ For each mixed dtype kernel, the kernel generator will generate combinations of For {4-bits-dtype, 8-bits-dtype} x 16-bits-dtype, the kernel generator will further generate kernels using shuffled layouts for the narrow data type matrix, which may have a better performance compared to its non-shuffle counter parts. -### CUTLASS Profiler usage +## CUTLASS Profiler usage The CUTLASS Profiler usage statement may be obtained by executing `cutlass_profiler --help` and appears as follows. ```bash @@ -364,11 +362,11 @@ Profile when execution is performed on device 0 and the C tensor is located on a $ cutlass_profiler --device=0 --allocations=C:1,D:2 --operation=Gemm --m=1024 --n=1024 --k=128 ``` -The format of tensor argument is followed by `:`. The type could be `f32` as 32-bit floating point, `s8` as 8-bit signed integer, etc. The available types can be referred to the `NumericTypeID_enumerants` in [util.cu](tools/library/src/util.cu). The layout could be `row` or `column`. If `--enable_sm90_mixed_dtype_shuffle_test=true` is used, the actual layout of the narrow data type matrix is a shuffled layout, neither `row` nor `column`. +The format of tensor argument is followed by `:`. The type could be `f32` as 32-bit floating point, `s8` as 8-bit signed integer, etc. The available types can be referred to the `NumericTypeID_enumerants` in [util.cu](https://github.com/NVIDIA/cutlass/tree/main/tools/library/src/util.cu). The layout could be `row` or `column`. If `--enable_sm90_mixed_dtype_shuffle_test=true` is used, the actual layout of the narrow data type matrix is a shuffled layout, neither `row` nor `column`. In addition to encoded data types, CUTLASS profiler allows non-encoded generic data types, namely `f8`, `f6`, and `f4`, with corresponding encoding specified through GEMM input argument: `--runtime_input_datatype_a` and `--runtime_input_datatype_b`. Currently, six encoding schemes are supported: `e4m3`, `e5m2`, `e3m2`, `e2m3`, and `e2m1`. -Cluster shapes can be statically set to `Shape;` and specified via runtime arguments: `cluster_m`, `cluster_n` and `cluster_k` in CUTLASS profiler. In addition to preferred cluster shapes, a user can also specify fallback cluster shapes via runtime arguments: `cluster_m_fallback`, `cluster_n_fallback` and `cluster_k_fallback` in CUTLASS profiler. Those fallback cluster shapes are smaller shapes than the preferred ones for the hardware to assign when there is no chance to issue a larger preferred CGA cluster to the GPU. There are several rules for using a flexible CGA: 1) Preferred CGA size should be divisible by fallback CGA size. 2) Grid dim should be divisible by preferred CGA size. 3) Preferred CGA and fallback CGA must have the same depth (cluster_dim.z must be equal). One may refer to our CUTLASS Example [73_blackwell_gemm_flexible_cluster](../../examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for more details of the this feature. +Cluster shapes can be statically set to `Shape;` and specified via runtime arguments: `cluster_m`, `cluster_n` and `cluster_k` in CUTLASS profiler. In addition to preferred cluster shapes, a user can also specify fallback cluster shapes via runtime arguments: `cluster_m_fallback`, `cluster_n_fallback` and `cluster_k_fallback` in CUTLASS profiler. Those fallback cluster shapes are smaller shapes than the preferred ones for the hardware to assign when there is no chance to issue a larger preferred CGA cluster to the GPU. There are several rules for using a flexible CGA: 1) Preferred CGA size should be divisible by fallback CGA size. 2) Grid dim should be divisible by preferred CGA size. 3) Preferred CGA and fallback CGA must have the same depth (cluster_dim.z must be equal). One may refer to our CUTLASS Example [73_blackwell_gemm_flexible_cluster](https://github.com/NVIDIA/cutlass/tree/main/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for more details of the this feature. Please be noted that this feature (flexible cluster shapes within a single grid) is only applicable to `sm100a` kernels. The hardware will rasterize into a single cluster shape for those kernels that do not support this feature even with preferred or fallback cluster shapes assigned. CUTLASS 3.x kernels for Hopper and Blackwell also support a new feature called programatic dependent launch (PDL). This can be enabled with `--use-pdl`, and can overlap the epilogue of the prior kernel with the prologue of the next kernel. This can effectively hide kernel prologues. Using PDL can improve performance for back to back GEMMs. See [dependent kernel launch](dependent_kernel_launch.md) for more information. CUDA graphs can also be used (`--use-cuda-graphs`) with PDL to ensure that smaller kernels are enqueued back-to-back on a stream. diff --git a/media/docs/programming_guidelines.md b/media/docs/cpp/programming_guidelines.md similarity index 99% rename from media/docs/programming_guidelines.md rename to media/docs/cpp/programming_guidelines.md index d7d601a2..b85108d9 100644 --- a/media/docs/programming_guidelines.md +++ b/media/docs/cpp/programming_guidelines.md @@ -1,6 +1,4 @@ -![ALT](../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Programming Guidelines") - -[README](../../README.md#documentation) > **Programming Guidelines** +![ALT](../../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Programming Guidelines") # Programming Guidelines @@ -954,9 +952,9 @@ For example: ``` Header files such as -[cutlass/cutlass.h](../../include/cutlass/cutlass.h) +[cutlass/cutlass.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/cutlass.h) and -[cute/config.hpp](../../include/cutlass/cutlass.h) +[cute/config.hpp](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/cutlass.h) offer macros for expressing compiler-dependent behavior. These include diff --git a/media/docs/quickstart.md b/media/docs/cpp/quickstart.md similarity index 95% rename from media/docs/quickstart.md rename to media/docs/cpp/quickstart.md index f62d43b5..b728f7ed 100644 --- a/media/docs/quickstart.md +++ b/media/docs/cpp/quickstart.md @@ -1,6 +1,4 @@ -![ALT](../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Quick Start Guide") - -[README](../../README.md#documentation) > **Quick Start** +![ALT](../../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Quick Start Guide") # Quickstart @@ -217,7 +215,7 @@ $ cmake .. -DCUTLASS_NVCC_ARCHS="50;53" # compiles for NVIDIA Maxwell G ## Using CUTLASS within other applications -Applications should list [`/include`](/include) within their include paths. They must be +Applications should list [`/include`](https://github.com/NVIDIA/cutlass/tree/main/include) within their include paths. They must be compiled as C++17 or greater. **Example:** print the contents of a variable storing half-precision data. @@ -466,7 +464,7 @@ int main(int argc, char const **args) { # CUTLASS Library -The [CUTLASS Library](/tools/library) defines an API for managing and executing collections of compiled +The [CUTLASS Library](https://github.com/NVIDIA/cutlass/tree/main/tools/library) defines an API for managing and executing collections of compiled kernel instances and launching them from host code without template instantiations in client code. The host-side launch API is designed to be analogous to BLAS implementations for convenience, though its @@ -482,16 +480,16 @@ for dense matrix computations on NVIDIA GPUs. The CUTLASS Library is used by the CUTLASS Profiler to manage kernel instances, and it is also used by several SDK examples. -* [10_planar_complex](/examples/10_planar_complex/planar_complex.cu) -* [11_planar_complex_array](/examples/11_planar_complex_array/planar_complex_array.cu) +* [10_planar_complex](https://github.com/NVIDIA/cutlass/tree/main/examples/10_planar_complex/planar_complex.cu) +* [11_planar_complex_array](https://github.com/NVIDIA/cutlass/tree/main/examples/11_planar_complex_array/planar_complex_array.cu) The CUTLASS Library defines enumerated types describing numeric data types, matrix and tensor layouts, math operation classes, complex transformations, and more. -Client applications should specify [`tools/library/include`](/tools/library/include) in their +Client applications should specify [`tools/library/include`](https://github.com/NVIDIA/cutlass/tree/main/tools/library/include) in their include paths and link against libcutlas_lib.so. -The CUTLASS SDK example [10_planar_complex](/examples/10_planar_complex/CMakeLists.txt) specifies +The CUTLASS SDK example [10_planar_complex](https://github.com/NVIDIA/cutlass/tree/main/examples/10_planar_complex/CMakeLists.txt) specifies its dependency on the CUTLASS Library with the following CMake command. ``` target_link_libraries( @@ -662,7 +660,7 @@ $ cmake .. -DCUTLASS_NVCC_ARCHS='70;75;80' -DCUTLASS_LIBRARY_KERNELS=tensorop*s* ## Instantiating a Blackwell SM100 GEMM kernel Blackwell SM100 kernels are instantiated very similarly to Hopper kernels. Let us start with an -[FP8 GEMM without blockscaling](../../test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_s32_batch_alpha_beta.cu) +[FP8 GEMM without blockscaling](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_s32_batch_alpha_beta.cu) as an example. The kernel starts with setting up datatypes and cluster shapes. @@ -706,7 +704,7 @@ for Blackwell, so the epilogue fusion is built in a same way as an SM90 epilogue ``` One can refer to our Sm100 unit tests as examples of how to correctly -choose mainloop schedules. All of our dispatch policies can be found in [dispatch_policy.hpp](../../include/cutlass/gemm/dispatch_policy.hpp) +choose mainloop schedules. All of our dispatch policies can be found in [dispatch_policy.hpp](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp) and more comprehensive Blackwell specific documentation for valid dispatch policies can be in [blackwell_functionality.md](./blackwell_functionality.md). @@ -729,7 +727,7 @@ dispatch policies can be in [blackwell_functionality.md](./blackwell_functionali >; ``` -Instantiating a blockscaled GEMM kernel is slightly different. Referring to an [MXFP8 GEMM](./../../test/unit/gemm/device/sm100_gemm_mxf8_mxf8_mxf8_tensor_op_f32_auto.cu) sample unit test, it takes a different tensor operation class: +Instantiating a blockscaled GEMM kernel is slightly different. Referring to an [MXFP8 GEMM](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_gemm_mxf8_mxf8_mxf8_tensor_op_f32_auto.cu) sample unit test, it takes a different tensor operation class: ```c++ using ElementA = cutlass::mx_float8_t; diff --git a/media/docs/terminology.md b/media/docs/cpp/terminology.md similarity index 94% rename from media/docs/terminology.md rename to media/docs/cpp/terminology.md index f4e3a9d7..1c5d31ea 100644 --- a/media/docs/terminology.md +++ b/media/docs/cpp/terminology.md @@ -1,13 +1,11 @@ -![ALT](../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Terminology") - -[README](../../README.md#documentation) > **Terminology** +![ALT](../../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Terminology") # CUTLASS Terminology **cute::Layout**: A `cute::Layout` vocabulary type composed of the hierarchical `cute::Shape` and `cute::Stride` -tuples that is used throughout CUTLASS 3.0 to represent and manipulate thread and data layouts. More details are included in the [CuTe specific tensor type documentation](/media/docs/cute/03_tensor.md). +tuples that is used throughout CUTLASS 3.0 to represent and manipulate thread and data layouts. More details are included in the [CuTe specific tensor type documentation](cute/03_tensor.md). -**cute::Tensor**: A pointer backed by a `cute::Layout` used to represent a tensor. More details are included in the [CuTe specific tensor type documentation](/media/docs/cute/03_tensor.md). +**cute::Tensor**: A pointer backed by a `cute::Layout` used to represent a tensor. More details are included in the [CuTe specific tensor type documentation](cute/03_tensor.md). **Capacity**: (scalar) physical number of elements in memory required to store a multidimensional object; expressed as the type's LongIndex type - example: the capacity of a column-major matrix is `lda * N` @@ -71,7 +69,7 @@ contiguous and strided dimensions of a tile. `sizeof(Array)` - gives expected value in units of bytes with minimum storage of `1 B`: (sizeof_bits::value * N) / 8 **Operator**: an object performing a computation on matrix or tensor objects. May be further refined by scope within the execution model hierarchy. Deprecated starting CUTLASS 3.0, -replaced by [MMA and Copy atoms from CuTe](/media/docs/cute/0t_mma_atom.md). +replaced by [MMA and Copy atoms from CuTe](cute/0t_mma_atom.md). **Tile Iterator**: abstraction for accessing and traversing a sequence of tiles in a tensor; CUTLASS specifies [formal concepts for tile iterators](tile_iterator_concept.md). Deprecated starting CUTLASS 3.0. diff --git a/media/docs/tile_iterator_concept.md b/media/docs/cpp/tile_iterator_concept.md similarity index 99% rename from media/docs/tile_iterator_concept.md rename to media/docs/cpp/tile_iterator_concept.md index f8db020d..63a3eb0b 100644 --- a/media/docs/tile_iterator_concept.md +++ b/media/docs/cpp/tile_iterator_concept.md @@ -1,6 +1,4 @@ -![ALT](../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Tile Iterator Concepts") - -[README](../../README.md#documentation) > **Tile Iterator Concepts** +![ALT](../../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Tile Iterator Concepts") # Tile Iterator Concepts diff --git a/media/docs/utilities.md b/media/docs/cpp/utilities.md similarity index 95% rename from media/docs/utilities.md rename to media/docs/cpp/utilities.md index e8e1b98e..b6dffe05 100644 --- a/media/docs/utilities.md +++ b/media/docs/cpp/utilities.md @@ -1,13 +1,12 @@ -![ALT](../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Code Organization") +![ALT](../../images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Code Organization") -[README](../../README.md#documentation) > **CUTLASS Utilities** Note: This document discusses utilities commonly used with code that targets CUTLASS 2.x. Although CUTLASS 3.0's primary entry point APIs do not transact in these `cutlass::*` tensor types anymore, users can still find them convenient for managing allocations with trivial affine layouts. -For more advanced host side tensor management, [`cute::Tensor`](/media/docs/cute/03_tensor.md)s +For more advanced host side tensor management, [`cute::Tensor`](cute/03_tensor.md)s can be used on either host or device for any memory space and full expressive power of -[`cute::Layout`](/media/docs/cute/01_layout.md)s. +[`cute::Layout`](cute/01_layout.md)s. # CUTLASS Utilities @@ -17,12 +16,12 @@ flexible implementations of needed functionality, but they are not expected to b Applications should configure their builds to list `/tools/util/include` in their include paths. -Source code is in [`/tools/util/include/cutlass/util/`](/tools/util/include/cutlass/util). +Source code is in [`/tools/util/include/cutlass/util/`](https://github.com/NVIDIA/cutlass/tree/main/tools/util/include/cutlass/util). ## Tensor Allocation and I/O To allocate a tensor with storage in both host and device memory, use `HostTensor` in -[`cutlass/util/host_tensor.h`](/tools/util/include/cutlass/util/host_tensor.h) +[`cutlass/util/host_tensor.h`](https://github.com/NVIDIA/cutlass/tree/main/tools/util/include/cutlass/util/host_tensor.h) ```c++ template @@ -61,7 +60,7 @@ cutlass::TensorView device_view = tensor.de ``` Printing to human-readable CSV output is accoplished with `std::ostream::operator<<()` defined in -[`cutlass/util/tensor_view_io.h`](/tools/util/include/cutlass/util/tensor_view_io.h). +[`cutlass/util/tensor_view_io.h`](https://github.com/NVIDIA/cutlass/tree/main/tools/util/include/cutlass/util/tensor_view_io.h). Note, this assumes all views refer to host memory. ```c++ #include @@ -428,7 +427,7 @@ synclog at [synclog_at]: [header] line=[line] thread=[threadIdx.xyz] block=[bloc * `header`: Name of the synchronization event. * `line`: Code line number of the synchronization operation calling into `synclog`. -Additional information may appear at the end of each line, such as shared memory address, phase bit, and arrive count. For more detailed information on `synclog` output, refer to [synclog.hpp](../../include/cutlass/arch/synclog.hpp) in the CUTLASS source code. +Additional information may appear at the end of each line, such as shared memory address, phase bit, and arrive count. For more detailed information on `synclog` output, refer to [synclog.hpp](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/synclog.hpp) in the CUTLASS source code. Please note that `synclog` is an experimental feature, and its functionality is not always guaranteed. We encourage its use in custom kernels and CUTLASS examples, though it is known to be incompatible with profiler kernels. diff --git a/pyproject.toml b/pyproject.toml index f892fe9d..04571be9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "nvidia-cutlass" -version = "3.8.0.0" +version = "3.9.0.0" description = "CUTLASS" readme = "README.md" requires-python = ">=3.8" diff --git a/python/cutlass/__init__.py b/python/cutlass/__init__.py index d60e2846..cd4b4617 100644 --- a/python/cutlass/__init__.py +++ b/python/cutlass/__init__.py @@ -134,7 +134,7 @@ def get_option_registry(): this._option_registry = OptionRegistry(device_cc()) return this._option_registry -this.__version__ = '3.8.0' +this.__version__ = '3.9.0' from cutlass.backend import create_memory_pool from cutlass.emit.pytorch import pytorch diff --git a/python/cutlass_library/emit_kernel_listing.py b/python/cutlass_library/emit_kernel_listing.py index 5a954586..a6eca001 100755 --- a/python/cutlass_library/emit_kernel_listing.py +++ b/python/cutlass_library/emit_kernel_listing.py @@ -282,6 +282,8 @@ def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0): def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode ): profiler_reference_computing = "--verification-providers=device --providers=cutlass" + + # beta values for L0 and L1 # TODO: randomize beta values for wider coverage beta_values = [0.5] diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index ef9ed167..4ae65ed0 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -10025,7 +10025,8 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio tile_sizes_cooperative = [ [128, 128, 128], - [128, 128, 256] + [128, 128, 256], + [256, 128, 128] ] tile_sizes_pingpong = [ diff --git a/python/setup_library.py b/python/setup_library.py index 003111f3..d5f74b9a 100644 --- a/python/setup_library.py +++ b/python/setup_library.py @@ -36,7 +36,7 @@ from setuptools import setup def perform_setup(): setup( name='cutlass_library', - version='3.8.0', + version='3.9.0', description='CUTLASS library generation scripts', packages=['cutlass_library'] ) diff --git a/python/setup_pycute.py b/python/setup_pycute.py index 822dfe16..31f92295 100644 --- a/python/setup_pycute.py +++ b/python/setup_pycute.py @@ -36,7 +36,7 @@ from setuptools import setup def perform_setup(): setup( name='pycute', - version='3.8.0', + version='3.9.0', description='Python implementation of CuTe', packages=['pycute'], ) diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 4d10b025..09b28f00 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -140,7 +140,7 @@ function(cutlass_test_unit_add_executable_split_file NAME) if (CUTLASS_UNIT_TEST_SPLIT_FILES) execute_process( WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} - COMMAND ${Python3_EXECUTABLE} ${CUTLASS_SOURCE_DIR}/tools/scripts/split_test_cmake.py + COMMAND ${Python3_EXECUTABLE} ${CUTLASS_SOURCE_DIR}/tools/util/scripts/split_test_cmake.py ${NAME} ${CMAKE_CURRENT_SOURCE_DIR} --src_files ${SUBARGV} diff --git a/test/unit/cute/ampere/cooperative_gemm.cu b/test/unit/cute/ampere/cooperative_gemm.cu index b192ec72..bcfcb12d 100644 --- a/test/unit/cute/ampere/cooperative_gemm.cu +++ b/test/unit/cute/ampere/cooperative_gemm.cu @@ -499,3 +499,40 @@ TEST(SM80_CuTe_Ampere, CooperativeGemm2_Double_MMA_Predicated_Reg) { test_cooperative_gemm_col_major_layout_rmem_c(shape_mnk, tiled_mma); } + +TEST(SM80_CuTe_Ampere, CooperativeGemmLDSMx2) { + + constexpr uint32_t thread_block_size = 128; + constexpr int MaxVecBits = 128; + using TA = cute::half_t; + using TB = cute::half_t; + using TC = float; + + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout, Stride<_1, _2, _0>>, + Tile<_32, _16, _16> + >{}; + + auto global_a_layout = make_layout(Shape<_32, _32>{}, LayoutRight{}); + auto global_b_layout = make_layout(Shape<_16, _32>{}, LayoutRight{}); + auto global_c_layout = make_layout(Shape<_32, _16>{}, LayoutRight{}); + + test_cooperative_gemm + (global_a_layout, + global_b_layout, + global_c_layout, + global_a_layout, + global_b_layout, + global_c_layout, + tiled_mma, + identity{}, + identity{}, + identity{}, + identity{}, + SM75_U32x4_LDSM_N{}, + SM75_U32x2_LDSM_N{}); +} diff --git a/test/unit/cute/cooperative_gemm_common.hpp b/test/unit/cute/cooperative_gemm_common.hpp index e524dc28..3ff20d40 100644 --- a/test/unit/cute/cooperative_gemm_common.hpp +++ b/test/unit/cute/cooperative_gemm_common.hpp @@ -188,7 +188,8 @@ template + class SMemCopyLdOpC, + class SMemCopyStOpC> __launch_bounds__(ThreadBlockSize) __global__ void cooperative_gemm_kernel(GMemALayout gmem_a_layout, GMemBLayout gmem_b_layout, @@ -209,7 +210,8 @@ cooperative_gemm_kernel(GMemALayout gmem_a_layout, CStoreTransform c_store_transform, SMemCopyOpA a_copy_op, SMemCopyOpB b_copy_op, - SMemCopyOpC c_copy_op) + SMemCopyLdOpC c_copy_ld_op, + SMemCopyStOpC c_copy_st_op) { using namespace cute; @@ -242,7 +244,7 @@ cooperative_gemm_kernel(GMemALayout gmem_a_layout, threadIdx.x, tiled_mma, alpha, s_a_tensor, s_b_tensor, beta, s_c_tensor, a_load_transform, b_load_transform, c_load_transform, c_store_transform, - a_copy_op, b_copy_op, c_copy_op + a_copy_op, b_copy_op, c_copy_ld_op, c_copy_st_op ); __syncthreads(); @@ -366,7 +368,8 @@ template, class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment, - class CSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment> + class CSMemCopyLdOp = AutoVectorizingCopyWithAssumedAlignment, + class CSMemCopyStOp = AutoVectorizingCopyWithAssumedAlignment> void test_cooperative_gemm(GMemALayout gmem_a_layout, GMemBLayout gmem_b_layout, GMemCLayout gmem_c_layout, @@ -380,7 +383,8 @@ void test_cooperative_gemm(GMemALayout gmem_a_layout, CStoreTransform c_store_transform = {}, ASMemCopyOp a_smem_copy_op = {}, BSMemCopyOp b_smem_copy_op = {}, - CSMemCopyOp c_smem_copy_op = {}) + CSMemCopyLdOp c_smem_copy_ld_op = {}, + CSMemCopyStOp c_smem_copy_st_op = {}) { static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); @@ -428,7 +432,7 @@ void test_cooperative_gemm(GMemALayout gmem_a_layout, TA, TB, TC, decltype(alpha), decltype(beta), TiledMma, ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform, - ASMemCopyOp, BSMemCopyOp, CSMemCopyOp + ASMemCopyOp, BSMemCopyOp, CSMemCopyLdOp, CSMemCopyStOp >; ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); @@ -453,7 +457,8 @@ void test_cooperative_gemm(GMemALayout gmem_a_layout, c_store_transform, a_smem_copy_op, b_smem_copy_op, - c_smem_copy_op + c_smem_copy_ld_op, + c_smem_copy_st_op ); cudaError_t result = cudaDeviceSynchronize(); diff --git a/test/unit/cute/hopper/cooperative_gemm.cu b/test/unit/cute/hopper/cooperative_gemm.cu index c6ed2eb2..bb71e879 100644 --- a/test/unit/cute/hopper/cooperative_gemm.cu +++ b/test/unit/cute/hopper/cooperative_gemm.cu @@ -115,3 +115,46 @@ TEST(SM90_CuTe_Hopper, CooperativeGemmTilingF16) { } #endif + +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + +TEST(SM90_CuTe_Hopper, CooperativeGemmSTSM) { + + constexpr uint32_t thread_block_size = 128; + constexpr int MaxVecBits = 128; + using TA = cute::half_t; + using TB = cute::half_t; + using TC = cute::half_t; + + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout, Stride<_1, _2, _0>>, + Tile<_32, _32, _16> + >{}; + + auto global_a_layout = make_layout(Shape<_64, _64>{}, LayoutRight{}); + auto global_b_layout = make_layout(Shape<_64, _64>{}, LayoutRight{}); + auto global_c_layout = make_layout(Shape<_64, _64>{}, LayoutRight{}); + + test_cooperative_gemm + (global_a_layout, + global_b_layout, + global_c_layout, + global_a_layout, + global_b_layout, + global_c_layout, + tiled_mma, + identity{}, + identity{}, + identity{}, + identity{}, + SM75_U32x4_LDSM_N{}, + SM75_U32x4_LDSM_N{}, + SM75_U32x4_LDSM_N{}, + SM90_U32x4_STSM_N{}); +} + +#endif diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index d8482444..b601054f 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -2449,17 +2449,20 @@ struct HostCollectiveEpilogue { // example of how to set kernel activation arguments // see ActivationFunctor::Arguments in activation.h for definition // if Arguments doesn't exist then fusion_args.activation is empty + auto init_activation_args = [] (auto activation, auto& args) { + using Activation = cute::remove_cvref_t; + if constexpr (cute::is_same_v>) { + args.lower_bound = 0; // Treat Clamp as ReLU + args.upper_bound = cutlass::platform::identity_for_minimum(); + } + if constexpr (cute::is_same_v>) { + args.scale = ElementCompute(1); + } + }; - if constexpr (cute::is_same_v>) { - fusion_args.activation.scale = ElementCompute(1); + if constexpr (not cute::is_same_v>) { + init_activation_args(ActivationFunctor{}, fusion_args.activation); } - - // Treat Clamp as ReLU - if constexpr (cute::is_same_v>) { - fusion_args.activation.lower_bound = 0; - fusion_args.activation.upper_bound = std::numeric_limits::max(); - } - if constexpr (IsAbsMaxEnabledD) { fusion_args.amax_D_ptr = abs_max_D.device_data(); } diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/CMakeLists.txt b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/CMakeLists.txt index 380b4aa4..7ec9e3ff 100644 --- a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/CMakeLists.txt +++ b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/CMakeLists.txt @@ -45,24 +45,26 @@ cutlass_test_unit_gemm_device_add_executable( BATCH_SOURCES ON BATCH_SIZE 1 - sm120_bs_gemm_f4_f4_f32_f32_epilogue_fusion.cu - sm120_bs_gemm_f4_f4_f32_f4_epilogue_fusion.cu - sm120_bs_gemm_f4_f4_f32_bf16_epilogue_fusion.cu + sm120_bs_gemm_nvf4_nvf4_f32_f32_epilogue_fusion.cu + sm120_bs_gemm_nvf4_nvf4_f32_nvf4_epilogue_fusion.cu + sm120_bs_gemm_nvf4_nvf4_f32_bf16_epilogue_fusion.cu ) cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_bs_gemm_device_tensorop_sm120 - sm120_bs_gemm_f4_f4_f32_bf16.cu - sm120_bs_gemm_f4_f4_f32_f16.cu - sm120_bs_gemm_f4_f4_f32_f32.cu - sm120_bs_gemm_f4_f4_f32_f32_narrow_output.cu - sm120_bs_gemm_f4_f4_f32_epilogue.cu + sm120_bs_gemm_nvf4_nvf4_f32_bf16.cu + sm120_bs_gemm_nvf4_nvf4_f32_f16.cu + sm120_bs_gemm_nvf4_nvf4_f32_f32.cu + sm120_bs_gemm_nvf4_nvf4_f32_f32_narrow_output.cu + sm120_bs_gemm_nvf4_nvf4_f32_epilogue.cu + sm120_bs_gemm_mxf4_mxf4_f32_f32.cu + sm120_bs_gemm_mxf6_mxf8_f32_f32.cu ) cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_bs_gemm_device_tensorop_sm120_stream_k - sm120_bs_gemm_f4_f4_f32_f32_stream_k.cu + sm120_bs_gemm_nvf4_nvf4_f32_f32_stream_k.cu ) endif() diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_epilogue.cu b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_epilogue.cu deleted file mode 100644 index b2f5b878..00000000 --- a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_epilogue.cu +++ /dev/null @@ -1,590 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#include - -#include "cutlass/cutlass.h" -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" - -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/gemm/dispatch_policy.hpp" - -#include "../../../common/cutlass_unit_test.h" - -#include "../gemm_testbed_3x.hpp" - -#if (defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)) - -using namespace cute; - -/////////////////////////////////////////////////////////////////////////////// - -namespace kernel_1 { - using ElementA = cutlass::float_e2m1_t; - using ElementB = cutlass::float_e2m1_t; - using ElementC = float; - using ElementD = cutlass::float_e2m1_t; - using ElementAccumulator = float; - using ElementCompute = float; - using ElementSF = cutlass::float_ue4m3_t; - - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - using LayoutD = cutlass::layout::RowMajor; - using LayoutSFD = cutlass::layout::RowMajor; - - using ElementPairA = cutlass::nv_float4_t; - using ElementPairB = cutlass::nv_float4_t; - - static constexpr int AlignmentA = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. - static constexpr int AlignmentB = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - - using TileShape = Shape<_128,_128,_256>; - using ClusterShape = Shape<_1,_1,_1>; - - constexpr int SFVectorSize = 16; - using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< - SFVectorSize, - ElementD, - ElementCompute, - ElementSF, - LayoutSFD - >; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, - TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, LayoutC, AlignmentC, - ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, - FusionOperation - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, - ElementPairA, LayoutA, AlignmentA, - ElementPairB, LayoutB, AlignmentB, - ElementAccumulator, - TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 - >::CollectiveOp; - - template - struct dummy { - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>; // both void (default) and PersistentScheduler map to dynamic scheduler with CLC query - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - }; - using GemmKernel = typename dummy::GemmKernel; - using Gemm = typename dummy::Gemm; - - -} // kernel_1 - -namespace kernel_2 { - using ElementA = cutlass::float_e2m1_t; - using ElementB = cutlass::float_e2m1_t; - using ElementC = float; - using ElementD = cutlass::float_e2m1_t; - using ElementAccumulator = float; - using ElementCompute = float; - using ElementSF = cutlass::float_ue8m0_t; - - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - using LayoutD = cutlass::layout::RowMajor; - using LayoutSFD = cutlass::layout::RowMajor; - - using ElementPairA = cutlass::mx_float4_t; - using ElementPairB = cutlass::mx_float4_t; - - static constexpr int AlignmentA = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. - static constexpr int AlignmentB = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - - using TileShape = Shape<_128,_128,_256>; - using ClusterShape = Shape<_1,_1,_1>; - - constexpr int SFVectorSize = 32; - using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< - SFVectorSize, - ElementD, - ElementCompute, - ElementSF, - LayoutSFD - >; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, - TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, LayoutC, AlignmentC, - ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, - FusionOperation - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, - ElementPairA, LayoutA, AlignmentA, - ElementPairB, LayoutB, AlignmentB, - ElementAccumulator, - TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120 - >::CollectiveOp; - - template - struct dummy { - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>; // both void (default) and PersistentScheduler map to dynamic scheduler with CLC query - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - }; - using GemmKernel = typename dummy::GemmKernel; - using Gemm = typename dummy::Gemm; - -} // kernel_2 - -namespace kernel_3 { - using ElementA = cutlass::float_e2m1_t; - using ElementB = cutlass::float_e2m1_t; - using ElementC = cutlass::half_t; - using ElementD = cutlass::float_e2m1_t; - using ElementAccumulator = float; - using ElementCompute = float; - using ElementSF = cutlass::float_ue4m3_t; - - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - using LayoutD = cutlass::layout::RowMajor; - using LayoutSFD = cutlass::layout::RowMajor; - - using ElementPairA = cutlass::nv_float4_t; - using ElementPairB = cutlass::nv_float4_t; - - static constexpr int AlignmentA = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. - static constexpr int AlignmentB = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - - using TileShape = Shape<_128,_128,_256>; - using ClusterShape = Shape<_1,_1,_1>; - - constexpr int SFVectorSize = 16; - using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< - SFVectorSize, - ElementD, - ElementCompute, - ElementSF, - LayoutSFD - >; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, - TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, LayoutC, AlignmentC, - ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, - FusionOperation - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, - ElementPairA, LayoutA, AlignmentA, - ElementPairB, LayoutB, AlignmentB, - ElementAccumulator, - TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 - >::CollectiveOp; - - template - struct dummy { - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>; // both void (default) and PersistentScheduler map to dynamic scheduler with CLC query - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - }; - using GemmKernel = typename dummy::GemmKernel; - using Gemm = typename dummy::Gemm; - - -} // kernel_3 - -namespace kernel_4 { - using ElementA = cutlass::float_e2m1_t; - using ElementB = cutlass::float_e2m1_t; - using ElementC = cutlass::half_t; - using ElementD = cutlass::float_e2m1_t; - using ElementAccumulator = float; - using ElementCompute = float; - using ElementSF = cutlass::float_ue8m0_t; - - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - using LayoutD = cutlass::layout::RowMajor; - using LayoutSFD = cutlass::layout::RowMajor; - - using ElementPairA = cutlass::mx_float4_t; - using ElementPairB = cutlass::mx_float4_t; - - static constexpr int AlignmentA = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. - static constexpr int AlignmentB = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - - using TileShape = Shape<_128,_128,_256>; - using ClusterShape = Shape<_1,_1,_1>; - - constexpr int SFVectorSize = 32; - using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< - SFVectorSize, - ElementD, - ElementCompute, - ElementSF, - LayoutSFD - >; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, - TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, LayoutC, AlignmentC, - ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, - FusionOperation - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, - ElementPairA, LayoutA, AlignmentA, - ElementPairB, LayoutB, AlignmentB, - ElementAccumulator, - TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120 - >::CollectiveOp; - - template - struct dummy { - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>; // both void (default) and PersistentScheduler map to dynamic scheduler with CLC query - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - }; - using GemmKernel = typename dummy::GemmKernel; - using Gemm = typename dummy::Gemm; - -} // kernel_4 - -namespace kernel_5 { - using ElementA = cutlass::float_e2m1_t; - using ElementB = cutlass::float_e2m1_t; - using ElementC = cutlass::bfloat16_t; - using ElementD = cutlass::float_e2m1_t; - using ElementAccumulator = float; - using ElementCompute = float; - using ElementSF = cutlass::float_ue4m3_t; - - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - using LayoutD = cutlass::layout::RowMajor; - using LayoutSFD = cutlass::layout::RowMajor; - - using ElementPairA = cutlass::nv_float4_t; - using ElementPairB = cutlass::nv_float4_t; - - static constexpr int AlignmentA = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. - static constexpr int AlignmentB = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - - using TileShape = Shape<_128,_128,_256>; - using ClusterShape = Shape<_1,_1,_1>; - - constexpr int SFVectorSize = 16; - using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< - SFVectorSize, - ElementD, - ElementCompute, - ElementSF, - LayoutSFD - >; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, - TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, LayoutC, AlignmentC, - ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, - FusionOperation - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, - ElementPairA, LayoutA, AlignmentA, - ElementPairB, LayoutB, AlignmentB, - ElementAccumulator, - TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 - >::CollectiveOp; - - template - struct dummy { - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>; // both void (default) and PersistentScheduler map to dynamic scheduler with CLC query - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - }; - using GemmKernel = typename dummy::GemmKernel; - using Gemm = typename dummy::Gemm; - - -} // kernel_5 - -namespace kernel_6 { - using ElementA = cutlass::float_e2m1_t; - using ElementB = cutlass::float_e2m1_t; - using ElementC = cutlass::bfloat16_t; - using ElementD = cutlass::float_e2m1_t; - using ElementAccumulator = float; - using ElementCompute = float; - using ElementSF = cutlass::float_ue8m0_t; - - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - using LayoutD = cutlass::layout::RowMajor; - using LayoutSFD = cutlass::layout::RowMajor; - - using ElementPairA = cutlass::mx_float4_t; - using ElementPairB = cutlass::mx_float4_t; - - static constexpr int AlignmentA = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. - static constexpr int AlignmentB = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - - using TileShape = Shape<_128,_128,_256>; - using ClusterShape = Shape<_1,_1,_1>; - - constexpr int SFVectorSize = 32; - using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< - SFVectorSize, - ElementD, - ElementCompute, - ElementSF, - LayoutSFD - >; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, - TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, LayoutC, AlignmentC, - ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, - FusionOperation - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, - ElementPairA, LayoutA, AlignmentA, - ElementPairB, LayoutB, AlignmentB, - ElementAccumulator, - TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120 - >::CollectiveOp; - - template - struct dummy { - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>; // both void (default) and PersistentScheduler map to dynamic scheduler with CLC query - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - }; - using GemmKernel = typename dummy::GemmKernel; - using Gemm = typename dummy::Gemm; - -} // kernel_6 - -namespace kernel_7 { - using ElementA = cutlass::float_e2m1_t; - using ElementB = cutlass::float_e2m1_t; - using ElementC = void; - using ElementD = cutlass::float_e2m1_t; - using ElementAccumulator = float; - using ElementCompute = float; - using ElementSF = cutlass::float_ue4m3_t; - - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - using LayoutD = cutlass::layout::RowMajor; - using LayoutSFD = cutlass::layout::RowMajor; - - using ElementPairA = cutlass::nv_float4_t; - using ElementPairB = cutlass::nv_float4_t; - - static constexpr int AlignmentA = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. - static constexpr int AlignmentB = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - - using TileShape = Shape<_128,_128,_256>; - using ClusterShape = Shape<_1,_1,_1>; - - constexpr int SFVectorSize = 16; - using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< - SFVectorSize, - ElementD, - ElementCompute, - ElementSF, - LayoutSFD, - ElementC - >; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, - TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, LayoutC, AlignmentC, - ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, - FusionOperation - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, - ElementPairA, LayoutA, AlignmentA, - ElementPairB, LayoutB, AlignmentB, - ElementAccumulator, - TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 - >::CollectiveOp; - - template - struct dummy { - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>; // both void (default) and PersistentScheduler map to dynamic scheduler with CLC query - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - }; - using GemmKernel = typename dummy::GemmKernel; - using Gemm = typename dummy::Gemm; - - -} // kernel_7 - -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs16_tensor_op_f32_f32_epilogue_vs16, 128x128x256) { - bool result = test::gemm::device::TestSmall(1.0, 0.5); - EXPECT_TRUE(result); -} - -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs32_tensor_op_f32_f32_epilogue_vs32, 128x128x256) { - bool result = test::gemm::device::TestSmall(1.0, 0.5); - EXPECT_TRUE(result); -} - -// ==== mixed datatypes for C (fp16/bf16) / D (fp32) matrices ==== // -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs16_tensor_op_f16_f32_epilogue_vs16, 128x128x256) { - bool result = test::gemm::device::TestSmall(1.0, 0.5); - EXPECT_TRUE(result); -} - -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs32_tensor_op_f16_f32_epilogue_vs32, 128x128x256) { - bool result = test::gemm::device::TestSmall(1.0, 0.5); - EXPECT_TRUE(result); -} - -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs16_tensor_op_bf16_f32_epilogue_vs16, 128x128x256) { - bool result = test::gemm::device::TestSmall(1.0, 0.5); - EXPECT_TRUE(result); -} - -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs32_tensor_op_bf16_f32_epilogue_vs32, 128x128x256) { - bool result = test::gemm::device::TestSmall(1.0, 0.5); - EXPECT_TRUE(result); -} - -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs32_tensor_op_void_f32_epilogue_vs32, 128x128x256) { - bool result = test::gemm::device::TestSmallFusion(1.0, 0.0); - EXPECT_TRUE(result); -} - -#endif // (defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)) diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f32.cu b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf4_mxf4_f32_f32.cu similarity index 96% rename from test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f32.cu rename to test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf4_mxf4_f32_f32.cu index 02421673..60e9f4d0 100644 --- a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f32.cu +++ b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf4_mxf4_f32_f32.cu @@ -87,7 +87,7 @@ namespace kernel_1 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -97,7 +97,7 @@ namespace kernel_1 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -114,7 +114,8 @@ namespace kernel_1 { } // kernel_1 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs32_tensor_op_f32_static_sched, 128x128x256) { + +TEST(SM120_Device_Blockscaled_Gemm_mxf4t_mxf4n_f32n_tensor_op_f32, 128x128x256) { bool result = test::gemm::device::TestSmall(1.0, 0.5); EXPECT_TRUE(result); } diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf6_mxf8_f32_f32.cu b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf6_mxf8_f32_f32.cu new file mode 100644 index 00000000..64b79cad --- /dev/null +++ b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf6_mxf8_f32_f32.cu @@ -0,0 +1,123 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +#if (defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +namespace kernel_1 { + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float8_t; + + static constexpr int AlignmentA = 64 * 8 / cutlass::sizeof_bits::value; // Align to 64 bytes. + static constexpr int AlignmentB = 96 * 8 / cutlass::sizeof_bits::value; // Align to 96 bytes. + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using TileShape = Shape<_128,_128,_128>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + ElementPairA, LayoutA, AlignmentA, + ElementPairB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + template + struct dummy { + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + }; + using GemmKernel = typename dummy::GemmKernel; + using Gemm = typename dummy::Gemm; + +} // kernel_1 + + +TEST(SM120_Device_Blockscaled_Gemm_mxf6t_mxf8n_f32n_tensor_op_f32, 128x128x128) { + bool result = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(result); +} + +#endif // (defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)) diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_bf16.cu b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_bf16.cu similarity index 64% rename from test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_bf16.cu rename to test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_bf16.cu index e64d6483..782eec30 100644 --- a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_bf16.cu +++ b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_bf16.cu @@ -87,7 +87,7 @@ namespace kernel_1 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -97,7 +97,7 @@ namespace kernel_1 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -115,74 +115,10 @@ namespace kernel_1 { } // kernel_1 -namespace kernel_3 { - using ElementA = cutlass::float_e2m1_t; - using ElementB = cutlass::float_e2m1_t; - using ElementC = cutlass::bfloat16_t; - using ElementD = cutlass::bfloat16_t; - using ElementAccumulator = float; - using ElementCompute = float; - using ElementSF = cutlass::float_ue8m0_t; - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - using LayoutD = cutlass::layout::ColumnMajor; - - using ElementPairA = cutlass::mx_float4_t; - using ElementPairB = cutlass::mx_float4_t; - - static constexpr int AlignmentA = 64 * 8 / cutlass::sizeof_bits::value; // Align to 64 bytes. - static constexpr int AlignmentB = 64 * 8 / cutlass::sizeof_bits::value; // Align to 64 bytes. - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - - using TileShape = Shape<_128,_128,_128>; - using ClusterShape = Shape<_1,_1,_1>; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, - TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, LayoutC, AlignmentC, - ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, - ElementPairA, LayoutA, AlignmentA, - ElementPairB, LayoutB, AlignmentB, - ElementAccumulator, - TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedMxf8f6f4Sm120 - >::CollectiveOp; - - template - struct dummy { - using TileSchedulerTag = cutlass::gemm::PersistentScheduler; // both void (default) and PersistentScheduler map to dynamic scheduler with CLC query - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - TileSchedulerTag>; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - }; - using GemmKernel = typename dummy::GemmKernel; - using Gemm = typename dummy::Gemm; - -} // kernel_3 - -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs16_tensor_op_bf16, 128x128x256) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_f32n_tensor_op_bf16, 128x128x256) { bool result = test::gemm::device::TestSmall(1.0, 0.5); EXPECT_TRUE(result); } -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs32_tensor_op_bf16, 128x128x128) { - bool result = test::gemm::device::TestSmall(1.0, 0.5); - EXPECT_TRUE(result); -} - #endif // (defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)) diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_bf16_epilogue_fusion.cu b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_bf16_epilogue_fusion.cu similarity index 93% rename from test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_bf16_epilogue_fusion.cu rename to test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_bf16_epilogue_fusion.cu index ef7a1e5f..5ae10746 100644 --- a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_bf16_epilogue_fusion.cu +++ b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_bf16_epilogue_fusion.cu @@ -97,7 +97,7 @@ namespace kernel_1 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -108,7 +108,7 @@ namespace kernel_1 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -167,7 +167,7 @@ namespace kernel_2 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -178,7 +178,7 @@ namespace kernel_2 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -237,7 +237,7 @@ namespace kernel_3 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -248,7 +248,7 @@ namespace kernel_3 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -309,7 +309,7 @@ namespace kernel_4 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -320,7 +320,7 @@ namespace kernel_4 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -343,7 +343,7 @@ namespace kernel_4 { // Acc: fp32 // Scale (alpha, beta): fp32 // D: bf16 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_bf16n_vs16_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_col_bias_relu) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_bf16n_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_col_bias_relu) { bool result = test::gemm::device::TestSmallFusion(1.0, 0.5); EXPECT_TRUE(result); } @@ -354,7 +354,7 @@ TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_bf16n_vs16_tensor_op_f32_f32_ep // Acc: fp32 // Scale (alpha, beta): fp32 // D: bf16 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_bf16n_vs16_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_col_bias_gelu) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_bf16n_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_col_bias_gelu) { bool result = test::gemm::device::TestSmallFusion(1.0, 0.5); EXPECT_TRUE(result); } @@ -365,7 +365,7 @@ TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_bf16n_vs16_tensor_op_f32_f32_ep // Acc: fp32 // Scale (alpha, beta): fp32 // D: bf16 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_bf16n_vs16_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_row_bias_relu) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_bf16n_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_row_bias_relu) { bool result = test::gemm::device::TestSmallFusion(1.0, 0.5); EXPECT_TRUE(result); } @@ -377,7 +377,7 @@ TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_bf16n_vs16_tensor_op_f32_f32_ep // Acc: fp32 // Scale (alpha, beta): fp32 // D: bf16 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_bf16n_vs16_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_row_bias_gelu) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_bf16n_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_row_bias_gelu) { bool result = test::gemm::device::TestSmallFusion(1.0, 0.5); EXPECT_TRUE(result); } diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f16.cu b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_epilogue.cu similarity index 59% rename from test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f16.cu rename to test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_epilogue.cu index ad0d8a2e..a88bc06b 100644 --- a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f16.cu +++ b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_epilogue.cu @@ -39,6 +39,7 @@ #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" @@ -58,8 +59,8 @@ using namespace cute; namespace kernel_1 { using ElementA = cutlass::float_e2m1_t; using ElementB = cutlass::float_e2m1_t; - using ElementC = cutlass::half_t; - using ElementD = cutlass::half_t; + using ElementC = float; + using ElementD = cutlass::float_e2m1_t; using ElementAccumulator = float; using ElementCompute = float; using ElementSF = cutlass::float_ue4m3_t; @@ -67,7 +68,8 @@ namespace kernel_1 { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; - using LayoutD = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::RowMajor; + using LayoutSFD = cutlass::layout::RowMajor; using ElementPairA = cutlass::nv_float4_t; using ElementPairB = cutlass::nv_float4_t; @@ -80,6 +82,15 @@ namespace kernel_1 { using TileShape = Shape<_128,_128,_256>; using ClusterShape = Shape<_1,_1,_1>; + constexpr int SFVectorSize = 16; + using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + SFVectorSize, + ElementD, + ElementCompute, + ElementSF, + LayoutSFD + >; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, @@ -87,7 +98,8 @@ namespace kernel_1 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -97,40 +109,41 @@ namespace kernel_1 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template struct dummy { - using TileSchedulerTag = cutlass::gemm::PersistentScheduler; // both void (default) and PersistentScheduler map to dynamic scheduler with CLC query using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveMainloop, CollectiveEpilogue, - TileSchedulerTag>; + cutlass::gemm::PersistentScheduler>; // both void (default) and PersistentScheduler map to dynamic scheduler with CLC query using Gemm = cutlass::gemm::device::GemmUniversalAdapter; }; using GemmKernel = typename dummy::GemmKernel; using Gemm = typename dummy::Gemm; + } // kernel_1 namespace kernel_2 { using ElementA = cutlass::float_e2m1_t; using ElementB = cutlass::float_e2m1_t; using ElementC = cutlass::half_t; - using ElementD = cutlass::half_t; + using ElementD = cutlass::float_e2m1_t; using ElementAccumulator = float; using ElementCompute = float; - using ElementSF = cutlass::float_ue8m0_t; + using ElementSF = cutlass::float_ue4m3_t; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; - using LayoutD = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::RowMajor; + using LayoutSFD = cutlass::layout::RowMajor; - using ElementPairA = cutlass::mx_float4_t; - using ElementPairB = cutlass::mx_float4_t; + using ElementPairA = cutlass::nv_float4_t; + using ElementPairB = cutlass::nv_float4_t; static constexpr int AlignmentA = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. static constexpr int AlignmentB = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. @@ -140,6 +153,15 @@ namespace kernel_2 { using TileShape = Shape<_128,_128,_256>; using ClusterShape = Shape<_1,_1,_1>; + constexpr int SFVectorSize = 16; + using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + SFVectorSize, + ElementD, + ElementCompute, + ElementSF, + LayoutSFD + >; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, @@ -147,7 +169,8 @@ namespace kernel_2 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -157,97 +180,189 @@ namespace kernel_2 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template struct dummy { - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue - >; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - }; - using GemmKernel = typename dummy::GemmKernel; - using Gemm = typename dummy::Gemm; - -} // kernel_2 - - -namespace kernel_3 { - using ElementA = cutlass::float_e2m1_t; - using ElementB = cutlass::float_e2m1_t; - using ElementC = cutlass::half_t; - using ElementD = cutlass::half_t; - using ElementAccumulator = float; - using ElementCompute = float; - using ElementSF = cutlass::float_ue8m0_t; - - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - using LayoutD = cutlass::layout::ColumnMajor; - - using ElementPairA = cutlass::mx_float4_t; - using ElementPairB = cutlass::mx_float4_t; - - static constexpr int AlignmentA = 64 * 8 / cutlass::sizeof_bits::value; // Align to 64 bytes. - static constexpr int AlignmentB = 64 * 8 / cutlass::sizeof_bits::value; // Align to 64 bytes. - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - - using TileShape = Shape<_128,_128,_128>; - using ClusterShape = Shape<_1,_1,_1>; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, - TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, LayoutC, AlignmentC, - ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, - ElementPairA, LayoutA, AlignmentA, - ElementPairB, LayoutB, AlignmentB, - ElementAccumulator, - TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedMxf8f6f4Sm120 - >::CollectiveOp; - - template - struct dummy { - using TileSchedulerTag = cutlass::gemm::PersistentScheduler; // both void (default) and PersistentScheduler map to dynamic scheduler with CLC query using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveMainloop, CollectiveEpilogue, - TileSchedulerTag>; + cutlass::gemm::PersistentScheduler>; // both void (default) and PersistentScheduler map to dynamic scheduler with CLC query using Gemm = cutlass::gemm::device::GemmUniversalAdapter; }; using GemmKernel = typename dummy::GemmKernel; using Gemm = typename dummy::Gemm; + +} // kernel_2 + +namespace kernel_3 { + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = cutlass::bfloat16_t; + using ElementD = cutlass::float_e2m1_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue4m3_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::RowMajor; + using LayoutSFD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::nv_float4_t; + using ElementPairB = cutlass::nv_float4_t; + + static constexpr int AlignmentA = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. + static constexpr int AlignmentB = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using TileShape = Shape<_128,_128,_256>; + using ClusterShape = Shape<_1,_1,_1>; + + constexpr int SFVectorSize = 16; + using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + SFVectorSize, + ElementD, + ElementCompute, + ElementSF, + LayoutSFD + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + ElementPairA, LayoutA, AlignmentA, + ElementPairB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + template + struct dummy { + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>; // both void (default) and PersistentScheduler map to dynamic scheduler with CLC query + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + }; + using GemmKernel = typename dummy::GemmKernel; + using Gemm = typename dummy::Gemm; + + } // kernel_3 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs16_tensor_op_f16, 128x128x256) { +namespace kernel_4 { + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = void; + using ElementD = cutlass::float_e2m1_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue4m3_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::RowMajor; + using LayoutSFD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::nv_float4_t; + using ElementPairB = cutlass::nv_float4_t; + + static constexpr int AlignmentA = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. + static constexpr int AlignmentB = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using TileShape = Shape<_128,_128,_256>; + using ClusterShape = Shape<_1,_1,_1>; + + constexpr int SFVectorSize = 16; + using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + SFVectorSize, + ElementD, + ElementCompute, + ElementSF, + LayoutSFD, + ElementC + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + ElementPairA, LayoutA, AlignmentA, + ElementPairB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + template + struct dummy { + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>; // both void (default) and PersistentScheduler map to dynamic scheduler with CLC query + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + }; + using GemmKernel = typename dummy::GemmKernel; + using Gemm = typename dummy::Gemm; + + +} // kernel_4 + +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_nvf4t_tensor_op_f32_f32_epilogue_vs16, 128x128x256) { bool result = test::gemm::device::TestSmall(1.0, 0.5); EXPECT_TRUE(result); } -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs32_tensor_op_f16_static_sched, 128x128x256) { + +// ==== mixed datatypes for C (fp16/bf16) / D (fp32) matrices ==== // +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_nvf4t_tensor_op_f16_f32_epilogue_vs16, 128x128x256) { bool result = test::gemm::device::TestSmall(1.0, 0.5); EXPECT_TRUE(result); } -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs32_tensor_op_f16, 128x128x128) { + +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_nvf4t_tensor_op_bf16_f32_epilogue_vs16, 128x128x256) { bool result = test::gemm::device::TestSmall(1.0, 0.5); EXPECT_TRUE(result); } + +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_nvf4t_tensor_op_void_f32_epilogue_vs32, 128x128x256) { + bool result = test::gemm::device::TestSmallFusion(1.0, 0.0); + EXPECT_TRUE(result); +} + #endif // (defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)) diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_f16.cu b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_f16.cu new file mode 100644 index 00000000..22fcf89a --- /dev/null +++ b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_f16.cu @@ -0,0 +1,126 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +#if (defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +namespace kernel_1 { + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue4m3_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + using ElementPairA = cutlass::nv_float4_t; + using ElementPairB = cutlass::nv_float4_t; + + static constexpr int AlignmentA = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. + static constexpr int AlignmentB = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using TileShape = Shape<_128,_128,_256>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + ElementPairA, LayoutA, AlignmentA, + ElementPairB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + template + struct dummy { + using TileSchedulerTag = cutlass::gemm::PersistentScheduler; // both void (default) and PersistentScheduler map to dynamic scheduler with CLC query + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileSchedulerTag>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + }; + using GemmKernel = typename dummy::GemmKernel; + using Gemm = typename dummy::Gemm; + +} // kernel_1 + + + +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_f32n_tensor_op_f16, 128x128x256) { + bool result = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(result); +} + + +#endif // (defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)) diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_f32.cu b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_f32.cu new file mode 100644 index 00000000..51540341 --- /dev/null +++ b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_f32.cu @@ -0,0 +1,125 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +#if (defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + + +namespace kernel_1 { + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue4m3_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + using ElementPairA = cutlass::nv_float4_t; + using ElementPairB = cutlass::nv_float4_t; + + static constexpr int AlignmentA = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. + static constexpr int AlignmentB = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using TileShape = Shape<_128,_128,_256>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + ElementPairA, LayoutA, AlignmentA, + ElementPairB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + template + struct dummy { + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + }; + using GemmKernel = typename dummy::GemmKernel; + using Gemm = typename dummy::Gemm; + +} // kernel_1 + + + +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_f32n_tensor_op_f32, 128x128x256) { + bool result = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(result); +} + +#endif // (defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)) diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f32_epilogue_fusion.cu b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_f32_epilogue_fusion.cu similarity index 93% rename from test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f32_epilogue_fusion.cu rename to test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_f32_epilogue_fusion.cu index 16c69bbf..0e76bf69 100644 --- a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f32_epilogue_fusion.cu +++ b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_f32_epilogue_fusion.cu @@ -97,7 +97,7 @@ namespace kernel_1 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -108,7 +108,7 @@ namespace kernel_1 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -167,7 +167,7 @@ namespace kernel_2 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -178,7 +178,7 @@ namespace kernel_2 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -237,7 +237,7 @@ namespace kernel_3 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -248,7 +248,7 @@ namespace kernel_3 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -307,7 +307,7 @@ namespace kernel_4 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -318,7 +318,7 @@ namespace kernel_4 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -377,7 +377,7 @@ namespace kernel_5 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -388,7 +388,7 @@ namespace kernel_5 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -447,7 +447,7 @@ namespace kernel_6 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -458,7 +458,7 @@ namespace kernel_6 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -481,7 +481,7 @@ namespace kernel_6 { // Acc: fp32 // Scale (alpha, beta): fp32 // D: fp32 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs16_tensor_op_f32_f32_epilogue, 128x128x256_per_row_bias_relu) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_f32n_tensor_op_f32_f32_epilogue, 128x128x256_per_row_bias_relu) { bool result = test::gemm::device::TestSmallFusion(1.0, 0.5); EXPECT_TRUE(result); } @@ -492,7 +492,7 @@ TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs16_tensor_op_f32_f32_epi // Acc: fp32 // Scale (alpha, beta): fp32 // D: fp32 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs16_tensor_op_f32_f32_epilogue, 128x128x256_per_row_bias_gelu) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_f32n_tensor_op_f32_f32_epilogue, 128x128x256_per_row_bias_gelu) { bool result = test::gemm::device::TestSmallFusion(1.0, 0.5); EXPECT_TRUE(result); } @@ -503,7 +503,7 @@ TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs16_tensor_op_f32_f32_epi // Acc: fp32 // Scale (alpha, beta): fp32 // D: fp32 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs16_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_col_bias_gelu) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_f32n_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_col_bias_gelu) { bool result = test::gemm::device::TestSmallFusion(1.0, 0.5); EXPECT_TRUE(result); } @@ -514,7 +514,7 @@ TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs16_tensor_op_f32_f32_epi // Acc: fp32 // Scale (alpha, beta): fp32 // D: fp32 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs16_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_col_bias_relu) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_f32n_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_col_bias_relu) { bool result = test::gemm::device::TestSmallFusion(1.0, 0.5); EXPECT_TRUE(result); } @@ -525,7 +525,7 @@ TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs16_tensor_op_f32_f32_epi // Acc: fp32 // Scale (alpha, beta): fp32 // D: fp32 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs16_tensor_op_f32_f32_epilogue, 128x128x256_per_row_bias_clamp) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_f32n_tensor_op_f32_f32_epilogue, 128x128x256_per_row_bias_clamp) { bool result = test::gemm::device::TestSmallFusion(1.0, 0.5); EXPECT_TRUE(result); } @@ -536,7 +536,7 @@ TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs16_tensor_op_f32_f32_epi // Acc: fp32 // Scale (alpha, beta): fp32 // D: fp32 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs16_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_col_bias_clamp) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_f32n_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_col_bias_clamp) { bool result = test::gemm::device::TestSmallFusion(1.0, 0.5); EXPECT_TRUE(result); } diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f32_narrow_output.cu b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_f32_narrow_output.cu similarity index 62% rename from test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f32_narrow_output.cu rename to test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_f32_narrow_output.cu index fd584cee..069b66eb 100644 --- a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f32_narrow_output.cu +++ b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_f32_narrow_output.cu @@ -102,7 +102,7 @@ namespace kernel_1 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -113,7 +113,7 @@ namespace kernel_1 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -130,89 +130,9 @@ namespace kernel_1 { } // kernel_1 -namespace kernel_2 { - using ElementA = cutlass::float_e2m1_t; - using ElementB = cutlass::float_e2m1_t; - using ElementC = cutlass::bfloat16_t; - using ElementD = cutlass::float_e2m3_t; - - using ElementAccumulator = float; - using ElementCompute = float; - using ElementSF = cutlass::float_ue8m0_t; - - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::RowMajor; - using LayoutD = cutlass::layout::RowMajor; - - using ElementPairA = cutlass::mx_float4_t; - using ElementPairB = cutlass::mx_float4_t; - - static constexpr int AlignmentA = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. - static constexpr int AlignmentB = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - - using TileShape = Shape<_128,_128,_256>; - using ClusterShape = Shape<_1,_1,_1>; - - constexpr int SFVectorSize = 32; - using LayoutSFD = cutlass::layout::RowMajor; - using ElementBias = cutlass::bfloat16_t; - using GmemLayoutSFC = cutlass::layout::RowMajor; - - using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltActBlockScaleFactor< - cutlass::epilogue::thread::ReLU, - SFVectorSize, - ElementD, - ElementCompute, - ElementSF, LayoutSFD, - ElementBias, - ElementC>; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, - TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, LayoutC, AlignmentC, - ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative - ,FusionOperation - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, - ElementPairA, LayoutA, AlignmentA, - ElementPairB, LayoutB, AlignmentB, - ElementAccumulator, - TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120 - >::CollectiveOp; - - template - struct dummy { - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>; // both void (default) and PersistentScheduler map to dynamic scheduler with CLC query - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - }; - using GemmKernel = typename dummy::GemmKernel; - using Gemm = typename dummy::Gemm; - -} // kernel_2 - -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs16_tensor_op_f32_fe2m3n, 128x128x256) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_f32n_tensor_op_f32_fe2m3n, 128x128x256) { bool result = test::gemm::device::TestSmall(1.0, 0.5); EXPECT_TRUE(result); } -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs32_tensor_op_f32_fe2m3n, 128x128x256) { - bool result = test::gemm::device::TestSmallFusion(1.0, 0); - EXPECT_TRUE(result); -} - #endif // (defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)) diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f32_stream_k.cu b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_f32_stream_k.cu similarity index 96% rename from test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f32_stream_k.cu rename to test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_f32_stream_k.cu index 7f1132d2..caff7ff9 100644 --- a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f32_stream_k.cu +++ b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_f32_stream_k.cu @@ -87,7 +87,7 @@ namespace kernel_1 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -97,7 +97,7 @@ namespace kernel_1 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -115,7 +115,7 @@ namespace kernel_1 { } // kernel_1 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_f32n_vs16_tensor_op_f32_stream_k, 128x128x256) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_f32n_tensor_op_f32_stream_k, 128x128x256) { bool result = test::gemm::device::TestSmall(1.0, 0.5); EXPECT_TRUE(result); } diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f4_epilogue_fusion.cu b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_nvf4_epilogue_fusion.cu similarity index 93% rename from test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f4_epilogue_fusion.cu rename to test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_nvf4_epilogue_fusion.cu index 1c864f3e..818b6b74 100644 --- a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_f4_f4_f32_f4_epilogue_fusion.cu +++ b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_nvf4_epilogue_fusion.cu @@ -104,7 +104,7 @@ namespace kernel_1 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -115,7 +115,7 @@ namespace kernel_1 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -182,7 +182,7 @@ namespace kernel_2 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -193,7 +193,7 @@ namespace kernel_2 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -260,7 +260,7 @@ namespace kernel_3 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -271,7 +271,7 @@ namespace kernel_3 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -336,7 +336,7 @@ namespace kernel_4 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -347,7 +347,7 @@ namespace kernel_4 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -413,7 +413,7 @@ namespace kernel_5 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -424,7 +424,7 @@ namespace kernel_5 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -490,7 +490,7 @@ namespace kernel_6 { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::collective::EpilogueScheduleAuto, FusionOperation >::CollectiveOp; @@ -501,7 +501,7 @@ namespace kernel_6 { ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120 + cutlass::gemm::KernelTmaWarpSpecializedCooperative >::CollectiveOp; template @@ -527,7 +527,7 @@ namespace kernel_6 { // Acc: fp32 // Scale (alpha, beta): fp32 // D: bf16 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs16_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_row_bias) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_nvf4t_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_row_bias) { bool result = test::gemm::device::TestSmallFusion(1.0, 0.5); EXPECT_TRUE(result); } @@ -538,7 +538,7 @@ TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs16_tensor_op_f32_f32_e // Acc: fp32 // Scale (alpha, beta): fp32 // D: bf16 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs16_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_row_bias_relu) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_nvf4t_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_row_bias_relu) { bool result = test::gemm::device::TestSmallFusion(1.0, 0.5); EXPECT_TRUE(result); } @@ -549,7 +549,7 @@ TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs16_tensor_op_f32_f32_e // Acc: fp32 // Scale (alpha, beta): fp32 // D: bf16 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs16_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_row_bias_gelu) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_nvf4t_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_row_bias_gelu) { bool result = test::gemm::device::TestSmallFusion(1.0, 0.5); EXPECT_TRUE(result); } @@ -560,7 +560,7 @@ TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs16_tensor_op_f32_f32_e // Acc: fp32 // Scale (alpha, beta): fp32 // D: bf16 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs16_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_col_bias) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_nvf4t_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_col_bias) { bool result = test::gemm::device::TestSmallFusion(1.0, 0.5); EXPECT_TRUE(result); } @@ -571,7 +571,7 @@ TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs16_tensor_op_f32_f32_e // Acc: fp32 // Scale (alpha, beta): fp32 // D: bf16 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs16_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_col_bias_relu) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_nvf4t_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_col_bias_relu) { bool result = test::gemm::device::TestSmallFusion(1.0, 0.5); EXPECT_TRUE(result); } @@ -582,7 +582,7 @@ TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs16_tensor_op_f32_f32_e // Acc: fp32 // Scale (alpha, beta): fp32 // D: bf16 -TEST(SM120_Device_Blockscaled_Gemm_fe2m1t_fe2m1n_fe2m1t_vs16_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_col_bias_gelu) { +TEST(SM120_Device_Blockscaled_Gemm_nvf4t_nvf4n_nvf4t_tensor_op_f32_f32_epilogue, 128x128x256_alpha_beta_per_col_bias_gelu) { bool result = test::gemm::device::TestSmallFusion(1.0, 0.5); EXPECT_TRUE(result); } diff --git a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f4_f16_tensor_op.cu b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f4_f16_tensor_op.cu index 458ecf65..38f53cd9 100644 --- a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f4_f16_tensor_op.cu +++ b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f4_f16_tensor_op.cu @@ -83,7 +83,7 @@ TEST(SM120_Device_Gemm_fe2m1t_fe2m1n_f16n_void_f32_tensor_op, 128x64x128_1x1x1) ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -134,7 +134,7 @@ TEST(SM120_Device_Gemm_fe2m1t_fe2m1n_f16n_void_f16_tensor_op, 128x64x128_1x1x1) ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -185,7 +185,7 @@ TEST(SM120_Device_Gemm_fe2m1t_fe2m1n_f16n_tensor_op_f32, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -236,7 +236,7 @@ TEST(SM120_Device_Gemm_fe2m1t_fe2m1n_f16n_tensor_op_f16, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< diff --git a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f4_f32_tensor_op.cu b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f4_f32_tensor_op.cu index f83637d1..12dcc0e9 100644 --- a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f4_f32_tensor_op.cu +++ b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f4_f32_tensor_op.cu @@ -83,7 +83,7 @@ TEST(SM120_Device_Gemm_fe2m1t_fe2m1n_f32n_tensor_op_f32, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< diff --git a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f6_f16_tensor_op.cu b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f6_f16_tensor_op.cu index 5c34b5e7..e51db253 100644 --- a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f6_f16_tensor_op.cu +++ b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f6_f16_tensor_op.cu @@ -83,7 +83,7 @@ TEST(SM120_Device_Gemm_fe2m1t_fe3m2n_f16n_tensor_op_f32, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -134,7 +134,7 @@ TEST(SM120_Device_Gemm_fe2m3t_fe2m1n_f16n_tensor_op_f32, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< diff --git a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f6_f16_tensor_op_narrow_output.cu b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f6_f16_tensor_op_narrow_output.cu index 6e93ea3c..d9f93676 100644 --- a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f6_f16_tensor_op_narrow_output.cu +++ b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f6_f16_tensor_op_narrow_output.cu @@ -85,7 +85,7 @@ TEST(SM120_Device_Gemm_fe2m1t_fe3m2n_f16n_tensor_op_fe2m3n, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -138,7 +138,7 @@ TEST(SM120_Device_Gemm_fe2m3t_fe2m1n_f16n_tensor_op_fe2m1t, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< diff --git a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f6_f32_tensor_op.cu b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f6_f32_tensor_op.cu index 697cf3e3..864e5137 100644 --- a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f6_f32_tensor_op.cu +++ b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f6_f32_tensor_op.cu @@ -83,7 +83,7 @@ TEST(SM120_Device_Gemm_fe2m1t_fe3m2n_f32n_tensor_op_f32, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -134,7 +134,7 @@ TEST(SM120_Device_Gemm_fe2m3t_fe2m1n_f32n_tensor_op_f32, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< diff --git a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f6_f32_tensor_op_narrow_output.cu b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f6_f32_tensor_op_narrow_output.cu index eb2278be..f7e3026f 100644 --- a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f6_f32_tensor_op_narrow_output.cu +++ b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f6_f32_tensor_op_narrow_output.cu @@ -85,7 +85,7 @@ TEST(SM120_Device_Gemm_fe2m1t_fe3m2n_f32n_tensor_op_fe2m3n, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -137,7 +137,7 @@ TEST(SM120_Device_Gemm_fe2m3t_fe2m1n_f32n_tensor_op_fe2m1t, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< diff --git a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f8_f16_tensor_op.cu b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f8_f16_tensor_op.cu index 1c969677..326ae459 100644 --- a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f8_f16_tensor_op.cu +++ b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f8_f16_tensor_op.cu @@ -83,7 +83,7 @@ TEST(SM120_Device_Gemm_fe2m1t_fe5m2n_f16n_tensor_op_f32, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -134,7 +134,7 @@ TEST(SM120_Device_Gemm_fe4m3t_fe2m1n_f16n_tensor_op_f32, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< diff --git a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f8_f32_tensor_op.cu b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f8_f32_tensor_op.cu index 7d761104..25deaba9 100644 --- a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f8_f32_tensor_op.cu +++ b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f4_f8_f32_tensor_op.cu @@ -83,7 +83,7 @@ TEST(SM120_Device_Gemm_fe2m1t_fe5m2n_f32n_tensor_op_f32, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -134,7 +134,7 @@ TEST(SM120_Device_Gemm_fe4m3t_fe2m1n_f32n_tensor_op_f32, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< diff --git a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f6_f6_f16_tensor_op.cu b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f6_f6_f16_tensor_op.cu index a62071ed..20675613 100644 --- a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f6_f6_f16_tensor_op.cu +++ b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f6_f6_f16_tensor_op.cu @@ -83,7 +83,7 @@ TEST(SM120_Device_Gemm_fe3m2t_fe3m2n_f16n_tensor_op_f32, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< diff --git a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f6_f6_f32_tensor_op.cu b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f6_f6_f32_tensor_op.cu index f5289211..1da47b41 100644 --- a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f6_f6_f32_tensor_op.cu +++ b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f6_f6_f32_tensor_op.cu @@ -83,7 +83,7 @@ TEST(SM120_Device_Gemm_fe3m2t_fe3m2n_f32n_tensor_op_f32, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< diff --git a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f6_f8_f16_tensor_op.cu b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f6_f8_f16_tensor_op.cu index fd1f79a8..ef108e2a 100644 --- a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f6_f8_f16_tensor_op.cu +++ b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f6_f8_f16_tensor_op.cu @@ -83,7 +83,7 @@ TEST(SM120_Device_Gemm_fe3m2t_fe4m3n_f16n_tensor_op_f32, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -134,7 +134,7 @@ TEST(SM120_Device_Gemm_fe4m3t_fe3m2n_f16n_tensor_op_f32, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< diff --git a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f6_f8_f32_tensor_op.cu b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f6_f8_f32_tensor_op.cu index f5b5db4a..f88f7161 100644 --- a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f6_f8_f32_tensor_op.cu +++ b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f6_f8_f32_tensor_op.cu @@ -83,7 +83,7 @@ TEST(SM120_Device_Gemm_fe3m2t_fe4m3n_f32n_tensor_op_f32, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -134,7 +134,7 @@ TEST(SM120_Device_Gemm_fe4m3t_fe3m2n_f32n_tensor_op_f32, 128x64x128_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< diff --git a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f8_f8_f16_tensor_op.cu b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f8_f8_f16_tensor_op.cu index 70b7d67c..d93604cc 100644 --- a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f8_f8_f16_tensor_op.cu +++ b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f8_f8_f16_tensor_op.cu @@ -82,7 +82,7 @@ TEST(SM120_Device_Gemm_fe4m3t_fe4m3n_f16n_tensor_op_f32, 128x64x64_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< diff --git a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f8_f8_f32_tensor_op.cu b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f8_f8_f32_tensor_op.cu index e5d0c46a..96c3b750 100644 --- a/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f8_f8_f32_tensor_op.cu +++ b/test/unit/gemm/device/sm120_tensorop_gemm/sm120_gemm_f8_f8_f32_tensor_op.cu @@ -82,7 +82,7 @@ TEST(SM120_Device_Gemm_fe4m3t_fe4m3n_f32n_tensor_op_f32, 128x64x64_1x1x1) { ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecializedCooperative + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< diff --git a/test/unit/pipeline/testbed_cluster_launch_control.h b/test/unit/pipeline/testbed_cluster_launch_control.h index 49f65d8a..50a68a14 100644 --- a/test/unit/pipeline/testbed_cluster_launch_control.h +++ b/test/unit/pipeline/testbed_cluster_launch_control.h @@ -137,7 +137,7 @@ public: return true; #endif -#if 1 +#if 0 bool is_success = false; for (int i = 0; i< 10; i++){ printf("iteration = %d\n", i); diff --git a/tools/library/src/operation_table.cu b/tools/library/src/operation_table.cu index dd2b48c6..bceeabf6 100644 --- a/tools/library/src/operation_table.cu +++ b/tools/library/src/operation_table.cu @@ -85,7 +85,7 @@ void OperationTable::append(Manifest const &manifest) { block_scaled_gemm_operations[functional_key][preference_key].push_back(op); } - + // insert all gemm operation into operation table if (desc.kind == OperationKind::kGemm) { @@ -121,6 +121,68 @@ void OperationTable::append(Manifest const &manifest) { gemm_operations[functional_key][preference_key].push_back(op); } + // insert all grouped gemm operation into operation table + if (desc.kind == OperationKind::kGroupedGemm) { + GroupedGemmDescription const &grouped_gemm_desc = static_cast(desc); + GemmDescription const &gemm_desc = grouped_gemm_desc.gemm; + + int cc = gemm_desc.tile_description.minimum_compute_capability; + + int alignment = std::max(std::max( + gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment); + + GemmPreferenceKey preference_key(cc, alignment); + + Operation const *op = operation.get(); + + if (!grouped_gemm_desc.block_scales.has_value()) { + GemmFunctionalKey functional_key( + gemm_desc.provider, + gemm_desc.gemm_kind, + gemm_desc.tile_description.math_instruction.element_accumulator, + gemm_desc.element_epilogue, + gemm_desc.A.element, + gemm_desc.A.layout, + gemm_desc.transform_A, + gemm_desc.B.element, + gemm_desc.B.layout, + gemm_desc.transform_B, + gemm_desc.C.element, + gemm_desc.C.layout, + gemm_desc.D.element, + gemm_desc.D.layout + ); + + gemm_operations[functional_key][preference_key].push_back(op); + } + else { + const BlockScaleDescription &block_scale_desc = grouped_gemm_desc.block_scales.value(); + BlockScaledGemmFunctionalKey functional_key( + gemm_desc.provider, + gemm_desc.gemm_kind, + gemm_desc.kind, + gemm_desc.tile_description.math_instruction.element_accumulator, + gemm_desc.element_epilogue, + gemm_desc.A.element, + gemm_desc.A.layout, + block_scale_desc.SFA.element, + gemm_desc.B.element, + gemm_desc.B.layout, + block_scale_desc.SFB.element, + gemm_desc.C.element, + gemm_desc.C.layout, + gemm_desc.D.element, + gemm_desc.D.layout, + block_scale_desc.SFD.element, + block_scale_desc.SFD.layout, + block_scale_desc.SFVecSize, + block_scale_desc.EpilogueSFVecSize + ); + + block_scaled_gemm_operations[functional_key][preference_key].push_back(op); + } + } + // insert all conv2d or conv3d operation into operation table if (desc.kind == OperationKind::kConv2d || desc.kind == OperationKind::kConv3d) { auto &conv_desc = static_cast(desc); diff --git a/tools/util/include/cutlass/util/reference/host/gett.hpp b/tools/util/include/cutlass/util/reference/host/gett.hpp index 45be9e72..dd54dc6e 100644 --- a/tools/util/include/cutlass/util/reference/host/gett.hpp +++ b/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -645,8 +645,9 @@ void gett_epilogue( (cute::is_same_v> or cute::is_same_v>) and cute::is_same_v; - constexpr bool IsClamp = - cute::is_same_v>; + constexpr bool UseReLU = + cute::is_same_v>; // Treat Clamp as ReLU + constexpr bool IsBackpropFusion = cute::is_same_v> or cute::is_same_v>; @@ -752,8 +753,9 @@ void gett_epilogue( } } - if constexpr (IsClamp) { // Treat Clamp as ReLU - output = activation(output, {0, std::numeric_limits::max()}); + if constexpr (UseReLU) { + cutlass::epilogue::thread::ReLU relu; + output = relu(output); } else { output = activation(output); diff --git a/tools/util/scripts/split_test_cmake.py b/tools/util/scripts/split_test_cmake.py new file mode 100644 index 00000000..6541ce1b --- /dev/null +++ b/tools/util/scripts/split_test_cmake.py @@ -0,0 +1,356 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + + +""" +Given a set of test files to be included in a CMake target, this script extracts +the TEST definitions from each file, writes them into new files, and prints the names +of the new files so that they can be processed as part of a new CMake target. + +For example, given a set of --src_files test_a.cu test_b.cu containing 3 and 2 TEST +definitions, respectively, this script would produce: + test_a_000.cu + test_a_001.cu + test_a_002.cu + test_b_000.cu + test_b_001.cu + +The splitting follows a fairly rudimentary algorithm that does not support all valid C++ programs. +We walk through a given input test file line by line. Any lines that are not within a TEST definition is added to a running +"filler" text. When a TEST definition is encountered, the current filler text becomes the prefix +for that test. All subsequent lines are considered to be part of the TEST definition until the +number of starting function braces ('{') match the number of closing function braces ('}'). When +these counts are equal, the TEST definition is considered to be completed. At this point, we return +to adding lines to the "filler" text until a new TEST definition is encountered. Any "filler" text +following a TEST definition is added to the suffix of that TEST definition (this is useful for finishing +off #if statements, as is common in unit tests.). + +A state machine illustrating this algorithm at a high level is provided in the source below. + +Example: Suppose an input test `test.cu` has the following source: + // COPYRIGHT + #include + + #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // Test #1 + TEST(SM90_a, 256x128x64_2x2x1) { + std::cout << "Test #1" << std::endl; + } + + // Test #2 + TEST(SM90_b, 256x128x64_1x1x1) { + std::cout << "Test #2" << std::endl; + } + + #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +The contents of the two resulting test files will be: + $ cat test_000.cu + // COPYRIGHT + #include + + #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // Test #1 + TEST(SM90_a, 256x128x64_2x2x1) { + std::cout << "Test #1" << std::endl; + } + + // Test #2 + + #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + $ cat test_001.cu + // COPYRIGHT + #include + + #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // Test #1 + + // Test #2 + TEST(SM90_b, 256x128x64_1x1x1) { + std::cout << "Test #2" << std::endl; + } + + #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +Notice that each of test_000.cu and test_001.cu contain comments that appear outside +the TEST definitions not included in each file. This is by design, as these +would be considered "filler" text. + +As expected, some cases can't be handled. Below is a non-exhaustive list: + 1. New TEST following the closing '}' of a TEST case on the same line: + TEST(x, y) { + // Do stuff + } TEST(a, b) { + + In this case, "TEST(a, b) {" will be ignored + + 2. Preprocessor macros that occur midway through a test case and extend + beyond the conclusion of a testcase + + Example: + TEST(a, b) { + // Do stuff + #if X + // Do more stuff + } + #else + // Do other stuff + } + #endif +""" + + +import argparse +import enum +import os + + +parser = argparse.ArgumentParser() +parser.add_argument("cmake_target", type=str, + help="Name of the CMake target being generated.") +parser.add_argument("src_dir", type=str, + help="Path to the directory containing test files.") +parser.add_argument("--src_files", nargs='+', + help="Files containing TEST instances to split.") +parser.add_argument("--max_tests_per_file", type=int, default=1, + help="Maximum number of TEST instances per file.") +parser.add_argument("--dst_dir", type=str, + help="Path to the directory to which to write new test files. If not set, uses src_dir.") +args = parser.parse_args() + + +if args.dst_dir == None: + args.dst_dir = args.src_dir + + +class Testcase: + """ + Lightweight tracker of test-case processing status + """ + def __init__(self, prefix_text): + # Any text that preceded the TEST definition that was + # not part of another TEST definition + self.prefix = prefix_text + + # Any text within the TEST definition + self.test = "" + + # Any text that follows the completion of the TEST definition + # and is not included in other TEST definitions + self.suffix = "" + + # Whether the test's definition has concluded + self.completed = False + + # Current balance of opening and closing curly brackets in + # the TEST definition. '{' increments the count and '}' decrements it. + # A value of 0 (when self.completed == False) indicates that the test + # has completed. + self.curly_bracket_balance = 0 + + +class ParseState(enum.Enum): + """ + State machine for processing. + Transitions occur on each line encountered in the soruce file + + + Line does not contain 'TEST(' + +----+ + | | + | v 'TEST(' + +--------+ encountered +--------------------------+ + ------>| Filler | -----------------------> | TestDeclaredWaitingStart | + +--------+ +--------------------------+ + ^ | + Number of '{' | | First '{' encountered + equals number of | +--------+ | + '}' encountered +-----------| InTest | <------------------+ + +--------+ + | ^ + | | + +----+ + Number of '{' encountered + exceeds number of '}' encountered + """ + + + # Any text that is not part of a TEST case + Filler = 0 + + # Processing text within the first { of the TEST case + # and before the en of the final } of the TEST case + InTest = 1 + + # Processing text from the start of the TEST definition + # but before the first {. This could occur if the opening { + # occurs on a separate line than the TEST definition. + TestDeclaredWaitingStart = 2 + + +cmake_src_list = [] +for filename in args.src_files: + if '.' not in filename: + # Add any non-filename arguments to the command list by default + cmake_src_list.append(filename) + continue + + if '/' in filename: + raise Exception( + f"Source files passed to {__file__} must be within the same directory " + "as the CMakeLists defining the target using the files. " + f"Provided path {filename} is in a different directory.") + + full_filename = os.path.join(args.src_dir, filename) + with open(full_filename, 'r') as infile: + lines = infile.readlines() + + # Find the number of instances of "TEST(" + ntest = sum([1 for line in lines if "TEST(" in line]) + + if ntest <= args.max_tests_per_file: + # File contains fewer than max_tests_per_file TEST instances. It does + # not need to be split + cmake_src_list.append(filename) + continue + + # Current state of the parsing state machine. We start with filler text + state = ParseState.Filler + + # List of individual TESTs found + tests = [] + + # Ongoing text that is not included in a TEST definition. This will serve + # as the prefix for any yet-to-be encountered TEST definitions. + filler_text = "" + + def add_filler_text(text): + global filler_text + # Add new text to the ongoing filler text and to the suffixes of + # any completed tests + filler_text += text + for i in range(len(tests)): + if tests[i].completed: + tests[i].suffix += text + + for line in lines: + if state == ParseState.Filler: + # We are not currently within a TEST definition. + + if 'TEST(' in line: + # We have encountered a new TEST( case. Any text preceding this + # must be added to the filler text (e.g., if we have a line of the form: + # "static constexpr int Val = 4; TEST(blah) {" + # then "static constexpr int Val = 4;" needs to be included in filler + # text, as it could be used by subsequent tests.) + splits = line.split('TEST') + + # There should not be more than one TEST definition on a given line + assert len(splits) <= 2 + + if len(splits) > 1: + if not splits[0].isspace(): + # Only add text to filler if there are non-whitespace charcters + # preceding the TEST definition in the line + filler_text += splits[0] + + # The new line is just the TEST-related line + line = 'TEST' + splits[-1] + + # Add tests and transtion to TestDeclaredWaitingStart state. + # Do not add the line to the test text of the new test case; this + # will be done in either the TestDeclaredWaitingStart state processing + # below or in the InTest state processing below. + tests.append(Testcase(filler_text)) + state = ParseState.TestDeclaredWaitingStart + else: + # Any remaining filler text is added to the running filler_text + # which will be used as the prefix for any new tests, and to the + # suffix of any completed tests + add_filler_text(line) + + if state == ParseState.TestDeclaredWaitingStart: + # We have seen a TEST definition but have not yet seen its opening {. + + if '{' in line: + # The first curly bracket for the TEST definition has been found. + # Advance to state InTests. Do not add the line to the test's text + # or change the curly-brace balance of the test; these will be done + # when processing the state == ParseState.InTest condition below. + state = ParseState.InTest + else: + tests[-1].test += line + + if state == ParseState.InTest: + # We are currently within a TEST definition. + # Process lines character-by-character looking for opening and closing + # braces. If we reach parity between opening and closing braces, the + # test is considered done. + filler_text_to_add = "" + for char in line: + if not tests[-1].completed: + tests[-1].test += char + if char == '{': + tests[-1].curly_bracket_balance += 1 + elif char == '}': + tests[-1].curly_bracket_balance -= 1 + if tests[-1].curly_bracket_balance == 0: + tests[-1].completed = True + else: + filler_text_to_add += char + + if filler_text_to_add != "" and (not filler_text_to_add.isspace() or '\n' in filler_text_to_add): + add_filler_text('\n' + filler_text_to_add) + + if tests[-1].completed: + state = ParseState.Filler + + # Write out the new files for tests + filename_prefix, filename_suffix = filename.split('.') + for i, test in enumerate(tests): + assert test.completed + new_filename = filename_prefix + '_' + str(i).zfill(3) + '.' + filename_suffix + full_new_filename = os.path.join(args.dst_dir, new_filename) + + # Replace any '\' with '/'. CMake doesn't like '\'. + full_new_filename = full_new_filename.replace('\\', '/') + + with open(full_new_filename, 'w') as outfile: + outfile.write(test.prefix + test.test + test.suffix) + cmake_src_list.append(full_new_filename) + + +for cmake_file in cmake_src_list: + print(cmake_file)