Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e9a75581fe | |||
| ac210faef8 | |||
| 15f5468872 | |||
| af5519d938 | |||
| 415d587ebf |
106
CHANGELOG.md
106
CHANGELOG.md
@ -1,60 +1,5 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
## [3.9.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.2) (2025-05-03)
|
||||
|
||||
* Fixed [Blockwise](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMM hang issue when problem size K is 128.
|
||||
* Optimal code generation with CUDA toolkit versions 12.9.
|
||||
|
||||
|
||||
## [3.9.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.1) (2025-04-30)
|
||||
|
||||
* Fixed Group Gemm hang issue in CUTLASS 3.x
|
||||
* Improved Hopper [Blockwise](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMM performance.
|
||||
|
||||
## [3.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.0) (2025-04-24)
|
||||
|
||||
* Support for Blackwell SM120 kernels for GeForce GPUs in CUTLASS 3.x API:
|
||||
- Collective mainloops that target for:
|
||||
* [Blockscaled datatypes with support for dense GEMM](./include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp)
|
||||
* [Blockscaled datatypes with support for sparse GEMM](./include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp)
|
||||
- New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell SM120 epilogue](./include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp) and [full set of EVT fusions](./include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp).
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM120 architecture:
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu).
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu).
|
||||
- [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu).
|
||||
- [Grouped GEMM with nvfp4 datatype](./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu).
|
||||
* Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM.
|
||||
* Support for Blackwell SM100 Sparse kernels:
|
||||
- Collective mainloop that target for
|
||||
* [SM100 Sparse GEMM](./include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp)
|
||||
* Set of example that demonstrate the usage of the 3.x API for targeting Blackwell SM100 Sparse GEMM:
|
||||
- [Sparse GEMM](./examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu)
|
||||
- [Blockscaled Sparse GEMM with NVFP4 input data type](./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu)
|
||||
- [Blockscaled Sparse GEMM with mixed input data type (MXFP8 and MXFP4)](./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu)
|
||||
* Set of unit tests that demonstrate the usage of [sparse](./test/unit/gemm/device/sm100_sparse_tensorop_gemm) and [blockscaled sparse](./test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm) Blackwell SM100 GEMM.
|
||||
* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/) covers the flashMLA-like weight-absorbed decoding use-case.
|
||||
* A new FMHA Backward kernel for SM100 Blackwell architecture extends CUTLASS [example](./examples/77_blackwell_fmha/) to show how the five backward pass MMAs can be fused into a single kernel to achieve high performance.
|
||||
* A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture.
|
||||
* Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures:
|
||||
- Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture.
|
||||
- Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture.
|
||||
- Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture.
|
||||
- Support for [grouped-wise GEMM](./tools/profiler/src/blockwise_gemm_operation_profiler.cu) in CUTLASS profiler.
|
||||
- Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture.
|
||||
- Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture.
|
||||
- Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture.
|
||||
* Added support for enhanced kernel performance search (auto-tuning) in CUTLASS profiler:
|
||||
- Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels.
|
||||
- Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance.
|
||||
- Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration.
|
||||
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/cpp/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
|
||||
* Support `void` as the D element in sm100 kernel epilogues.
|
||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
* Optimal code generation with CUDA toolkit versions 12.8U1.
|
||||
|
||||
## [3.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.8.0) (2025-01-25)
|
||||
|
||||
* Support for new CuTe building blocks specifically for Blackwell SM100 architecture:
|
||||
@ -69,7 +14,7 @@
|
||||
- [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp).
|
||||
- [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp).
|
||||
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
|
||||
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/cpp/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
|
||||
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
|
||||
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
|
||||
* Full support for Blackwell SM100 kernels in CUTLASS 3.x API:
|
||||
- [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
|
||||
@ -107,11 +52,11 @@
|
||||
- A set of new [Hopper grouped GEMM kernels](./examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
|
||||
- A new [Hopper FP8 GEMM with groupwise scaling](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
|
||||
* Documentation updates:
|
||||
- [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/cpp/quickstart.md#instantiating-a-blackwell-gemm-kernel).
|
||||
- Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/cpp/blackwell_functionality.md)
|
||||
- A new [functionality documentation](./media/docs/cpp/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
|
||||
- [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel).
|
||||
- Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/blackwell_functionality.md)
|
||||
- A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
|
||||
- Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture).
|
||||
- Updates to [profiler documentation](./media/docs/cpp/profiler.md) for testing mixed input GEMM kernels on Hopper.
|
||||
- Updates to [profiler documentation](./media/docs/profiler.md) for testing mixed input GEMM kernels on Hopper.
|
||||
|
||||
## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11)
|
||||
- [Hopper blockwise scaling FP8 GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439).
|
||||
@ -124,7 +69,7 @@
|
||||
+ Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication.
|
||||
+ Remove `cute::copy_vec<T>` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment<NumBits>,...)`.
|
||||
+ A refactor of default epilogue struct `DefaultEpilogue` [API](./include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel.
|
||||
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](./media/docs/cpp/profiler.md#cutlass-profiler).
|
||||
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](./media/docs/profiler.md#cutlass-profiler).
|
||||
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
- Optimal code generation with CUDA toolkit versions 12.6.
|
||||
|
||||
@ -138,12 +83,12 @@
|
||||
- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API.
|
||||
- [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
|
||||
- [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu).
|
||||
- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/cpp/dependent_kernel_launch.md).
|
||||
- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/cpp/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
|
||||
- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md).
|
||||
- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
|
||||
- A new TMA-enabled [epilogue](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support.
|
||||
- A SIMT-enabled pointer-array [epilogue](./include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp).
|
||||
- A new [Ping-Pong kernel schedule for Grouped GEMM](./include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations.
|
||||
- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/cpp/profiler.md#instantiating-more-kernels-with-hopper).
|
||||
- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/profiler.md#instantiating-more-kernels-with-hopper).
|
||||
- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h)
|
||||
- Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu).
|
||||
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
@ -153,7 +98,7 @@
|
||||
|
||||
- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu)
|
||||
- [Exposure of L2 `cache_hint`s in TMA copy atoms](./include/cute/arch/copy_sm90_tma.hpp#L48)
|
||||
- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/cpp/profiler.md#GEMM), and
|
||||
- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and
|
||||
[example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
|
||||
- [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu).
|
||||
- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence:
|
||||
@ -166,7 +111,7 @@
|
||||
- Support for residual add (beta != 0) in convolution kernels.
|
||||
- A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output.
|
||||
- A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt).
|
||||
- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/cpp/ide_setup.md) and [expanded code style guide](./media/docs/cpp/programming_guidelines.md).
|
||||
- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md).
|
||||
- Better support for MSVC as a host compiler.
|
||||
- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2.
|
||||
- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1.
|
||||
@ -174,7 +119,7 @@
|
||||
## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09)
|
||||
|
||||
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp)
|
||||
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/cpp/gemm_api_3x.md).
|
||||
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md).
|
||||
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](./include/cutlass/conv/convnd_problem_shape.hpp).
|
||||
+ Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
|
||||
+ [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.
|
||||
@ -186,7 +131,7 @@
|
||||
- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices.
|
||||
+ [Ampere FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm80.cu) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu#L227-L301), [Ampere INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu#L392-L1342), [Ampere INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu#L372-L934).
|
||||
+ [Turing FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm75.cu#L55-L394), [Turing INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu#L166-L537), [Turing INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu#L310-L564).
|
||||
- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cpp/cute/03_tensor.md), [MMA atoms](./media/docs/cpp/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial).
|
||||
- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial).
|
||||
- Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337).
|
||||
- Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17.
|
||||
- Fixes to greatly reduce build warnings.
|
||||
@ -205,7 +150,7 @@
|
||||
* Beta release of [Group-GEMM](./examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
* [Ampere Sparse GEMM](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
|
||||
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](./include/cutlass/arch/barrier.h) has been officially released.
|
||||
* Improved [CuTe documentation](./media/docs/cpp/cute/) including improved clarity and depth of [Quickstart](./media/docs/cute/00_quickstart.md), [CuTe Layout](./media/docs/cpp/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cpp/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved.
|
||||
* Improved [CuTe documentation](./media/docs/cute/) including improved clarity and depth of [Quickstart](./media/docs/cute/00_quickstart.md), [CuTe Layout](./media/docs/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved.
|
||||
|
||||
## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31)
|
||||
* [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
|
||||
@ -256,7 +201,7 @@
|
||||
* Epilogue builders. Similar to mainloop builders (see [example 49](./examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
|
||||
* Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
|
||||
* Performance optimizations for the [*warp-specialized persistent ping-pong*](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
|
||||
* Changes to the [GEMM API 3.x](./media/docs/cpp/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
|
||||
* Changes to the [GEMM API 3.x](./media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
|
||||
* [FMHA Backward Pass](./examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
|
||||
* [Streamk GEMM with Broadcast](./examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
|
||||
* [Batched B2B GEMM](./examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
|
||||
@ -268,10 +213,10 @@
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23)
|
||||
* [CuTe](./media/docs/cpp/cute/00_quickstart.md), a [new core library and backend](./include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
|
||||
* [A new conceptual operation hierarchy](./media/docs/cpp/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](./media/docs/cpp/gemm_api_3x.md).
|
||||
* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](./include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](./include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](./media/docs/cpp/cutlass_3x_backwards_compatibility.md).
|
||||
* Updates to [Functionality](./media/docs/cpp/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3.
|
||||
* [CuTe](./media/docs/cute/00_quickstart.md), a [new core library and backend](./include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
|
||||
* [A new conceptual operation hierarchy](./media/docs/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](./media/docs/gemm_api_3x.md).
|
||||
* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](./include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](./include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](./media/docs/cutlass_3x_backwards_compatibility.md).
|
||||
* Updates to [Functionality](./media/docs/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3.
|
||||
* Updates to [Compatibility](./README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](./README.md#Target-Architecture).
|
||||
* New warp-specialized GEMM [kernel schedules](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](./include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters.
|
||||
* Extensions to CUTLASS profiler to support threadblock cluster shapes in library and profiler tile configurations.
|
||||
@ -449,7 +394,7 @@
|
||||
* Global memory iterators supporting Fprop, Dgrad, and Wgrad
|
||||
* `MmaMultistage` for implicit GEMM convolution for NVIDIA Ampere architecture
|
||||
* `MmaPipeline` for implicit GEMM convolution for NVIDIA Volta and Turing architectures
|
||||
* [Documentation](./media/docs/cpp/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
|
||||
* [Documentation](./media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
|
||||
|
||||
## [2.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.3.0) (2020-09-23)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
@ -463,7 +408,7 @@
|
||||
* NVIDIA Ampere GPU Architecture examples and documentation:
|
||||
* [Tensor Float 32](./examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and
|
||||
* [Sparse Tensor Cores](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu)
|
||||
* Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/cpp/gemm_api.md#efficient-epilogue)
|
||||
* Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/gemm_api.md#efficient-epilogue)
|
||||
|
||||
## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
@ -483,7 +428,7 @@
|
||||
* Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON`
|
||||
|
||||
## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06)
|
||||
* BLAS-style host-side API added to [CUTLASS Library](./media/docs/cpp/quickstart.md#cutlass-library)
|
||||
* BLAS-style host-side API added to [CUTLASS Library](./media/docs/quickstart.md#cutlass-library)
|
||||
* API to launch compiled kernel instances for GEMM and planar complex GEMM
|
||||
* Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores
|
||||
* Computes complex matrix products on matrices stored as disjoint real and imaginary parts
|
||||
@ -497,10 +442,10 @@
|
||||
* Encapsulated functionality embodying modern C++11 programming techniques
|
||||
* Optimized containers and data types for efficient, generic, portable device code
|
||||
* Updates to:
|
||||
* [Quick start guide](./media/docs/cpp/quickstart.md)
|
||||
* [Quick start guide](./media/docs/quickstart.md)
|
||||
* [Documentation](./README.md#documentation)
|
||||
* [Utilities](./media/docs/cpp/utilities.md)
|
||||
* [CUTLASS Profiler](./media/docs/cpp/profiler.md)
|
||||
* [Utilities](./media/docs/utilities.md)
|
||||
* [CUTLASS Profiler](./media/docs/profiler.md)
|
||||
* Native Turing Tensor Cores
|
||||
* Efficient GEMM kernels targeting Turing Tensor Cores
|
||||
* Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands
|
||||
@ -593,3 +538,4 @@ SPDX-License-Identifier: BSD-3-Clause
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
@ -102,8 +102,6 @@ set(CMAKE_CUDA_STANDARD_REQUIRED ON)
|
||||
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --expt-relaxed-constexpr)
|
||||
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -ftemplate-backtrace-limit=0)
|
||||
|
||||
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
|
||||
set(CMAKE_INSTALL_PREFIX install CACHE PATH "Default installation location." FORCE)
|
||||
endif()
|
||||
@ -175,7 +173,7 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 101 101a 120 120a)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||
@ -384,21 +382,7 @@ endif()
|
||||
|
||||
if (CUTLASS_ENABLE_GDC_FOR_SM90)
|
||||
message(STATUS "Grid Dependency Control (GDC) is enabled for SM90 kernels (required for programmatic dependent launches).")
|
||||
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_GDC_FOR_SM90=1)
|
||||
endif()
|
||||
|
||||
if (NOT DEFINED CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT)
|
||||
set(CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT ON)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_GDC_FOR_SM100
|
||||
${CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT}
|
||||
CACHE BOOL
|
||||
"Enables Grid Dependency Control (GDC) for SM100 kernels (required for PDL).")
|
||||
|
||||
if (CUTLASS_ENABLE_GDC_FOR_SM100)
|
||||
message(STATUS "Grid Dependency Control (GDC) is enabled for SM100 kernels (required for programmatic dependent launches).")
|
||||
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_GDC_FOR_SM100=1)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_ENABLE_GDC_FOR_SM90=1)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_SYNCLOG OFF CACHE BOOL "Enable synchronization event logging for race condition debugging. WARNING: This redefines __syncthreads() and __syncwarp() in all downstream code!")
|
||||
@ -443,7 +427,7 @@ if (NOT MSVC AND CUTLASS_NVCC_KEEP)
|
||||
# MSVC flow handles caching already, but for other generators we handle it here.
|
||||
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
|
||||
file(MAKE_DIRECTORY ${CUTLASS_NVCC_KEEP_DIR})
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep -v -objtemp) # --keep-dir may not work with nvcc for some directories.
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep -v) # --keep-dir may not work with nvcc for some directories.
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -save-temps=${CUTLASS_NVCC_KEEP_DIR})
|
||||
endif()
|
||||
|
||||
@ -470,13 +454,6 @@ if(UNIX)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-fno-strict-aliasing)
|
||||
endif()
|
||||
|
||||
# Known ctk11.4 issue (fixed later)
|
||||
# Also see https://stackoverflow.com/questions/64523302/cuda-missing-return-statement-at-end-of-non-void-function-in-constexpr-if-fun
|
||||
if (CUDA_VERSION VERSION_LESS 11.5.0)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcudafe "--diag_suppress=implicit_return_from_non_void_function" )
|
||||
message("CUDA_VERSION check pass ${CUDA_VERSION}")
|
||||
endif()
|
||||
|
||||
# Don't leak lineinfo in release builds
|
||||
if (NOT CMAKE_BUILD_TYPE MATCHES "Release")
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -gmlt)
|
||||
@ -713,7 +690,6 @@ target_include_directories(
|
||||
CUTLASS
|
||||
SYSTEM INTERFACE
|
||||
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
|
||||
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include/cccl>
|
||||
)
|
||||
|
||||
install(
|
||||
@ -1055,7 +1031,6 @@ function(cutlass_generate_profiler_tests NAME)
|
||||
string(REGEX REPLACE "_cluster_k_fallback=[0-9]+" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "runtime_input_datatype_a=" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "runtime_input_datatype_b=" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "swizzle_size=" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REGEX REPLACE "verification_enabled=(true|false)" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REGEX REPLACE "warmup_iterations=[0-9]+" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REGEX REPLACE "profiling_iterations=[0-9]+" "" TEST_NAME "${TEST_NAME}")
|
||||
|
||||
@ -128,35 +128,3 @@ Bryce Lelbach<br />
|
||||
Joel McCormack<br />
|
||||
Kyrylo Perelygin<br />
|
||||
Sean Treichler<br />
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -2,16 +2,10 @@
|
||||
|
||||
## 2025
|
||||
|
||||
- ["Comet: Fine-grained Computation-communication Overlapping for Mixture-of-Experts"](https://arxiv.org/abs/2502.19811). Shulai Zhang, Ningxin Zheng, Haibin Lin, Ziheng Jiang, Wenlei Bao, Chengquan Jiang, Qi Hou, Weihao Cui, Size Zheng, Li-Wen Chang, Quan Chen, Xin Liu. _arXiv_, February 2025.
|
||||
|
||||
- ["ParetoQ: Scaling Laws in Extremely Low-bit LLM Quantization"](https://arxiv.org/abs/2502.02631). Zechun Liu, Changsheng Zhao, Hanxian Huang, Sijia Chen, Jing Zhang, Jiawei Zhao, Scott Roy, Lisa Jin, Yunyang Xiong, Yangyang Shi, Lin Xiao, Yuandong Tian, Bilge Soran, Raghuraman Krishnamoorthi, Tijmen Blankevoort, Vikas Chandra. _arXiv_, February 2025.
|
||||
|
||||
- ["Generalized Neighborhood Attention: Multi-dimensional Sparse Attention at the Speed of Light"](https://arxiv.org/abs/2504.16922). Ali Hassani, Fengzhe Zhou, Aditya Kane, Jiannan Huang, Chieh-Yun Chen, Min Shi, Steven Walton, Markus Hoehnerbach, Vijay Thakkar, Michael Isaev, Qinsheng Zhang, Bing Xu, Haicheng Wu, Wen-mei Hwu, Ming-Yu Liu, Humphrey Shi. _arXiv_, April 2025.
|
||||
|
||||
## 2024
|
||||
|
||||
- ["DeepSeek-V3 Technical Report"](https://arxiv.org/abs/2412.19437). DeepSeek-AI. _arXiv_, December 2024.
|
||||
|
||||
- ["ShadowKV: KV Cache in Shadows for High-Throughput Long-Context LLM Inference"](https://arxiv.org/abs/2410.21465). Hanshi Sun, Li-Wen Chang, Wenlei Bao, Size Zheng, Ningxin Zheng, Xin Liu, Harry Dong, Yuejie Chi, Beidi Chen. _arXiv_, October 2024.
|
||||
|
||||
- ["FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion"](https://arxiv.org/abs/2406.06858). Li-Wen Chang, Wenlei Bao, Qi Hou, Chengquan Jiang, Ningxin Zheng, Yinmin Zhong, Xuanrun Zhang, Zuquan Song, Chengji Yao, Ziheng Jiang, Haibin Lin, Xin Jin, Xin Liu. _arXiv_, June 2024.
|
||||
@ -70,35 +64,3 @@
|
||||
"](https://arxiv.org/abs/2008.13006). Cong Guo, Bo Yang Hsueh, Jingwen Leng, Yuxian Qiu, Yue Guan, Zehuan Wang, Xiaoying Jia, Xipeng Li, Minyi Guo, Yuhao Zhu. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2020.
|
||||
|
||||
- ["Strassen's Algorithm Reloaded on GPUs"](https://dl.acm.org/doi/10.1145/3372419). Jianyu Huang, Chenhan D. Yu, Robert A. van de Geijn. _ACM Transactions on Mathematical Software_, March 2020.
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
155
README.md
155
README.md
@ -1,8 +1,8 @@
|
||||

|
||||
|
||||
# CUTLASS 3.9.2
|
||||
# CUTLASS 3.8.0
|
||||
|
||||
_CUTLASS 3.9.2 - May 2025_
|
||||
_CUTLASS 3.8.0 - January 2025_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
|
||||
@ -32,53 +32,71 @@ the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution
|
||||
operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline.
|
||||
This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
|
||||
|
||||
See the [Quick Start Guide](./media/docs/cpp/quickstart.md) to get started quickly.
|
||||
See the [Quick Start Guide](./media/docs/quickstart.md) to get started quickly.
|
||||
|
||||
See the [functionality docs](./media/docs/cpp/functionality.md) for a more comprehensive
|
||||
See the [functionality docs](./media/docs/functionality.md) for a more comprehensive
|
||||
list of kernel level features, data types, instructions, and minimum supported by CUTLASS on each GPU
|
||||
architecture.
|
||||
|
||||
# What's New in CUTLASS 3.9
|
||||
# What's New in CUTLASS 3.8
|
||||
|
||||
* Support for Blackwell SM120 kernels for GeForce GPUs in CUTLASS 3.x API:
|
||||
- Collective mainloops that target for:
|
||||
* [Blockscaled datatypes with support for dense GEMM](./include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp)
|
||||
* [Blockscaled datatypes with support for sparse GEMM](./include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp)
|
||||
- New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell SM120 epilogue](./include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp) and [full set of EVT fusions](./include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp).
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM120 architecture:
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu).
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu).
|
||||
- [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu).
|
||||
- [Grouped GEMM with nvfp4 datatype](./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu).
|
||||
* Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM.
|
||||
* Support for Blackwell SM100 Sparse kernels:
|
||||
- Collective mainloop that target for
|
||||
* [SM100 Sparse GEMM](./include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp)
|
||||
* Set of example that demonstrate the usage of the 3.x API for targeting Blackwell SM100 Sparse GEMM:
|
||||
- [Sparse GEMM](./examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu)
|
||||
- [Blockscaled Sparse GEMM with NVFP4 input data type](./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu)
|
||||
- [Blockscaled Sparse GEMM with mixed input data type (MXFP8 and MXFP4)](./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu)
|
||||
* Set of unit tests that demonstrate the usage of [sparse](./test/unit/gemm/device/sm100_sparse_tensorop_gemm) and [blockscaled sparse](./test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm) Blackwell SM100 GEMM.
|
||||
* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/) covers the flashMLA-like weight-absorbed decoding use-case.
|
||||
* A new FMHA Backward kernel for SM100 Blackwell architecture extends CUTLASS [example](./examples/77_blackwell_fmha/) to show how the five backward pass MMAs can be fused into a single kernel to achieve high performance.
|
||||
* A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture.
|
||||
* Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures:
|
||||
- Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture.
|
||||
- Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture.
|
||||
- Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture.
|
||||
- Support for [grouped-wise GEMM](./tools/profiler/src/blockwise_gemm_operation_profiler.cu) in CUTLASS profiler.
|
||||
- Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture.
|
||||
- Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture.
|
||||
- Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture.
|
||||
* Added support for enhanced kernel performance search (auto-tuning) in CUTLASS profiler:
|
||||
- Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels.
|
||||
- Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance.
|
||||
- Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration.
|
||||
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/cpp/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
|
||||
* Support `void` as the D element in sm100 kernel epilogues.
|
||||
CUTLASS 3.8 is the first release that supports the NVIDIA Blackwell SM100 architecture.
|
||||
For a background on Blackwell's new features, please consult the PTX documentation for CUDA 12.8.
|
||||
|
||||
* Support for new CuTe building blocks specifically for Blackwell SM100 architecture:
|
||||
- [5th generation Blackwell Tensor Core instructions (TCGen05)](./include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms.
|
||||
- Extensions to [Tensor Memory Accelerator](./include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms.
|
||||
- Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](./include/cute/pointer.hpp) across CuTe as a first class data locale.
|
||||
- Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](./include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe.
|
||||
- [`make_tmem_copy()`](./include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms.
|
||||
- Support for [new variants of LDSM on Blackwell](./include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms.
|
||||
* Support for new CUTLASS building blocks specifically for Blackwell SM100 architecture:
|
||||
- Various narrow precision [FP4, FP6, and FP8](./include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](./include/cutlass/float_subbyte.h)
|
||||
- [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp).
|
||||
- [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp).
|
||||
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
|
||||
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
|
||||
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
|
||||
* Full support for Blackwell SM100 kernels in CUTLASS 3.x API:
|
||||
- [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
|
||||
+ Implement a new warp-specialization recipe tuned specifically for Blackwell SM100 architecture.
|
||||
+ Leverage all the new features such as CLC based tile scheduling, preferred cluster, and TMEM based double buffering of accumulators.
|
||||
+ Support stream-K load balancing for all kernel types everywhere via composable scheduler support.
|
||||
- Blackwell collective mainloops that target the TCGen05 MMA instructions (both SS and TS) for
|
||||
* [Non-block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp)
|
||||
* [Non-block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp)
|
||||
* [Block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp)
|
||||
* [Block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp)
|
||||
- Blackwell [collective mainloop for convolution kernels](./include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp) supporting non-block scaled data types for fprop, dgrad, and wgrad.
|
||||
- New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp), [convolution](./include/cutlass/conv/dispatch_policy.hpp), and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell epilogue that supports loading accumulators from `tmem`](./include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp) and [full set of EVT fusions]().
|
||||
* CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification.
|
||||
- Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes.
|
||||
- Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors.
|
||||
- Support for mixed input GEMM kernels on Hopper in the profiler.
|
||||
* New CUTLASS profiler flag `use-cuda-graphs` to reduce overheads when benchmarking launch-bound kernels.
|
||||
* A new 3.x version of grouped GEMM to the CUTLASS library and generates kernels for Hopper and Blackwell. Now grouped GEMM support is enabled in the CUTLASS profiler (`./cutlass_profiler --operation=GroupedGemm --help` for details).
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM100 architecture:
|
||||
- [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API.
|
||||
- GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell.
|
||||
- Block scaled data type GEMMs targeting Blackwell's native block scaled Tensor Cores:
|
||||
+ [NVFP4 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu)
|
||||
+ [NVFP4 inputs with NVFP4 output](./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu)
|
||||
+ [Mixed MXFP8 and MXFP6 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu)
|
||||
- GEMM example demonstrating [Blackwell's new preferred cluster support via dynamic cluster shapes](./examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for increased occupancy.
|
||||
- [GEMM with CLC based StreamK scheduler for load balancing](./examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu).
|
||||
- Grouped GEMM for [vanilla FP8 data inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu).
|
||||
- Convolution kernels for [fprop](./examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](./examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](./examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu).
|
||||
- [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128.
|
||||
- A new BF16x9 GEMM [kernel](./examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu) that emulates FP32 GEMM (SGEMM) using BF16 operations.
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Hopper architecture:
|
||||
- A set of new [Hopper grouped GEMM kernels](./examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
|
||||
- A new [Hopper FP8 GEMM with groupwise scaling](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
|
||||
* Documentation updates:
|
||||
- [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel).
|
||||
- Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/blackwell_functionality.md)
|
||||
- A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
|
||||
- Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture).
|
||||
|
||||
Note: CUTLASS 3.x builds are known to be down on Windows platforms for all CUDA toolkits.
|
||||
CUTLASS team is working on a fix.
|
||||
@ -125,7 +143,7 @@ Layouts can also be combined and manipulated via functional composition, on whic
|
||||
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates.
|
||||
This greatly simplifies the design and improves code composability and readability.
|
||||
More documentation specific to CuTe can be found in its
|
||||
[dedicated documentation directory](./media/docs/cpp/cute/00_quickstart.md).
|
||||
[dedicated documentation directory](./media/docs/cute/00_quickstart.md).
|
||||
|
||||
# Compatibility
|
||||
|
||||
@ -172,7 +190,6 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be
|
||||
|NVIDIA H100 Tensor Core GPU |9.0|11.8|
|
||||
|NVIDIA H200 Tensor Core GPU |9.0|11.8|
|
||||
|NVIDIA B200 Tensor Core GPU |10.0|12.8|
|
||||
|NVIDIA GeForce RTX 50x0 series |10.0|12.8|
|
||||
|
||||
## Target Architecture
|
||||
|
||||
@ -208,7 +225,7 @@ NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels
|
||||
compiled for Blackwell SM100 architecture with arch conditional features
|
||||
(using `sm100a`) are not compatible with RTX 50 series GPUs.
|
||||
|
||||
Please refer to the [functionality documentation](./media/docs/cpp/functionality.md)
|
||||
Please refer to the [functionality documentation](./media/docs/functionality.md)
|
||||
for details on which kernels require which target architectures.
|
||||
|
||||
# Documentation
|
||||
@ -216,22 +233,22 @@ for details on which kernels require which target architectures.
|
||||
CUTLASS is described in the following documents and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
|
||||
- [Quick Start Guide](./media/docs/cpp/quickstart.md) - basics of building and running CUTLASS
|
||||
- [Functionality](./media/docs/cpp/functionality.md) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](./media/docs/cpp/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [CUTLASS 3.x Design](./media/docs/cpp/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
||||
- [GEMM API 3.x](./media/docs/cpp/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
|
||||
- [GEMM API 2.x](./media/docs/cpp/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
|
||||
- [Implicit GEMM Convolution](./media/docs/cpp/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
|
||||
- [Code Organization](./media/docs/cpp/code_organization.md) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](./media/docs/cpp/terminology.md) - describes terms used in the code
|
||||
- [Programming Guidelines](./media/docs/cpp/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
|
||||
- [Fundamental types](./media/docs/cpp/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
|
||||
- [Layouts](./media/docs/cpp/layout.md) - describes layouts of matrices and tensors in memory
|
||||
- [Tile Iterators](./media/docs/cpp/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
|
||||
- [CUTLASS Profiler](./media/docs/cpp/profiler.md) - command-line driven profiling application
|
||||
- [CUTLASS Utilities](./media/docs/cpp/utilities.md) - additional templates used to facilitate rapid development
|
||||
- [Dependent kernel launch](./media/docs/cpp/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent
|
||||
- [Quick Start Guide](./media/docs/quickstart.md) - basics of building and running CUTLASS
|
||||
- [Functionality](./media/docs/functionality.md) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](./media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [CUTLASS 3.x Design](./media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
||||
- [GEMM API 3.x](./media/docs/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
|
||||
- [GEMM API 2.x](./media/docs/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
|
||||
- [Implicit GEMM Convolution](./media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
|
||||
- [Code Organization](./media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](./media/docs/terminology.md) - describes terms used in the code
|
||||
- [Programming Guidelines](./media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
|
||||
- [Fundamental types](./media/docs/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
|
||||
- [Layouts](./media/docs/layout.md) - describes layouts of matrices and tensors in memory
|
||||
- [Tile Iterators](./media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
|
||||
- [CUTLASS Profiler](./media/docs/profiler.md) - command-line driven profiling application
|
||||
- [CUTLASS Utilities](./media/docs/utilities.md) - additional templates used to facilitate rapid development
|
||||
- [Dependent kernel launch](./media/docs/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent
|
||||
kernels in the same stream, and how it is used in CUTLASS.
|
||||
|
||||
# Resources
|
||||
@ -251,7 +268,7 @@ projects. Client applications should target CUTLASS's `include/` directory in th
|
||||
paths.
|
||||
|
||||
CUTLASS unit tests, examples, and utilities can be build with CMake.
|
||||
The minimum version of CMake is given in the [Quickstart guide](./media/docs/cpp/quickstart.md).
|
||||
The minimum version of CMake is given in the [Quickstart guide](./media/docs/quickstart.md).
|
||||
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
|
||||
on your system.
|
||||
|
||||
@ -296,7 +313,7 @@ CUTLASS is arranged as a header-only library along with Utilities, Tools, Exampl
|
||||
and template concepts defined in the CUTLASS project.
|
||||
|
||||
A detailed explanation of the source code organization may be found in the
|
||||
[CUTLASS documentation](./media/docs/cpp/code_organization.md), but several main components are summarized below.
|
||||
[CUTLASS documentation](./media/docs/code_organization.md), but several main components are summarized below.
|
||||
|
||||
## CUTLASS Template Library
|
||||
|
||||
@ -370,7 +387,7 @@ tools/
|
||||
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
|
||||
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
|
||||
|
||||
Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/cpp/quickstart.md).
|
||||
Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/quickstart.md).
|
||||
|
||||
# Performance Profiling
|
||||
|
||||
@ -586,9 +603,9 @@ reference_device: Passed
|
||||
|
||||
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
|
||||
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
|
||||
- [GEMM CMake Examples](./media/docs/cpp/quickstart.md#gemm-cmake-examples)
|
||||
- [Implicit GEMM convolution CMake Examples](./media/docs/cpp/quickstart.md#convolution-cmake-examples)
|
||||
- [Further details about the CUTLASS Profiler are described here.](./media/docs/cpp/profiler.md)
|
||||
- [GEMM CMake Examples](./media/docs/quickstart.md#gemm-cmake-examples)
|
||||
- [Implicit GEMM convolution CMake Examples](./media/docs/quickstart.md#convolution-cmake-examples)
|
||||
- [Further details about the CUTLASS Profiler are described here.](./media/docs/profiler.md)
|
||||
|
||||
|
||||
# About
|
||||
|
||||
@ -65,10 +65,10 @@ endfunction()
|
||||
|
||||
if(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS)
|
||||
|
||||
set(PROFILER_ARCH_LIST 100a 101a 120a)
|
||||
set(PROFILER_ARCH_LIST 100a)
|
||||
foreach(ARCH IN LISTS CUTLASS_NVCC_ARCHS)
|
||||
if(NOT (ARCH IN_LIST PROFILER_ARCH_LIST))
|
||||
message(FATAL_ERROR "Only SM100a/101a/120a compute capability is supported with profiler-based unit tests")
|
||||
message(FATAL_ERROR "Only SM100a compute capability is supported with profiler-based unit tests")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@
|
||||
addressable memory, and then store it back into addressable memory.
|
||||
|
||||
TileIterator is a core concept in CUTLASS that enables efficient loading and storing of data to
|
||||
and from addressable memory. The PredicatedTileIterator accepts a ThreadMap type, which defines
|
||||
and from addressable memory. The PredicateTileIterator accepts a ThreadMap type, which defines
|
||||
the mapping of threads to a "tile" in memory. This separation of concerns enables user-defined
|
||||
thread mappings to be specified.
|
||||
|
||||
@ -124,7 +124,7 @@ __global__ void copy(
|
||||
|
||||
cudaError_t TestTileIterator(int M, int K) {
|
||||
|
||||
// For this example, we chose a <64, 4> tile shape. The PredicatedTileIterator expects
|
||||
// For this example, we chose a <64, 4> tile shape. The PredicateTileIterator expects
|
||||
// PitchLinearShape and PitchLinear layout.
|
||||
using Shape = cutlass::layout::PitchLinearShape<64, 4>;
|
||||
using Layout = cutlass::layout::PitchLinear;
|
||||
@ -136,7 +136,7 @@ cudaError_t TestTileIterator(int M, int K) {
|
||||
// dimension then along the strided dimension.
|
||||
using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<Shape, kThreads>;
|
||||
|
||||
// Define the PredicatedTileIterator, using TileShape, Element, Layout, and ThreadMap types
|
||||
// Define the PredicateTileIterator, using TileShape, Element, Layout, and ThreadMap types
|
||||
using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
Shape, Element, Layout, 1, ThreadMap>;
|
||||
|
||||
|
||||
@ -115,3 +115,4 @@ SPDX-License-Identifier: BSD-3-Clause
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
@ -2,35 +2,3 @@
|
||||
|
||||
This directory contains deprecated examples for PyCUTLASS, a precursor to the CUTLASS Python interface.
|
||||
For examples of using CUTLASS's actively-maintained Pythonic interface, see the [examples/python](/examples/python) directory.
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -165,35 +165,3 @@ Example 7: GELU
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
|
||||
```
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -402,7 +402,7 @@ struct Options : MixedDtypeOptions{
|
||||
void initialize(Options const& options) {
|
||||
|
||||
auto shape_B = cute::make_shape(options.n, options.k, options.l);
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.g);
|
||||
int const scale_k = (options.k + options.g - 1) / options.g;
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
|
||||
// Reverse stride here due to swap and transpose
|
||||
@ -429,7 +429,7 @@ void initialize(Options const& options) {
|
||||
block_zero.reset(scale_k * options.l * options.n);
|
||||
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_tensor(block_B, seed + 2021);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
|
||||
@ -318,7 +318,7 @@ struct Options : MixedDtypeOptions {
|
||||
void initialize(Options const& options) {
|
||||
|
||||
auto shape_B = cute::make_shape(options.n, options.k, options.l);
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.g);
|
||||
int const scale_k = (options.k + options.g - 1) / options.g;
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
|
||||
// Reverse stride here due to swap and transpose
|
||||
@ -347,7 +347,7 @@ void initialize(Options const& options) {
|
||||
block_zero.reset(scale_k * options.l * options.n);
|
||||
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_tensor(block_B, seed + 2021);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size());
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale(block_scale, options);
|
||||
|
||||
@ -288,7 +288,7 @@ cutlass::DeviceAllocation<typename GemmScaleWithZeroPoint::EpilogueOutputOp::Ele
|
||||
void initialize(MixedDtypeOptions const& options) {
|
||||
|
||||
auto shape_b = cute::make_shape(options.n, options.k, options.l);
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.g);
|
||||
int const scale_k = (options.k + options.g - 1) / options.g;
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
|
||||
// Reverse stride here due to swap and transpose
|
||||
@ -313,7 +313,7 @@ void initialize(MixedDtypeOptions const& options) {
|
||||
block_zero.reset(scale_k * options.l * options.n);
|
||||
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_tensor(block_B, seed + 2021);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
|
||||
@ -41,35 +41,3 @@ We are currently optimizing the following cases:
|
||||
* Optimizations for memory bound cases.
|
||||
|
||||
* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size.
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -208,6 +208,20 @@ bool initialize_tensor(
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
bool initialize_quant_tensor(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed = 2023) {
|
||||
|
||||
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
|
||||
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Element>
|
||||
bool initialize_scale(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
@ -218,8 +232,10 @@ bool initialize_scale(
|
||||
float scope_max = 1.0f, scope_min = 1.0f;
|
||||
if (options.mode != MixedDtypeGemmMode::ConvertOnly) {
|
||||
float elt_max_f = float(cutlass::platform::numeric_limits<Element>::max());
|
||||
scope_max = 2.f;
|
||||
scope_min = 0.1f;
|
||||
const float max_dequant_val = 4.f;
|
||||
const float min_dequant_val = 0.5f;
|
||||
scope_max = max_dequant_val / elt_max_f;
|
||||
scope_min = min_dequant_val / elt_max_f;
|
||||
}
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
@ -207,35 +207,3 @@ With this in mind, this example kernel has the following limitations:
|
||||
- This example kernel only supports dynamic image count, all other conv problem shape must be defined as `cute::Constant<>`s
|
||||
- Problem shapes (including dynamic image count `N`) must be evenly divisible by the tile shape
|
||||
- It does not perform fp32->tf32 numeric conversion, gmem inputs must be rounded to tf32 already
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -26,13 +26,11 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
set(TEST_PREFETCH_CASE --m=8192 --n=64 --k=8192 --iterations=0)
|
||||
include_directories(
|
||||
.
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
63_hopper_gemm_with_weight_prefetch
|
||||
63_hopper_gemm_with_weight_prefetch.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_PREFETCH_CASE
|
||||
)
|
||||
|
||||
target_include_directories(63_hopper_gemm_with_weight_prefetch PUBLIC .)
|
||||
)
|
||||
|
||||
@ -74,40 +74,9 @@ echo "Overlap ratio of 0.8, prefetch ratio of 0.7"
|
||||
However, note that the example still runs a single GEMM, and most of the performance improvement
|
||||
is expected in end to end applications.
|
||||
|
||||
|
||||
## Limitations
|
||||
* The parameter defaults are typically not good choices, especially `prefetch_ratio`.
|
||||
When `prefetch_ratio` is unspecified (set to `-1.0`), the prefetch warp will `try_wait` on a
|
||||
memory barrier before issuing every single TMA load, and in many cases this will slow down
|
||||
prefetching to the point of being almost ineffective.
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -362,11 +362,11 @@ public:
|
||||
using ClusterSyncWithPrefetchBarrier = typename cutlass::arch::NamedBarrier;
|
||||
auto prefetcher_arrive_barrier = ClusterSyncWithPrefetchBarrier(
|
||||
blockDim.x * blockDim.y * blockDim.z,
|
||||
/*id*/ 0);
|
||||
/*reserved_named_barriers_*/ 14);
|
||||
// Prefetcher warp doesn't arrive on this barrier.
|
||||
auto cluster_arrive_barrier = ClusterSyncWithPrefetchBarrier(
|
||||
blockDim.x * blockDim.y * blockDim.z - NumThreadsPerWarp,
|
||||
/*id*/ 1);
|
||||
/*reserved_named_barriers_*/ 15);
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) {
|
||||
__syncwarp();
|
||||
|
||||
@ -120,7 +120,8 @@
|
||||
#include "helper.h"
|
||||
|
||||
// Distributed GEMM helpers
|
||||
#include "dist_gemm_helpers.h"
|
||||
#include "util/benchmark.h"
|
||||
#include "util/device_copy.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -833,10 +834,10 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
if (props.major < 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture "
|
||||
<< "(compute capability 90)." << std::endl;
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -62,40 +62,3 @@ procedure is the same, simply modify the following line in the example:
|
||||
```cpp
|
||||
using TP = _8;
|
||||
```
|
||||
|
||||
## References
|
||||
* [Distributed GEMM Blog](https://blog.shi-labs.com/distributed-gemm-88be6a481e2b)
|
||||
* [Distributed GEMM Talk on CUDA Mode](https://www.youtube.com/watch?v=NHRTCQBZokg)
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
@ -17,8 +17,6 @@ Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit ar
|
||||
This example specifically requires CUDA Toolkit 12.6 or newer, due to some of the necessary
|
||||
CUDA graph APIs.
|
||||
|
||||
The minimum CUDA driver version for running this example is [560.28.03](https://docs.nvidia.com/cuda/archive/12.6.0/cuda-toolkit-release-notes/index.html#id5).
|
||||
|
||||
### Hardware / driver settings
|
||||
|
||||
This example requires Hopper GPUs with NVLink network.
|
||||
@ -86,35 +84,3 @@ GPU5 OK OK OK OK OK X OK OK
|
||||
GPU6 OK OK OK OK OK OK X OK
|
||||
GPU7 OK OK OK OK OK OK OK X
|
||||
```
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -44,11 +44,6 @@
|
||||
#include <cuda/atomic>
|
||||
#include <cuda/std/atomic>
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/cuda_host_adapter.hpp"
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -120,46 +115,4 @@ struct DistGpuTimer {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Generic device-to-device data movement kernel based for CuTe tensors.
|
||||
///
|
||||
/// NOTE: this kernel assigns one element copy to every thread, and is by no means
|
||||
/// an efficient way of copying tensors. It should only be used for convenience in
|
||||
/// reference checks.
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TensorSource, typename TensorDestination>
|
||||
void device_copy(TensorSource tensor_source,
|
||||
TensorDestination tensor_destination,
|
||||
cudaStream_t stream);
|
||||
|
||||
|
||||
template <typename TensorSource, typename TensorDestination>
|
||||
__global__ void device_copy_kernel(TensorSource const tensor_source,
|
||||
TensorDestination tensor_destination) {
|
||||
auto linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
using ElementSrc = typename TensorSource::value_type;
|
||||
using ElementDst = typename TensorDestination::value_type;
|
||||
NumericConverter<ElementDst, ElementSrc> converter;
|
||||
if (linear_idx < size(tensor_source)) {
|
||||
tensor_destination(linear_idx) = converter(tensor_source(linear_idx));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TensorSource, typename TensorDestination>
|
||||
void device_copy(TensorSource tensor_source,
|
||||
TensorDestination tensor_destination,
|
||||
cudaStream_t stream) {
|
||||
|
||||
assert(tensor_source.size() == tensor_destination.size());
|
||||
|
||||
auto numel = tensor_source.size();
|
||||
static constexpr int NumThreads = 128;
|
||||
auto grid_size = cute::ceil_div(numel, NumThreads);
|
||||
|
||||
dim3 grid(grid_size);
|
||||
dim3 block(NumThreads);
|
||||
device_copy_kernel<<<grid, block, 0, stream>>>(tensor_source, tensor_destination);
|
||||
}
|
||||
|
||||
} //namespace cutlass
|
||||
84
examples/65_distributed_gemm/util/device_copy.h
Normal file
84
examples/65_distributed_gemm/util/device_copy.h
Normal file
@ -0,0 +1,84 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief generic device-to-device data movement kernel based for CuTe tensors.
|
||||
|
||||
NOTE: this kernel assigns one element copy to every thread, and is by no means
|
||||
an efficient way of copying tensors. It should only be used for convenience in
|
||||
reference checks.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/cuda_host_adapter.hpp"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
template <typename TensorSource, typename TensorDestination>
|
||||
void device_copy(TensorSource tensor_source,
|
||||
TensorDestination tensor_destination,
|
||||
cudaStream_t stream);
|
||||
|
||||
|
||||
template <typename TensorSource, typename TensorDestination>
|
||||
__global__ void device_copy_kernel(TensorSource const tensor_source,
|
||||
TensorDestination tensor_destination) {
|
||||
auto linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
using ElementSrc = typename TensorSource::value_type;
|
||||
using ElementDst = typename TensorDestination::value_type;
|
||||
NumericConverter<ElementDst, ElementSrc> converter;
|
||||
if (linear_idx < size(tensor_source)) {
|
||||
tensor_destination(linear_idx) = converter(tensor_source(linear_idx));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TensorSource, typename TensorDestination>
|
||||
void device_copy(TensorSource tensor_source,
|
||||
TensorDestination tensor_destination,
|
||||
cudaStream_t stream) {
|
||||
|
||||
assert(tensor_source.size() == tensor_destination.size());
|
||||
|
||||
auto numel = tensor_source.size();
|
||||
static constexpr int NumThreads = 128;
|
||||
auto grid_size = cute::ceil_div(numel, NumThreads);
|
||||
|
||||
dim3 grid(grid_size);
|
||||
dim3 block(NumThreads);
|
||||
device_copy_kernel<<<grid, block, 0, stream>>>(tensor_source, tensor_destination);
|
||||
}
|
||||
|
||||
} //namespace cutlass
|
||||
@ -0,0 +1,712 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Grouped scale Hopper FP8 GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
|
||||
This example demonstrate a grouped scaled FP8 GEMM using the new CUTLASS 3.0.
|
||||
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
|
||||
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
|
||||
which are more efficient than the Ampere tensor core instructions.
|
||||
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
|
||||
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
|
||||
copies between thread blocks in a cluster.
|
||||
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
|
||||
4. This example shows all important fusions used by FP8 gemm kernels, i.e., grouped scale factor along M for
|
||||
A, blocked scale factor along K for A tensor, blocked scale factor for B tensor, the abs_max value of D tensor.
|
||||
5. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
|
||||
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
|
||||
improve performance.
|
||||
Examples:
|
||||
$ ./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_deepgemm \
|
||||
--m=4096 --iterations=1000
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
// #include "cutlass/util/reference/host/tensor_copy.h"
|
||||
// #include "cutlass/util/reference/host/tensor_compare.h"
|
||||
// #include "cutlass/util/reference/host/tensor_norm.h"
|
||||
|
||||
// Includes from examples directory
|
||||
#include "helper.h"
|
||||
// #include "reference/host/gemm_with_groupwise_scaling.h"
|
||||
|
||||
#include "deep_gemm/include/deep_gemm/fp8_gemm.cuh"
|
||||
|
||||
// using namespace cute;
|
||||
using namespace deep_gemm;
|
||||
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
bool help;
|
||||
int iterations;
|
||||
int m, n, k, num_groups;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(4096),
|
||||
n(4096),
|
||||
k(4096),
|
||||
num_groups(4),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
Options defaults;
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m, defaults.m);
|
||||
cmd.get_cmd_line_argument("num_groups", num_groups, defaults.num_groups);
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "67_hopper_fp8_deepgemm\n\n"
|
||||
<< " Hopper FP8 DeepGEMM kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the m size\n"
|
||||
<< " --num_groups=<int> Sets the number of groups\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
constexpr int cdiv(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
|
||||
// #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Helper to initialize a block of device data (scale_tensors)
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
|
||||
scope_min = -1;
|
||||
scope_max = 1;
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
/// Todo: add reference check
|
||||
bool verify(const Options &options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
struct TestGemm {
|
||||
using Element = cutlass::float_e4m3_t;
|
||||
using ElementScale = float;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = cutlass::bfloat16_t;
|
||||
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> Tensor_lhs;
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> Tensor_rhs;
|
||||
cutlass::HostTensor<ElementScale, cutlass::layout::ColumnMajor> Tensor_lhs_scale;
|
||||
cutlass::HostTensor<ElementScale, cutlass::layout::RowMajor> Tensor_rhs_scale;
|
||||
cutlass::HostTensor<ElementOut, cutlass::layout::RowMajor> Tensor_out;
|
||||
|
||||
|
||||
/// Initialize operands to be used in the GEMM
|
||||
void initialize(
|
||||
const Options &options,
|
||||
uint64_t seed = 2025) {
|
||||
|
||||
Tensor_lhs.resize({options.m, options.k}); //[m, k]
|
||||
Tensor_rhs.resize({options.n, options.k}); //[n, k]
|
||||
Tensor_lhs_scale.resize({options.m, cdiv(options.k, 128)}); // [m, cdiv(k, 128)] column major
|
||||
Tensor_rhs_scale.resize({cdiv(options.n, 128), cdiv(options.k, 128)}); // [cdiv(n, 128), cdiv(k, 128)]
|
||||
Tensor_out.resize({options.m, options.n}); // [m, n]
|
||||
|
||||
initialize_tensor(Tensor_lhs.host_view(), cutlass::Distribution::Uniform, seed + 1);
|
||||
initialize_tensor(Tensor_rhs.host_view(), cutlass::Distribution::Uniform, seed + 2);
|
||||
initialize_scale_tensor(Tensor_lhs_scale.host_view(), cutlass::Distribution::Uniform, seed + 3);
|
||||
initialize_scale_tensor(Tensor_rhs_scale.host_view(), cutlass::Distribution::Uniform, seed + 4);
|
||||
|
||||
Tensor_lhs.sync_device();
|
||||
Tensor_rhs.sync_device();
|
||||
Tensor_lhs_scale.sync_device();
|
||||
Tensor_rhs_scale.sync_device();
|
||||
Tensor_out.sync_device();
|
||||
|
||||
}
|
||||
|
||||
void run(Options &options)
|
||||
{
|
||||
cudaDeviceProp props;
|
||||
int current_device;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device));
|
||||
|
||||
initialize(options);
|
||||
|
||||
cudaStream_t stream{nullptr};
|
||||
constexpr auto N = 4096;
|
||||
constexpr auto K = 4096;
|
||||
constexpr auto BLOCK_M = 128;
|
||||
constexpr auto BLOCK_N = 128;
|
||||
constexpr auto kNumStages = 5;
|
||||
constexpr auto kNumTMAMulticast = 2;
|
||||
const int num_sms = 132; // for H100
|
||||
const int best_smem_size = 199376;
|
||||
|
||||
// Make a templated GEMM
|
||||
using GemmKernel = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
|
||||
|
||||
int m = options.m;
|
||||
// DeepGEMM requires __nv_fp8_e4m3 input and __nv_bfloat16 output
|
||||
__nv_fp8_e4m3* lhs = reinterpret_cast<__nv_fp8_e4m3*>(Tensor_lhs.device_data());
|
||||
__nv_fp8_e4m3* rhs = reinterpret_cast<__nv_fp8_e4m3*>(Tensor_rhs.device_data());
|
||||
float* lhs_scales = Tensor_lhs_scale.device_data();
|
||||
float* rhs_scales = Tensor_rhs_scale.device_data();
|
||||
__nv_bfloat16* out = reinterpret_cast<__nv_bfloat16*>(Tensor_out.device_data());
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmKernel::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmKernel::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmKernel::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = GemmKernel::make_2d_tma_d_desc(out, m);
|
||||
GemmKernel::run(out, rhs_scales, nullptr,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, best_smem_size);
|
||||
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
std::cout << "run Gemm...\n";
|
||||
// TODO: reference check
|
||||
Result result;
|
||||
// result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile);
|
||||
|
||||
// std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
// if (!result.passed) {
|
||||
// exit(-1);
|
||||
// }
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
// initialize(options);
|
||||
GemmKernel::run(out, rhs_scales, nullptr,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, best_smem_size);
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): (128, 128, 128)" << std::endl;
|
||||
std::cout << " ScaleGranularityM: 1 (ScaleMsPerTile: 128)" << std::endl;
|
||||
std::cout << " ScaleGranularityN: 128 (ScaleNsPerTile: 1)" << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct TestGroupedGemm_Contiguous {
|
||||
using Element = cutlass::float_e4m3_t;
|
||||
using ElementScale = float;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = cutlass::bfloat16_t;
|
||||
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> Tensor_lhs;
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> Tensor_rhs;
|
||||
cutlass::HostTensor<ElementScale, cutlass::layout::ColumnMajor> Tensor_lhs_scale;
|
||||
cutlass::HostTensor<ElementScale, cutlass::layout::RowMajor> Tensor_rhs_scale;
|
||||
cutlass::HostTensor<ElementOut, cutlass::layout::RowMajor> Tensor_out;
|
||||
cutlass::HostTensor<int, cutlass::layout::RowMajor> Tensor_grouped_layout;
|
||||
|
||||
/// Initialize operands to be used in the GEMM
|
||||
void initialize(
|
||||
const Options &options,
|
||||
uint64_t seed = 2025) {
|
||||
|
||||
Tensor_lhs.resize({options.m, options.k}); //[m, k]
|
||||
Tensor_rhs.resize({options.num_groups * options.n, options.k}); //[num_groups, n, k]
|
||||
Tensor_lhs_scale.resize({options.m, cdiv(options.k, 128)}); // [m, cdiv(k, 128)] column major
|
||||
Tensor_rhs_scale.resize({options.num_groups * cdiv(options.n, 128), cdiv(options.k, 128)}); // [num_groups, cdiv(n, 128), cdiv(k, 128)]
|
||||
Tensor_out.resize({options.m, options.n}); // [m, n]
|
||||
Tensor_grouped_layout.resize({1,options.m}); // [num_groups,]
|
||||
|
||||
std::vector<int> group_start {0, options.m/4, 2*options.m/4, 3*options.m/4, options.m}; // sum(grouped_layout) = options.m
|
||||
for (int i = 0; i < options.m; ++i) {
|
||||
for(int j = 0; j < options.num_groups; ++j) {
|
||||
if(i >= group_start[j] && i < group_start[j+1]) {
|
||||
Tensor_grouped_layout.host_data()[i] = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
initialize_tensor(Tensor_lhs.host_view(), cutlass::Distribution::Uniform, seed + 1);
|
||||
initialize_tensor(Tensor_rhs.host_view(), cutlass::Distribution::Uniform, seed + 2);
|
||||
initialize_scale_tensor(Tensor_lhs_scale.host_view(), cutlass::Distribution::Uniform, seed + 3);
|
||||
initialize_scale_tensor(Tensor_rhs_scale.host_view(), cutlass::Distribution::Uniform, seed + 4);
|
||||
|
||||
Tensor_lhs.sync_device();
|
||||
Tensor_rhs.sync_device();
|
||||
Tensor_lhs_scale.sync_device();
|
||||
Tensor_rhs_scale.sync_device();
|
||||
Tensor_out.sync_device();
|
||||
Tensor_grouped_layout.sync_device();
|
||||
|
||||
}
|
||||
|
||||
void run(Options &options)
|
||||
{
|
||||
cudaDeviceProp props;
|
||||
int current_device;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device));
|
||||
|
||||
initialize(options);
|
||||
|
||||
cudaStream_t stream{nullptr};
|
||||
constexpr auto N = 4096;
|
||||
constexpr auto K = 4096;
|
||||
constexpr auto BLOCK_M = 128;
|
||||
constexpr auto BLOCK_N = 128;
|
||||
constexpr auto num_groups = 4;
|
||||
constexpr auto kNumStages = 5;
|
||||
constexpr auto kNumTMAMulticast = 2;
|
||||
const int num_sms = 132; // for H100
|
||||
const int best_smem_size = 199376;
|
||||
|
||||
// Make a templated GEMM
|
||||
using GemmKernel = Gemm<N, K, BLOCK_M, BLOCK_N, 128, num_groups, kNumStages, kNumTMAMulticast, GemmType::GroupedContiguous>;
|
||||
|
||||
int m = options.m;
|
||||
// DeepGEMM requires __nv_fp8_e4m3 input and __nv_bfloat16 output
|
||||
__nv_fp8_e4m3* lhs = reinterpret_cast<__nv_fp8_e4m3*>(Tensor_lhs.device_data());
|
||||
__nv_fp8_e4m3* rhs = reinterpret_cast<__nv_fp8_e4m3*>(Tensor_rhs.device_data());
|
||||
float* lhs_scales = Tensor_lhs_scale.device_data();
|
||||
float* rhs_scales = Tensor_rhs_scale.device_data();
|
||||
__nv_bfloat16* out = reinterpret_cast<__nv_bfloat16*>(Tensor_out.device_data());
|
||||
int* grouped_layout = Tensor_grouped_layout.device_data();
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmKernel::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmKernel::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmKernel::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = GemmKernel::make_2d_tma_d_desc(out, m);
|
||||
GemmKernel::run(out, rhs_scales, grouped_layout,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, best_smem_size);
|
||||
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
std::cout << "run GroupedGemm Contiguous...\n";
|
||||
// TODO: reference check
|
||||
Result result;
|
||||
// result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile);
|
||||
|
||||
// std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
// if (!result.passed) {
|
||||
// exit(-1);
|
||||
// }
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
// initialize(options);
|
||||
GemmKernel::run(out, rhs_scales, grouped_layout,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, best_smem_size);
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Number of groups: " << options.num_groups << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): (128, 128, 128)" << std::endl;
|
||||
std::cout << " ScaleGranularityM: 1 (ScaleMsPerTile: 128)" << std::endl;
|
||||
std::cout << " ScaleGranularityN: 128 (ScaleNsPerTile: 1)" << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct TestGroupedGemm_Masked {
|
||||
using Element = cutlass::float_e4m3_t;
|
||||
using ElementScale = float;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = cutlass::bfloat16_t;
|
||||
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> Tensor_lhs;
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> Tensor_rhs;
|
||||
cutlass::HostTensor<ElementScale, cutlass::layout::ColumnMajor> Tensor_lhs_scale;
|
||||
cutlass::HostTensor<ElementScale, cutlass::layout::RowMajor> Tensor_rhs_scale;
|
||||
cutlass::HostTensor<ElementOut, cutlass::layout::RowMajor> Tensor_out;
|
||||
cutlass::HostTensor<int, cutlass::layout::RowMajor> Tensor_masked_m;
|
||||
|
||||
/// Initialize operands to be used in the GEMM
|
||||
void initialize(
|
||||
const Options &options,
|
||||
uint64_t seed = 2025) {
|
||||
|
||||
int m_max = options.m;
|
||||
Tensor_lhs.resize({options.num_groups * m_max, options.k}); //[num_groups, m, k]
|
||||
Tensor_rhs.resize({options.num_groups * options.n, options.k}); //[num_groups, n, k]
|
||||
Tensor_lhs_scale.resize({options.num_groups * m_max, cdiv(options.k, 128)}); // [num_groups, m, cdiv(k, 128)] column major
|
||||
Tensor_rhs_scale.resize({options.num_groups * cdiv(options.n, 128), cdiv(options.k, 128)}); // [num_groups, cdiv(n, 128), cdiv(k, 128)]
|
||||
Tensor_out.resize({options.num_groups * m_max, options.n}); // [num_groups, m, n]
|
||||
Tensor_masked_m.resize({1,options.num_groups}); // [num_groups,]
|
||||
|
||||
std::vector<int> masked_m {options.m/4,2*options.m/4,3*options.m/4,options.m}; // max(masked_m) <= options.m
|
||||
for (int i = 0; i < options.num_groups; ++i) {
|
||||
Tensor_masked_m.host_data()[i] = masked_m[i];
|
||||
}
|
||||
|
||||
initialize_tensor(Tensor_lhs.host_view(), cutlass::Distribution::Uniform, seed + 1);
|
||||
initialize_tensor(Tensor_rhs.host_view(), cutlass::Distribution::Uniform, seed + 2);
|
||||
initialize_scale_tensor(Tensor_lhs_scale.host_view(), cutlass::Distribution::Uniform, seed + 3);
|
||||
initialize_scale_tensor(Tensor_rhs_scale.host_view(), cutlass::Distribution::Uniform, seed + 4);
|
||||
|
||||
Tensor_lhs.sync_device();
|
||||
Tensor_rhs.sync_device();
|
||||
Tensor_lhs_scale.sync_device();
|
||||
Tensor_rhs_scale.sync_device();
|
||||
Tensor_out.sync_device();
|
||||
Tensor_masked_m.sync_device();
|
||||
|
||||
}
|
||||
|
||||
void run(Options &options)
|
||||
{
|
||||
cudaDeviceProp props;
|
||||
int current_device;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device));
|
||||
|
||||
initialize(options);
|
||||
|
||||
cudaStream_t stream{nullptr};
|
||||
constexpr auto N = 4096;
|
||||
constexpr auto K = 4096;
|
||||
constexpr auto BLOCK_M = 128;
|
||||
constexpr auto BLOCK_N = 128;
|
||||
constexpr auto num_groups = 4;
|
||||
constexpr auto kNumStages = 5;
|
||||
constexpr auto kNumTMAMulticast = 2;
|
||||
const int num_sms = 132; // for H100
|
||||
const int best_smem_size = 199376;
|
||||
|
||||
// Make a templated GEMM
|
||||
using GemmKernel = Gemm<N, K, BLOCK_M, BLOCK_N, 128, num_groups, kNumStages, kNumTMAMulticast, GemmType::GroupedMasked>;
|
||||
|
||||
int m = options.m;
|
||||
// DeepGEMM requires __nv_fp8_e4m3 input and __nv_bfloat16 output
|
||||
__nv_fp8_e4m3* lhs = reinterpret_cast<__nv_fp8_e4m3*>(Tensor_lhs.device_data());
|
||||
__nv_fp8_e4m3* rhs = reinterpret_cast<__nv_fp8_e4m3*>(Tensor_rhs.device_data());
|
||||
float* lhs_scales = Tensor_lhs_scale.device_data();
|
||||
float* rhs_scales = Tensor_rhs_scale.device_data();
|
||||
__nv_bfloat16* out = reinterpret_cast<__nv_bfloat16*>(Tensor_out.device_data());
|
||||
int* masked_m = Tensor_masked_m.device_data();
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmKernel::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmKernel::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmKernel::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = GemmKernel::make_2d_tma_d_desc(out, m);
|
||||
GemmKernel::run(out, rhs_scales, masked_m,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, best_smem_size);
|
||||
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
std::cout << "run GroupedGemm Contiguous...\n";
|
||||
// TODO: reference check
|
||||
Result result;
|
||||
// result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile);
|
||||
|
||||
// std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
// if (!result.passed) {
|
||||
// exit(-1);
|
||||
// }
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
// initialize(options);
|
||||
GemmKernel::run(out, rhs_scales, masked_m,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, best_smem_size);
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: M " << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Number of groups: " << options.num_groups << std::endl;
|
||||
std::cout << " Number of masked rows: " ;
|
||||
for (int i = 0; i < options.num_groups; ++i) {
|
||||
std::cout << Tensor_masked_m.host_data()[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): (128, 128, 128)" << std::endl;
|
||||
std::cout << " ScaleGranularityM: 1 (ScaleMsPerTile: 128)" << std::endl;
|
||||
std::cout << " ScaleGranularityN: 128 (ScaleNsPerTile: 1)" << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
#if defined (CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
TestGemm testgemm{};
|
||||
testgemm.run(options);
|
||||
|
||||
TestGroupedGemm_Contiguous testgroupedgemm_contiguous{};
|
||||
testgroupedgemm_contiguous.run(options);
|
||||
|
||||
TestGroupedGemm_Masked testgroupedgemm_masked{};
|
||||
testgroupedgemm_masked.run(options);
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -75,11 +75,11 @@
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
// Includes from examples directory
|
||||
#include "helper.h"
|
||||
#include "hopper_fp8_commandline.hpp"
|
||||
#include "reference/host/gemm_with_blockwise_scaling.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -100,7 +100,7 @@ using LayoutB = cutlass::layout::ColumnMajor; // L
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
@ -123,13 +123,7 @@ using ArchTag = cutlass::arch::Sm90; // T
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(TileShape{}));
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<>;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
@ -149,8 +143,8 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
|
||||
|
||||
using CollectiveMainloopWithBlockWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, cute::tuple<LayoutA, LayoutSFA>, AlignmentA,
|
||||
ElementB, cute::tuple<LayoutB, LayoutSFB>, AlignmentB,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
@ -196,22 +190,20 @@ StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
StrideAux stride_aux;
|
||||
LayoutSFA layout_SFA;
|
||||
LayoutSFB layout_SFB;
|
||||
uint64_t seed;
|
||||
|
||||
using LayoutScalar = cutlass::layout::PackedVectorLayout;
|
||||
cutlass::HostTensor<ElementA , LayoutA > tensor_A;
|
||||
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
|
||||
cutlass::HostTensor<ElementC , LayoutC > tensor_C;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
|
||||
uint32_t mma_promotion_interval;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutScalar> blockscale_tensor_A;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutScalar> blockscale_tensor_B;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutA> blockscale_tensor_A;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutB> blockscale_tensor_B;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
|
||||
cutlass::HostTensor<ElementAux, LayoutAux> tensor_aux;
|
||||
cutlass::HostTensor<ElementAux, LayoutAux> tensor_ref_aux;
|
||||
|
||||
using LayoutScalar = cutlass::layout::PackedVectorLayout;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_alpha;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_beta;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_A;
|
||||
@ -259,116 +251,117 @@ struct Result
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, bits_input);
|
||||
return true;
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Helper to initialize a block of device data (scale_tensors)
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
double scope_max, scope_min;
|
||||
|
||||
scope_min = -1;
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
scope_max = 1;
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min);
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options<RasterOrderOptions> &options) {
|
||||
|
||||
// Find Block Scaling tensor shapes based on problem shape and TileShape
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
|
||||
auto blockscale_m = cute::get<0>(blockscale_shape);
|
||||
auto blockscale_n = cute::get<1>(blockscale_shape);
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_aux = stride_D;
|
||||
|
||||
// Layout SFA and SFB represent logically broadcasting data in CuTe.
|
||||
// E.g., if Layout SFA has shape ((ScaleGranularityM, M / ScaleGranularityM), (ScaleGraunularityK, K / ScaleGranularityK))
|
||||
// and strides ((0, 1), (0, M / ScaleGraunuarlityM)), then each collection of ScaleGranularityM x ScaleGranularityK
|
||||
// indecies in the tensor map to the same offset.
|
||||
|
||||
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
|
||||
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA)));
|
||||
auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB)));
|
||||
auto blockscale_a_coord = cutlass::make_Coord(blockscale_m * options.l, blockscale_k);
|
||||
auto blockscale_b_coord = cutlass::make_Coord(blockscale_k, blockscale_n * options.l);
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
blockscale_tensor_A.resize(blockscale_a_coord);
|
||||
@ -405,10 +398,6 @@ void initialize(const Options<RasterOrderOptions> &options) {
|
||||
blockscale_tensor_A.sync_device();
|
||||
blockscale_tensor_B.sync_device();
|
||||
|
||||
// Note : This value has to match the KernelSchedule::ScalePromotionInterval
|
||||
// Else kernel will fail can_implement() check
|
||||
// Deprecation Notice : We plan to remove this params member in an upcoming release
|
||||
// Users can safely delete this line from their code, since the default is already 4
|
||||
mma_promotion_interval = 4;
|
||||
|
||||
if (options.save_aux) {
|
||||
@ -445,18 +434,14 @@ void initialize(const Options<RasterOrderOptions> &options) {
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.resize(cutlass::make_Coord(1));
|
||||
initialize_tensor(abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0);
|
||||
abs_max_D.sync_device();
|
||||
reference_abs_max_D.resize(cutlass::make_Coord(1));
|
||||
initialize_tensor(reference_abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0);
|
||||
}
|
||||
|
||||
if (IsAuxFp8 && options.save_aux && options.save_amax) {
|
||||
abs_max_aux.resize(cutlass::make_Coord(1));
|
||||
initialize_tensor(abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0);
|
||||
abs_max_aux.sync_device();
|
||||
reference_abs_max_aux.resize(cutlass::make_Coord(1));
|
||||
initialize_tensor(reference_abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -472,9 +457,7 @@ typename Gemm::Arguments args_from_options(const Options<RasterOrderOptions> &op
|
||||
stride_B,
|
||||
mma_promotion_interval,
|
||||
blockscale_tensor_A.device_data(),
|
||||
layout_SFA,
|
||||
blockscale_tensor_B.device_data(),
|
||||
layout_SFB
|
||||
blockscale_tensor_B.device_data()
|
||||
},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
@ -528,6 +511,13 @@ bool verify(const Options<RasterOrderOptions> &options) {
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Block scaling tensors shapes based CTA Block (TileShape) and GEMM Problem shape
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
|
||||
auto blockscale_m = cute::get<0>(blockscale_shape);
|
||||
auto blockscale_n = cute::get<1>(blockscale_shape);
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
cute::make_layout(
|
||||
@ -560,18 +550,28 @@ bool verify(const Options<RasterOrderOptions> &options) {
|
||||
)
|
||||
);
|
||||
|
||||
auto SFA = cute::make_tensor(blockscale_tensor_A.host_data(), layout_SFA);
|
||||
auto SFB = cute::make_tensor(blockscale_tensor_B.host_data(), layout_SFB);
|
||||
auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(blockscale_m, blockscale_k, options.l),
|
||||
cute::make_stride(blockscale_k, 1, blockscale_m * blockscale_k)
|
||||
)
|
||||
);
|
||||
auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(blockscale_n, blockscale_k, options.l),
|
||||
cute::make_stride(blockscale_k, 1, blockscale_n * blockscale_k)
|
||||
)
|
||||
);
|
||||
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator,
|
||||
decltype(A),
|
||||
decltype(SFA),
|
||||
decltype(B),
|
||||
decltype(SFB)
|
||||
> mainloop_params{A, SFA, B, SFB};
|
||||
cutlass::reference::host::GettMainloopParams<ElementAccumulator,
|
||||
decltype(A), decltype(B),
|
||||
decltype(blockscale_A), decltype(blockscale_B),
|
||||
TileShape> mainloop_params{
|
||||
A, B, // Operand Tensors
|
||||
blockscale_A, blockscale_B // Blockwise scaling Tensors
|
||||
};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
@ -604,40 +604,29 @@ bool verify(const Options<RasterOrderOptions> &options) {
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// compare_reference
|
||||
bool passed = true;
|
||||
tensor_D.sync_host();
|
||||
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_D.host_view(), tensor_ref_D.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
|
||||
double mse = cutlass::reference::host::TensorMSE(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
double mre = cutlass::reference::host::TensorMRE(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
double max_error = cutlass::reference::host::TensorGreatestError(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
std::cout << " Result MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
|
||||
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
|
||||
#if 0
|
||||
std::cout << "tensor_ref_D.host_view() {" << std::endl
|
||||
<< tensor_ref_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
std::cout << "tensor_D.host_view() {" << std::endl
|
||||
<< tensor_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
#endif
|
||||
if (false) {
|
||||
std::cout << "tensor_ref_D.host_view() {" << std::endl
|
||||
<< tensor_ref_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
std::cout << "tensor_D.host_view() {" << std::endl
|
||||
<< tensor_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.sync_host();
|
||||
std::cout << " Abs max D: " << abs_max_D.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_D.at(cutlass::make_Coord(0)) << std::endl;
|
||||
passed &= cutlass::relatively_equal(abs_max_D.at(cutlass::make_Coord(0)), reference_abs_max_D.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
|
||||
passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0));
|
||||
}
|
||||
|
||||
if (options.save_aux) {
|
||||
tensor_aux.sync_host();
|
||||
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_aux.host_view(), tensor_ref_aux.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
|
||||
mse = cutlass::reference::host::TensorMSE(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
mre = cutlass::reference::host::TensorMRE(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
max_error = cutlass::reference::host::TensorGreatestError(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
std::cout << " Aux MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
|
||||
passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view());
|
||||
if (IsAuxFp8 && options.save_amax) {
|
||||
abs_max_aux.sync_host();
|
||||
std::cout << " Abs max aux: " << abs_max_aux.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_aux.at(cutlass::make_Coord(0)) << std::endl;
|
||||
passed &= cutlass::relatively_equal(abs_max_aux.at(cutlass::make_Coord(0)), reference_abs_max_aux.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
|
||||
passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0));
|
||||
}
|
||||
}
|
||||
|
||||
@ -673,22 +662,20 @@ int run(Options<RasterOrderOptions> &options)
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
if (options.verify) {
|
||||
result.passed = verify(options);
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
}
|
||||
else {
|
||||
result.passed = true;
|
||||
}
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
// if (!result.passed) {
|
||||
// exit(-1);
|
||||
// }
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
|
||||
if (iter == options.warmup)
|
||||
timer.start();
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
@ -713,7 +700,7 @@ int run(Options<RasterOrderOptions> &options)
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return result.passed;
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -759,9 +746,7 @@ int main(int argc, char const **args) {
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
bool passed = run<Gemm>(options);
|
||||
if (!passed)
|
||||
return -1;
|
||||
run<Gemm>(options);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@ -75,11 +75,11 @@
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
// Includes from examples directory
|
||||
#include "helper.h"
|
||||
#include "hopper_fp8_commandline.hpp"
|
||||
#include "reference/host/gemm_with_groupwise_scaling.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -100,7 +100,7 @@ using LayoutB = cutlass::layout::ColumnMajor; // L
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
@ -120,30 +120,55 @@ using ElementAccumulator = float; // E
|
||||
using ElementBlockScale = float; // Element type for blockscaling during accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using TileShape_ = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()...
|
||||
|
||||
constexpr int ScaleGranularityM = 1;
|
||||
constexpr int ScaleGranularityN = 128;
|
||||
constexpr int ScaleGranularityK = 128;
|
||||
// ScaleGranularity{M,N}: number of {rows in A}/{columns in B} that share the same scaling factor
|
||||
// Given TileShape = Shape<_128,_128,_128>:
|
||||
// ScaleGranularityM == 128 and ScaleGranularityN == 128 --> 2Dx2D (the shape of the scaling factor)
|
||||
// ScaleGranularityM == 1 and ScaleGranularityN == 128 --> 1Dx2D scaling
|
||||
// ScaleGranularityM == 128 and ScaleGranularityN == 1 --> 2Dx1D scaling
|
||||
// ScaleGranularityM == 1 and ScaleGranularityN == 1 --> 1Dx1D scaling
|
||||
template <int ScaleGranularityM_, int ScaleGranularityN_>
|
||||
struct GroupScaleConfig {
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
|
||||
static constexpr int ScaleGranularityM = ScaleGranularityM_;
|
||||
static constexpr int ScaleGranularityN = ScaleGranularityN_;
|
||||
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||
static_assert(size<0>(TileShape{}) == ScaleGranularityM * ScaleMsPerTile,
|
||||
"FP8 scaling granularity must evenly divide tile shape along M.");
|
||||
static_assert(size<1>(TileShape{}) == ScaleGranularityN * ScaleNsPerTile,
|
||||
"FP8 scaling granularity must evenly divide tile shape along N.");
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_, ScaleGranularityN_>;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
|
||||
LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementC>;
|
||||
};
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
using GroupScale1D1DConfig = GroupScaleConfig< 1, 1>;
|
||||
using GroupScale1D2DConfig = GroupScaleConfig< 1, size<1>(TileShape_{})>;
|
||||
using GroupScale2D1DConfig = GroupScaleConfig<size<0>(TileShape_{}), 1>;
|
||||
using GroupScale2D2DConfig = GroupScaleConfig<size<0>(TileShape_{}), size<1>(TileShape_{})>;
|
||||
|
||||
template <typename ScheduleConfig>
|
||||
struct GroupScaleGemm {
|
||||
using ArchTag = typename ScheduleConfig::ArchTag;
|
||||
using OperatorClass = typename ScheduleConfig::OperatorClass;
|
||||
using TileShape = typename ScheduleConfig::TileShape;
|
||||
using ClusterShape = typename ScheduleConfig::ClusterShape;
|
||||
using KernelSchedule = typename ScheduleConfig::KernelSchedule;
|
||||
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
|
||||
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
|
||||
using FusionOperation = typename ScheduleConfig::FusionOperation;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
@ -154,10 +179,10 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, cute::tuple<LayoutA, LayoutSFA>, AlignmentA,
|
||||
ElementB, cute::tuple<LayoutB, LayoutSFB>, AlignmentB,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
@ -166,26 +191,38 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
cutlass::gemm::StreamKScheduler
|
||||
using GemmKernelDefault = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloopWithGroupWiseScaling,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using GemmKernelStreamK = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloopWithGroupWiseScaling,
|
||||
CollectiveEpilogue,
|
||||
cutlass::gemm::StreamKScheduler
|
||||
>;
|
||||
|
||||
using GemmDefault = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelDefault>;
|
||||
using GemmStreamK = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelStreamK>;
|
||||
};
|
||||
|
||||
using GroupScale1D1DGemm = GroupScaleGemm<GroupScale1D1DConfig>;
|
||||
using GroupScale1D2DGemm = GroupScaleGemm<GroupScale1D2DConfig>;
|
||||
using GroupScale2D1DGemm = GroupScaleGemm<GroupScale2D1DConfig>;
|
||||
using GroupScale2D2DGemm = GroupScaleGemm<GroupScale2D2DConfig>;
|
||||
|
||||
// Extract information from Gemm kernel.
|
||||
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
|
||||
using EpilogueOutputOp = typename GroupScale1D1DGemm::GemmDefault::EpilogueOutputOp;
|
||||
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
|
||||
using ElementAmax = typename EpilogueOutputOp::ElementAmax;
|
||||
using ActivationFunctor = typename EpilogueOutputOp::ActivationFn;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using StrideA = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideA;
|
||||
using StrideB = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideB;
|
||||
using StrideC = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideC;
|
||||
using StrideD = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideD;
|
||||
using StrideAux = StrideD;
|
||||
|
||||
constexpr bool IsDFp8 =
|
||||
@ -205,23 +242,20 @@ StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
StrideAux stride_aux;
|
||||
LayoutSFA layout_SFA;
|
||||
LayoutSFB layout_SFB;
|
||||
uint64_t seed;
|
||||
|
||||
using LayoutScalar = cutlass::layout::PackedVectorLayout;
|
||||
|
||||
cutlass::HostTensor<ElementA , LayoutA > tensor_A;
|
||||
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
|
||||
cutlass::HostTensor<ElementC , LayoutC > tensor_C;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
|
||||
uint32_t mma_promotion_interval;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutScalar> blockscale_tensor_A;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutScalar> blockscale_tensor_B;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutA> blockscale_tensor_A;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutB> blockscale_tensor_B;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
|
||||
cutlass::HostTensor<ElementAux, LayoutAux> tensor_aux;
|
||||
cutlass::HostTensor<ElementAux, LayoutAux> tensor_ref_aux;
|
||||
|
||||
using LayoutScalar = cutlass::layout::PackedVectorLayout;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_alpha;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_beta;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_A;
|
||||
@ -269,114 +303,120 @@ struct Result
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, bits_input);
|
||||
return true;
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Helper to initialize a block of device data (scale_tensors)
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
double scope_max, scope_min;
|
||||
|
||||
scope_min = -1;
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
scope_max = 1;
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min);
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
template <typename GroupScaleConfig>
|
||||
void initialize(const Options<RasterOrderOptions> &options) {
|
||||
|
||||
assert(options.m % ScaleGranularityM == 0);
|
||||
assert(options.n % ScaleGranularityN == 0);
|
||||
using TileShape = typename GroupScaleConfig::TileShape;
|
||||
const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile;
|
||||
const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile;
|
||||
|
||||
// Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
|
||||
auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access.
|
||||
auto groupscale_n = cute::get<1>(blockscale_shape) * ScaleNsPerTile; // We need to pad along N in scale tensor of A to prevent illegal memory access.
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_aux = stride_D;
|
||||
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
|
||||
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
|
||||
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto groupscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA)));
|
||||
auto groupscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB)));
|
||||
auto groupscale_a_coord = cutlass::make_Coord(groupscale_m * options.l, blockscale_k);
|
||||
auto groupscale_b_coord = cutlass::make_Coord(groupscale_n * options.l, blockscale_k);
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
@ -413,10 +453,6 @@ void initialize(const Options<RasterOrderOptions> &options) {
|
||||
blockscale_tensor_A.sync_device();
|
||||
blockscale_tensor_B.sync_device();
|
||||
|
||||
// Note : This value has to match the KernelSchedule::ScalePromotionInterval
|
||||
// Else kernel will fail can_implement() check
|
||||
// Deprecation Notice : We plan to remove this params member in an upcoming release
|
||||
// Users can safely delete this line from their code, since the default is already 4
|
||||
mma_promotion_interval = 4;
|
||||
|
||||
if (options.save_aux) {
|
||||
@ -481,9 +517,7 @@ GemmArguments args_from_options(const Options<RasterOrderOptions> &options)
|
||||
stride_B,
|
||||
mma_promotion_interval,
|
||||
blockscale_tensor_A.device_data(),
|
||||
layout_SFA,
|
||||
blockscale_tensor_B.device_data(),
|
||||
layout_SFB
|
||||
blockscale_tensor_B.device_data()
|
||||
},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
@ -533,11 +567,18 @@ GemmArguments args_from_options(const Options<RasterOrderOptions> &options)
|
||||
}
|
||||
|
||||
/// Don't know why the compiler does not like verify() being templated...
|
||||
bool verify(const Options<RasterOrderOptions> &options) {
|
||||
bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile, const int ScaleNsPerTile) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape_{})));
|
||||
auto blockscale_m = cute::get<0>(blockscale_shape);
|
||||
auto blockscale_n = cute::get<1>(blockscale_shape);
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
cute::make_layout(
|
||||
@ -570,18 +611,28 @@ bool verify(const Options<RasterOrderOptions> &options) {
|
||||
)
|
||||
);
|
||||
|
||||
auto SFA = cute::make_tensor(blockscale_tensor_A.host_data(), layout_SFA);
|
||||
auto SFB = cute::make_tensor(blockscale_tensor_B.host_data(), layout_SFB);
|
||||
auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(blockscale_m, ScaleMsPerTile, blockscale_k, options.l),
|
||||
cute::make_stride(blockscale_k * ScaleMsPerTile, 1, ScaleMsPerTile, blockscale_m * blockscale_k * ScaleMsPerTile)
|
||||
)
|
||||
);
|
||||
auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(blockscale_n, ScaleNsPerTile, blockscale_k, options.l),
|
||||
cute::make_stride(blockscale_k * ScaleNsPerTile, 1, ScaleNsPerTile, blockscale_n * blockscale_k * ScaleNsPerTile)
|
||||
)
|
||||
);
|
||||
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator,
|
||||
decltype(A),
|
||||
decltype(SFA),
|
||||
decltype(B),
|
||||
decltype(SFB)
|
||||
> mainloop_params{A, SFA, B, SFB};
|
||||
cutlass::reference::host::GettMainloopParams<ElementAccumulator,
|
||||
decltype(A), decltype(B),
|
||||
decltype(blockscale_A), decltype(blockscale_B),
|
||||
TileShape_> mainloop_params{
|
||||
A, B, // Operand Tensors
|
||||
blockscale_A, blockscale_B // Groupwise scaling Tensors
|
||||
};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
@ -614,40 +665,29 @@ bool verify(const Options<RasterOrderOptions> &options) {
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// compare_reference
|
||||
bool passed = true;
|
||||
tensor_D.sync_host();
|
||||
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_D.host_view(), tensor_ref_D.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
|
||||
double mse = cutlass::reference::host::TensorMSE(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
double mre = cutlass::reference::host::TensorMRE(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
double max_error = cutlass::reference::host::TensorGreatestError(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
std::cout << " Result MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
|
||||
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
|
||||
#if 0
|
||||
std::cout << "tensor_ref_D.host_view() {" << std::endl
|
||||
<< tensor_ref_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
std::cout << "tensor_D.host_view() {" << std::endl
|
||||
<< tensor_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
#endif
|
||||
if (false) {
|
||||
std::cout << "tensor_ref_D.host_view() {" << std::endl
|
||||
<< tensor_ref_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
std::cout << "tensor_D.host_view() {" << std::endl
|
||||
<< tensor_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.sync_host();
|
||||
std::cout << " Abs max D: " << abs_max_D.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_D.at(cutlass::make_Coord(0)) << std::endl;
|
||||
passed &= cutlass::relatively_equal(abs_max_D.at(cutlass::make_Coord(0)), reference_abs_max_D.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
|
||||
passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0));
|
||||
}
|
||||
|
||||
if (options.save_aux) {
|
||||
tensor_aux.sync_host();
|
||||
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_aux.host_view(), tensor_ref_aux.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
|
||||
mse = cutlass::reference::host::TensorMSE(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
mre = cutlass::reference::host::TensorMRE(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
max_error = cutlass::reference::host::TensorGreatestError(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
std::cout << " Aux MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
|
||||
passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view());
|
||||
if (IsAuxFp8 && options.save_amax) {
|
||||
abs_max_aux.sync_host();
|
||||
std::cout << " Abs max aux: " << abs_max_aux.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_aux.at(cutlass::make_Coord(0)) << std::endl;
|
||||
passed &= cutlass::relatively_equal(abs_max_aux.at(cutlass::make_Coord(0)), reference_abs_max_aux.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
|
||||
passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0));
|
||||
}
|
||||
}
|
||||
|
||||
@ -655,34 +695,16 @@ bool verify(const Options<RasterOrderOptions> &options) {
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
int run(Options<RasterOrderOptions> &options) {
|
||||
template <typename GroupScaleConfig, typename Gemm>
|
||||
int run(Options<RasterOrderOptions> &options)
|
||||
{
|
||||
using TileShape = typename GroupScaleConfig::TileShape;
|
||||
const int ScaleGranularityM = GroupScaleConfig::ScaleGranularityM;
|
||||
const int ScaleGranularityN = GroupScaleConfig::ScaleGranularityN;
|
||||
const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile;
|
||||
const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile;
|
||||
|
||||
bool skip = false;
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
|
||||
|
||||
if (options.m < ScaleGranularityM) {
|
||||
std::cout << " Skippig (m size: " << options.m << " less than ScaleGranularityM: " << ScaleGranularityM << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (options.n < ScaleGranularityN) {
|
||||
std::cout << " Skippig (n size: " << options.n << " less than ScaleGranularityN: " << ScaleGranularityN << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (options.k < size<2>(TileShape{})) {
|
||||
std::cout << " Skippig (k size: " << options.k << " less than TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (!skip) std::cout << " Running... " << std::endl;
|
||||
else return -1;
|
||||
|
||||
initialize(options);
|
||||
initialize<GroupScaleConfig>(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
@ -707,22 +729,20 @@ int run(Options<RasterOrderOptions> &options) {
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
if (options.verify) {
|
||||
result.passed = verify(options);
|
||||
result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
}
|
||||
else {
|
||||
result.passed = true;
|
||||
}
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
// if (!result.passed) {
|
||||
// exit(-1);
|
||||
// }
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
|
||||
if (iter == options.warmup)
|
||||
timer.start();
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
@ -742,13 +762,17 @@ int run(Options<RasterOrderOptions> &options) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
return result.passed;
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -794,10 +818,27 @@ int main(int argc, char const **args) {
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
bool passed = true;
|
||||
passed = run(options);
|
||||
if (!passed)
|
||||
return -1;
|
||||
std::cout << "Basic split-K GEMM kernel" << std::endl;
|
||||
run<GroupScale1D1DConfig, GroupScale1D1DGemm::GemmDefault>(options);
|
||||
std::cout << std::endl;
|
||||
run<GroupScale1D2DConfig, GroupScale1D2DGemm::GemmDefault>(options);
|
||||
std::cout << std::endl;
|
||||
run<GroupScale2D1DConfig, GroupScale2D1DGemm::GemmDefault>(options);
|
||||
std::cout << std::endl;
|
||||
run<GroupScale2D2DConfig, GroupScale2D2DGemm::GemmDefault>(options);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "StreamK GEMM kernel" << std::endl;
|
||||
run<GroupScale1D1DConfig, GroupScale1D1DGemm::GemmStreamK>(options);
|
||||
std::cout << std::endl;
|
||||
run<GroupScale1D2DConfig, GroupScale1D2DGemm::GemmStreamK>(options);
|
||||
std::cout << std::endl;
|
||||
run<GroupScale2D1DConfig, GroupScale2D1DGemm::GemmStreamK>(options);
|
||||
std::cout << std::endl;
|
||||
run<GroupScale2D2DConfig, GroupScale2D2DGemm::GemmStreamK>(options);
|
||||
std::cout << std::endl;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@ -35,3 +35,8 @@ cutlass_example_add_executable(
|
||||
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling
|
||||
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
67_hopper_fp8_deepgemm
|
||||
67_hopper_fp8_deepgemm.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,13 @@
|
||||
import torch
|
||||
|
||||
from . import jit
|
||||
from .jit_kernels import (
|
||||
gemm_fp8_fp8_bf16_nt,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
cell_div,
|
||||
set_num_sms, get_num_sms,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_m_alignment_for_contiguous_layout
|
||||
)
|
||||
from .utils import bench, bench_kineto, calc_diff
|
||||
@ -0,0 +1,444 @@
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
#include "mma_utils.cuh"
|
||||
#include "scheduler.cuh"
|
||||
#include "tma_utils.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class Layout {
|
||||
RowMajor,
|
||||
ColMajor
|
||||
};
|
||||
|
||||
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
|
||||
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
|
||||
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
|
||||
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
||||
uint32_t kNumTMAMulticast,
|
||||
GemmType kGemmType>
|
||||
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
||||
fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
DG_STATIC_ASSERT(cell_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
|
||||
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// Shared memory
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
||||
static constexpr uint32_t SHAPE_K_SCALES = cell_div(SHAPE_K, BLOCK_K);
|
||||
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
||||
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
|
||||
constexpr uint32_t kNumIterations = cell_div(SHAPE_K, kFullKOfAllStages);
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_id();
|
||||
|
||||
// Prefetch TMA descriptors at very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
|
||||
// Data on shared memory
|
||||
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
|
||||
__nv_fp8_e4m3* smem_a[kNumStages];
|
||||
__nv_fp8_e4m3* smem_b[kNumStages];
|
||||
float* smem_scales_a[kNumStages];
|
||||
float* smem_scales_b;
|
||||
|
||||
// TMA Barrier for both divisible and non-divisible cases
|
||||
Barrier* full_barriers[kNumStages];
|
||||
Barrier* empty_barriers[kNumStages];
|
||||
|
||||
// Fill shared memory pointers
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}
|
||||
smem_scales_b = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE));
|
||||
|
||||
// Fill barriers
|
||||
DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers");
|
||||
DG_STATIC_ASSERT(not kMustUseUniformedScaleB or SHAPE_K_SCALES % (sizeof(Barrier) / sizeof(float)) == 0, "Misaligned barriers");
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_scales_b + SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i] = barrier_start_ptr + i;
|
||||
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
|
||||
}
|
||||
|
||||
// Initialize barriers
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast");
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
|
||||
}
|
||||
|
||||
// Synchronize all threads to make barrier visible in normal memory model
|
||||
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
auto launch_k_iterations = [](const auto& func) {
|
||||
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
|
||||
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
} else {
|
||||
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
func(kNumIterations - 1, NotDivisibleK{});
|
||||
}
|
||||
};
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr int kNumTMARegisters = 40;
|
||||
constexpr int kNumMathRegisters = 232;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
|
||||
// Issue TMA A with broadcasting
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||
|
||||
// Issue TMA B without broadcasting
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
if constexpr (kNumTMAMulticast > 1) {
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s)
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
|
||||
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
// Decide the number of scales B to load
|
||||
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
|
||||
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
|
||||
if constexpr (not kMustUseUniformedScaleB) {
|
||||
num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
|
||||
num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8;
|
||||
}
|
||||
uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2);
|
||||
|
||||
// Load B scales with math warp-groups
|
||||
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
|
||||
if (threadIdx.x >= 32) {
|
||||
auto num_previous_lines = scheduler.get_global_idx<false>(cell_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
|
||||
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
|
||||
#pragma unroll
|
||||
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
|
||||
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
|
||||
}
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Accumulation for WGMMA or CUDA promotion
|
||||
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](int s) {
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
|
||||
}
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Read B scales
|
||||
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
|
||||
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
|
||||
|
||||
// Wait TMA arrivals
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
|
||||
// Read A scales
|
||||
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
||||
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival
|
||||
empty_barrier_arrive(s);
|
||||
|
||||
// Promote with scales
|
||||
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
|
||||
float scale_0_1, scale_1_1;
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
||||
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
||||
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
||||
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
||||
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
||||
}
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
});
|
||||
|
||||
// Write back to shared memory using STSM
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
|
||||
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
|
||||
);
|
||||
}
|
||||
if constexpr (WGMMA::kNumAccum % 8 != 0) {
|
||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
|
||||
);
|
||||
}
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
if (threadIdx.x == 0) {
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
|
||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
cute::tma_store_arrive();
|
||||
cute::tma_store_wait<0>();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAMulticast,
|
||||
GemmType kGemmType>
|
||||
class Gemm {
|
||||
private:
|
||||
using Barrier = cuda::barrier<cuda::thread_scope_block>;
|
||||
|
||||
public:
|
||||
Gemm() = default;
|
||||
|
||||
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const CUtensorMap& tma_a_desc,
|
||||
const CUtensorMap& tma_b_desc,
|
||||
const CUtensorMap& tma_scales_a_desc,
|
||||
const CUtensorMap& tma_d_desc,
|
||||
cudaStream_t stream,
|
||||
int num_sms, uint32_t smem_size) {
|
||||
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
||||
constexpr uint32_t kNumTMAThreads = 128;
|
||||
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
||||
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
|
||||
kNumTMAMulticast, kGemmType>;
|
||||
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
||||
|
||||
// Cluster launch
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = num_sms;
|
||||
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
config.dynamicSmemBytes = smem_size;
|
||||
config.stream = stream;
|
||||
|
||||
// Clusters for TMA multicast
|
||||
// NOTES: `>= 4` cluster size will cause performance degradation
|
||||
cudaLaunchAttribute attr;
|
||||
attr.id = cudaLaunchAttributeClusterDimension;
|
||||
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
|
||||
config.attrs = &attr;
|
||||
config.numAttrs = 1;
|
||||
|
||||
// Launch
|
||||
auto status = cudaLaunchKernelEx(&config, kernel,
|
||||
gmem_d, scales_b, grouped_layout,
|
||||
shape_m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
|
||||
DG_HOST_ASSERT(status == cudaSuccess);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_b_desc(T* global_address) {
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, BLOCK_M, BLOCK_N,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
||||
// Make TMA aligned to 16 bytes
|
||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||
shape_m = cell_div(shape_m, kAlignment) * kAlignment;
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
shape_m, cell_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_desc(
|
||||
T* global_address, Layout layout,
|
||||
uint32_t gmem_rows, uint32_t gmem_cols,
|
||||
uint32_t smem_rows, uint32_t smem_cols,
|
||||
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
|
||||
if (layout == Layout::RowMajor) {
|
||||
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
|
||||
uint32_t smem_dim[2] = {smem_cols, smem_rows};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type);
|
||||
} else {
|
||||
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
|
||||
uint32_t smem_dim[2] = {smem_rows, smem_cols};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
@ -0,0 +1,885 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
struct SM90_64x16x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %10, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
||||
" %8,"
|
||||
" %9,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 16;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x24x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %14, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11},"
|
||||
" %12,"
|
||||
" %13,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 24;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x32x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %18, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
||||
" %16,"
|
||||
" %17,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 32;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x40x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %22, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19},"
|
||||
" %20,"
|
||||
" %21,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 40;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x48x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %26, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
||||
" %24,"
|
||||
" %25,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 48;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x56x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %30, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27}, "
|
||||
" %28,"
|
||||
" %29,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 56;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x64x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %34, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31}, "
|
||||
" %32,"
|
||||
" %33,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 64;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x72x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %38, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35}, "
|
||||
" %36,"
|
||||
" %37,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 72;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x80x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %42, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39}, "
|
||||
" %40,"
|
||||
" %41,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 80;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x88x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %46, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43}, "
|
||||
" %44,"
|
||||
" %45,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 88;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x96x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %50, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47}, "
|
||||
" %48,"
|
||||
" %49,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 96;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x104x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %54, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51}, "
|
||||
" %52,"
|
||||
" %53,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 104;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x112x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %58, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55}, "
|
||||
" %56,"
|
||||
" %57,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 112;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x120x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %62, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59}, "
|
||||
" %60,"
|
||||
" %61,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 120;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x128x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %66, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59, %60, %61, %62, %63}, "
|
||||
" %64,"
|
||||
" %65,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 128;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x192x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
||||
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
|
||||
float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79,
|
||||
float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87,
|
||||
float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %98, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
||||
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
||||
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
||||
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
||||
" %88, %89, %90, %91, %92, %93, %94, %95}, "
|
||||
" %96,"
|
||||
" %97,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
||||
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
||||
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
||||
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
||||
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
||||
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
|
||||
d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79],
|
||||
d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87],
|
||||
d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 192;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x2_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
|
||||
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
|
||||
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
|
||||
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x4_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) {
|
||||
const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
|
||||
*reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
|
||||
asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
||||
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
|
||||
}
|
||||
};
|
||||
|
||||
__device__ void warpgroup_arrive() {
|
||||
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
|
||||
__device__ void warpgroup_commit_batch() {
|
||||
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
|
||||
__device__ void warpgroup_fence_operand(float& reg) {
|
||||
asm volatile("" : "+f"(reg) :: "memory");
|
||||
}
|
||||
|
||||
__forceinline__ __device__ uint32_t get_lane_id() {
|
||||
uint32_t lane_id;
|
||||
asm("mov.u32 %0, %laneid;" : "=r"(lane_id));
|
||||
return lane_id;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) {
|
||||
uint32_t ret;
|
||||
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) {
|
||||
int4 ret;
|
||||
asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
|
||||
float ret;
|
||||
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
|
||||
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
|
||||
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ void warpgroup_wait() {
|
||||
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
|
||||
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
|
||||
}
|
||||
|
||||
union GmmaDescriptor {
|
||||
__host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept {
|
||||
desc_ = t.desc_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept {
|
||||
desc_ = t.desc_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
uint64_t desc_;
|
||||
uint32_t reg32_[2];
|
||||
uint16_t reg16_[4];
|
||||
|
||||
struct {
|
||||
uint16_t start_address_: 14, : 2;
|
||||
uint16_t leading_byte_offset_: 14, : 2;
|
||||
uint16_t stride_byte_offset_: 14, : 2;
|
||||
uint8_t : 1, base_offset_: 3, : 4;
|
||||
uint8_t : 6, layout_type_: 2;
|
||||
} bitfield;
|
||||
|
||||
// Decay to an `uint64_t`
|
||||
__host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; }
|
||||
};
|
||||
|
||||
template <class PointerType>
|
||||
__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type,
|
||||
int leading_byte_offset = 0,
|
||||
int stride_byte_offset = 1024) {
|
||||
GmmaDescriptor desc;
|
||||
auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
desc.bitfield.start_address_ = uint_ptr >> 4;
|
||||
desc.bitfield.layout_type_ = layout_type;
|
||||
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
|
||||
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
|
||||
desc.bitfield.base_offset_ = 0;
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <int N>
|
||||
struct FP8MMASelector {
|
||||
static constexpr auto select_type() {
|
||||
if constexpr (N == 16) return SM90_64x16x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 24) return SM90_64x24x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 32) return SM90_64x32x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 40) return SM90_64x40x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 48) return SM90_64x48x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 56) return SM90_64x56x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 64) return SM90_64x64x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 72) return SM90_64x72x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 80) return SM90_64x80x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 88) return SM90_64x88x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 96) return SM90_64x96x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 104) return SM90_64x104x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS();
|
||||
}
|
||||
|
||||
using type = decltype(select_type());
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
@ -0,0 +1,103 @@
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class GemmType {
|
||||
Normal,
|
||||
GroupedContiguous,
|
||||
GroupedMasked
|
||||
};
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
|
||||
template <GemmType kGemmType,
|
||||
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
|
||||
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
|
||||
uint32_t kNumNBlocks = cell_div(SHAPE_N, BLOCK_N),
|
||||
uint32_t kNumNBlocksPerGroup = 16>
|
||||
struct Scheduler {
|
||||
int current_iter = -1;
|
||||
uint32_t num_aligned_m_blocks;
|
||||
|
||||
// For normal GEMM
|
||||
// Maybe not used in the masked grouped GEMM
|
||||
uint32_t num_blocks;
|
||||
|
||||
// For grouped GEMM
|
||||
int* grouped_layout;
|
||||
// Only used for masked layout
|
||||
uint32_t curr_group_idx, curr_cumsum;
|
||||
|
||||
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
|
||||
int* grouped_layout = nullptr) {
|
||||
num_aligned_m_blocks = cell_div(shape_m, BLOCK_M);
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
||||
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
||||
this->grouped_layout = grouped_layout;
|
||||
} else if (kGemmType == GemmType::GroupedMasked) {
|
||||
curr_group_idx = curr_cumsum = 0;
|
||||
this->grouped_layout = grouped_layout;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||
|
||||
// Swizzle for better L2 usages
|
||||
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
|
||||
auto group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
|
||||
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
m_block_idx = in_group_idx / num_n_blocks_in_group;
|
||||
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
|
||||
}
|
||||
|
||||
template <bool kIgnoreGroupedForGroupedContiguous=true>
|
||||
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
|
||||
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return block_idx * block_size;
|
||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
||||
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||
return offset * shape_dim + block_idx * block_size;
|
||||
} else if (kGemmType == GemmType::GroupedMasked) {
|
||||
return curr_group_idx * shape_dim + block_idx * block_size;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;
|
||||
|
||||
if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||
uint32_t num_m_blocks;
|
||||
while (true) {
|
||||
// End of the task
|
||||
if (curr_group_idx == kNumGroups)
|
||||
return false;
|
||||
|
||||
// Within current group
|
||||
num_m_blocks = cell_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
|
||||
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
|
||||
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
|
||||
break;
|
||||
|
||||
// Move to check the next group
|
||||
curr_group_idx ++, curr_cumsum = current_m_block_cumsum;
|
||||
}
|
||||
|
||||
get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
|
||||
} else {
|
||||
if (next_block_idx >= num_blocks)
|
||||
return false;
|
||||
|
||||
get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
} // namespace deep_gemm
|
||||
@ -0,0 +1,96 @@
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <cuda.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda/barrier>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <class T>
|
||||
constexpr CUtensorMapDataType get_CUtensorMapDataType() {
|
||||
if constexpr (std::is_same<T, uint8_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, uint16_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
|
||||
} else if constexpr (std::is_same<T, uint32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
|
||||
} else if constexpr (std::is_same<T, uint64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
|
||||
} else if constexpr (std::is_same<T, int32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT32;
|
||||
} else if constexpr (std::is_same<T, int64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT64;
|
||||
} else if constexpr (std::is_same<T, __half>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
} else if constexpr (std::is_same<T, float>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
} else if constexpr (std::is_same<T, double>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
|
||||
}
|
||||
}
|
||||
|
||||
PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
|
||||
// Get pointer to `cuTensorMapEncodeTiled`
|
||||
cudaDriverEntryPointQueryResult driver_status;
|
||||
void* cuTensorMapEncodeTiled_ptr = nullptr;
|
||||
|
||||
#if CUDA_VERSION >= 12050
|
||||
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#else
|
||||
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#endif
|
||||
|
||||
if (driver_status != cudaDriverEntryPointSuccess)
|
||||
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
|
||||
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
|
||||
uint64_t stride_in_bytes, uint32_t smem_dim[2],
|
||||
CUtensorMapSwizzle swizzle_type,
|
||||
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
|
||||
CUtensorMap tensor_map{};
|
||||
constexpr uint32_t rank = 2;
|
||||
uint64_t global_stride[rank - 1] = {stride_in_bytes};
|
||||
uint32_t elem_strides[rank] = {1, 1};
|
||||
|
||||
if (encode_func == nullptr)
|
||||
encode_func = get_cuTensorMapEncodeTiled();
|
||||
|
||||
auto result = encode_func(
|
||||
&tensor_map, get_CUtensorMapDataType<typename std::remove_cv<T>::type>(), rank,
|
||||
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
DG_HOST_ASSERT(result == CUDA_SUCCESS);
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
template <uint32_t kNumTMAMulticast = 1>
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
int32_t const& crd_0, int32_t const& crd_1) {
|
||||
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
} else if (cute::block_rank_in_cluster() == 0) {
|
||||
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@ -0,0 +1,48 @@
|
||||
#pragma once
|
||||
|
||||
#include <exception>
|
||||
|
||||
#ifdef __CLION_IDE__
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
|
||||
#define printf host_device_printf
|
||||
#endif
|
||||
|
||||
class AssertionException : public std::exception {
|
||||
private:
|
||||
std::string message{};
|
||||
|
||||
public:
|
||||
explicit AssertionException(const std::string& message) : message(message) {}
|
||||
|
||||
const char *what() const noexcept override { return message.c_str(); }
|
||||
};
|
||||
|
||||
#ifndef DG_HOST_ASSERT
|
||||
#define DG_HOST_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", \
|
||||
__FILE__, __LINE__, #cond); \
|
||||
throw AssertionException("Assertion failed: " #cond); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_DEVICE_ASSERT
|
||||
#define DG_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
|
||||
asm("trap;"); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_STATIC_ASSERT
|
||||
#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ constexpr T cell_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
@ -0,0 +1,3 @@
|
||||
from .compiler import get_nvcc_compiler, build
|
||||
from .template import cpp_format, generate
|
||||
from .runtime import Runtime
|
||||
@ -0,0 +1,146 @@
|
||||
import hashlib
|
||||
import functools
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import uuid
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
from typing import Tuple
|
||||
|
||||
from .runtime import Runtime, RuntimeCache
|
||||
from .template import typename_map
|
||||
|
||||
runtime_cache = RuntimeCache()
|
||||
|
||||
|
||||
def hash_to_hex(s: str) -> str:
|
||||
md5 = hashlib.md5()
|
||||
md5.update(s.encode('utf-8'))
|
||||
return md5.hexdigest()[0:12]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_jit_include_dir() -> str:
|
||||
return f'{os.path.dirname(os.path.abspath(__file__))}/../include'
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_deep_gemm_version() -> str:
|
||||
# Update include directories
|
||||
include_dir = f'{get_jit_include_dir()}/deep_gemm'
|
||||
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||
md5 = hashlib.md5()
|
||||
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
|
||||
with open(f'{include_dir}/{filename}', 'rb') as f:
|
||||
md5.update(f.read())
|
||||
|
||||
return md5.hexdigest()[0:12]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_nvcc_compiler() -> Tuple[str, str]:
|
||||
paths = []
|
||||
if os.getenv('DG_NVCC_COMPILER'):
|
||||
paths.append(os.getenv('DG_NVCC_COMPILER'))
|
||||
paths.append(f'{CUDA_HOME}/bin/nvcc')
|
||||
|
||||
# Try to find the first available NVCC compiler
|
||||
least_version_required = '12.3'
|
||||
version_pattern = re.compile(r'release (\d+\.\d+)')
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
match = version_pattern.search(os.popen(f'{path} --version').read())
|
||||
version = match.group(1)
|
||||
assert match, f'Cannot get the version of NVCC compiler {path}'
|
||||
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
|
||||
return path, version
|
||||
raise RuntimeError('Cannot find any available NVCC compiler')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_default_user_dir():
|
||||
if 'DG_CACHE_DIR' in os.environ:
|
||||
path = os.getenv('DG_CACHE_DIR')
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
return os.path.expanduser('~') + '/.deep_gemm'
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_tmp_dir():
|
||||
return f'{get_default_user_dir()}/tmp'
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_cache_dir():
|
||||
return f'{get_default_user_dir()}/cache'
|
||||
|
||||
|
||||
def make_tmp_dir():
|
||||
tmp_dir = get_tmp_dir()
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
def put(path, data, is_binary=False):
|
||||
# Write and do POSIX atomic replace
|
||||
tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}'
|
||||
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
|
||||
f.write(data)
|
||||
os.replace(tmp_file_path, path)
|
||||
|
||||
|
||||
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
|
||||
# Compiler flags
|
||||
nvcc_flags = ['-std=c++17', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
'-gencode=arch=compute_90a,code=sm_90a',
|
||||
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=177,174,940']
|
||||
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi']
|
||||
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
|
||||
include_dirs = [get_jit_include_dir()]
|
||||
|
||||
# Build signature
|
||||
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||
path = f'{get_cache_dir()}/{name}'
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
if runtime_cache[path] is not None:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
return runtime_cache[path]
|
||||
|
||||
# Write the code
|
||||
os.makedirs(path, exist_ok=True)
|
||||
args_path = f'{path}/kernel.args'
|
||||
src_path = f'{path}/kernel.cu'
|
||||
put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]))
|
||||
put(src_path, code)
|
||||
|
||||
# Compile into a temporary SO file
|
||||
so_path = f'{path}/kernel.so'
|
||||
tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so'
|
||||
|
||||
# Compile
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', tmp_so_path,
|
||||
*flags,
|
||||
*[f'-I{d}' for d in include_dirs]]
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
||||
print(f'Compiling JIT runtime {name} with command {command}')
|
||||
assert subprocess.check_call(command) == 0, f'Failed to compile {src_path}'
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
pass
|
||||
|
||||
# Atomic replace SO file
|
||||
os.replace(tmp_so_path, so_path)
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path)
|
||||
return runtime_cache[path]
|
||||
@ -0,0 +1,66 @@
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from .template import map_ctype
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.args = None
|
||||
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@staticmethod
|
||||
def is_path_valid(path: str) -> bool:
|
||||
# Exists and is a directory
|
||||
if not os.path.exists(path) or not os.path.isdir(path):
|
||||
return False
|
||||
|
||||
# Contains all necessary files
|
||||
files = ['kernel.cu', 'kernel.args', 'kernel.so']
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
def __call__(self, *args) -> int:
|
||||
# Load SO file
|
||||
if self.lib is None or self.args is None:
|
||||
self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so'))
|
||||
with open(os.path.join(self.path, 'kernel.args'), 'r') as f:
|
||||
self.args = eval(f.read())
|
||||
|
||||
# Check args and launch
|
||||
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
||||
cargs = []
|
||||
for arg, (name, dtype) in zip(args, self.args):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
|
||||
else:
|
||||
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
|
||||
cargs.append(map_ctype(arg))
|
||||
|
||||
return_code = ctypes.c_int(0)
|
||||
self.lib.launch(*cargs, ctypes.byref(return_code))
|
||||
return return_code.value
|
||||
|
||||
|
||||
class RuntimeCache:
|
||||
def __init__(self) -> None:
|
||||
self.cache = {}
|
||||
|
||||
def __getitem__(self, path: str) -> Optional[Runtime]:
|
||||
# In Python runtime
|
||||
if path in self.cache:
|
||||
return self.cache[path]
|
||||
|
||||
# Already compiled
|
||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
runtime = Runtime(path)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
return None
|
||||
|
||||
def __setitem__(self, path, runtime) -> None:
|
||||
self.cache[path] = runtime
|
||||
@ -0,0 +1,93 @@
|
||||
import copy
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
|
||||
from typing import Any, Iterable, Dict, Tuple
|
||||
|
||||
|
||||
# Name map for Python `eval`
|
||||
typename_map: Dict[Any, str] = {
|
||||
**{t: t.__name__ for t in (bool, int, float)},
|
||||
torch.int: 'torch.int',
|
||||
torch.float: 'torch.float',
|
||||
torch.bfloat16: 'torch.bfloat16',
|
||||
torch.float8_e4m3fn: 'torch.float8_e4m3fn',
|
||||
torch.cuda.Stream: 'torch.cuda.Stream',
|
||||
}
|
||||
|
||||
# `ctype` map for Python casting
|
||||
ctype_map: Dict[Any, Any] = {
|
||||
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
|
||||
**{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
|
||||
}
|
||||
|
||||
|
||||
# Type map for both Python API and source code usages
|
||||
genc_map = {
|
||||
bool: ('bool', 'bool'),
|
||||
int: ('int', 'int'),
|
||||
float: ('float', 'float'),
|
||||
torch.int: ('void*', 'int*'),
|
||||
torch.float: ('void*', 'float*'),
|
||||
torch.bfloat16: ('void*', '__nv_bfloat16*'),
|
||||
torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
||||
torch.cuda.Stream: ('void*', 'cudaStream_t'),
|
||||
}
|
||||
|
||||
|
||||
def map_ctype(value: Any) -> Any:
|
||||
ctype = ctype_map[value.dtype if isinstance(value, torch.Tensor) else type(value)]
|
||||
if isinstance(value, torch.Tensor):
|
||||
return ctype(value.data_ptr())
|
||||
if isinstance(value, torch.cuda.Stream):
|
||||
return ctype(value.cuda_stream)
|
||||
return ctype(value)
|
||||
|
||||
|
||||
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
|
||||
# We don't use `str.format` because it's not safe for C++ {} braces
|
||||
new_template = copy.deepcopy(template)
|
||||
for key, value in keys.items():
|
||||
new_template = new_template.replace(f'{{{key}}}', f'{value}')
|
||||
return new_template
|
||||
|
||||
|
||||
def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
|
||||
# Common prefix
|
||||
code = '// DeepGEMM auto-generated JIT CUDA source file\n\n'
|
||||
|
||||
# Includes
|
||||
preload_sys_includes = ['<cuda.h>', '<cuda_fp8.h>', '<cuda_runtime.h>', '<iostream>']
|
||||
preload_package_includes = ['"cutlass/cutlass.h"']
|
||||
|
||||
assert isinstance(includes, list) or isinstance(includes, tuple)
|
||||
sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')])))
|
||||
package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')])))
|
||||
code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n'
|
||||
code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n'
|
||||
|
||||
# Function signature
|
||||
raw = '__raw_'
|
||||
get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n
|
||||
code += f'extern "C" void launch('
|
||||
code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ])
|
||||
code += ') {\n'
|
||||
|
||||
# Cast raw types
|
||||
code += ' // Cast raw types (if needed)\n'
|
||||
for arg_name, arg_type in arg_defs:
|
||||
if genc_map[arg_type][0] != genc_map[arg_type][1]:
|
||||
code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n'
|
||||
|
||||
# Function body
|
||||
code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')])
|
||||
|
||||
# End the function
|
||||
code += '}\n\n'
|
||||
|
||||
# Debug print
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Generated code:\n{code}')
|
||||
|
||||
return code
|
||||
@ -0,0 +1,10 @@
|
||||
from .gemm import gemm_fp8_fp8_bf16_nt
|
||||
from .m_grouped_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
||||
)
|
||||
from .utils import (
|
||||
cell_div, set_num_sms, get_num_sms,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_m_alignment_for_contiguous_layout
|
||||
)
|
||||
@ -0,0 +1,171 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_num_sms, cell_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
||||
template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
|
||||
// Make a templated GEMM
|
||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
|
||||
GemmType::run(out, rhs_scales, nullptr,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool:
|
||||
if num_tma_multicast == 1:
|
||||
return True
|
||||
return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
|
||||
|
||||
|
||||
def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int:
|
||||
smem_d = block_m * block_n * 2
|
||||
smem_a_per_stage = block_m * block_k
|
||||
smem_scales_a_per_stage = block_m * 4
|
||||
smem_b_per_stage = block_n * block_k
|
||||
smem_scales_b = cell_div(k, block_k) * 4
|
||||
smem_barrier = num_stages * 8 * 2
|
||||
|
||||
smem_size = 0
|
||||
smem_size += smem_d
|
||||
smem_size += num_stages * smem_a_per_stage
|
||||
smem_size += num_stages * smem_scales_a_per_stage
|
||||
smem_size += num_stages * smem_b_per_stage
|
||||
smem_size += smem_scales_b * (1 if block_k % block_n == 0 else 2)
|
||||
smem_size += smem_barrier
|
||||
return smem_size
|
||||
|
||||
|
||||
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int]:
|
||||
if not is_grouped_contiguous:
|
||||
# TODO: for some cases, smaller M block is better, add them into tuning space
|
||||
block_ms = (64 if m <= 64 else 128, )
|
||||
else:
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
block_ns = tuple(range(16, 129, 8))
|
||||
|
||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||
get_num_waves = lambda bm, bn: (cell_div(cell_div(m, bm) * cell_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
get_last_wave_util = lambda bm, bn: fix_wave_saturate((cell_div(m, bm) * cell_div(n, bn) * num_groups) % num_sms)
|
||||
|
||||
# Decide block sizes by waves
|
||||
best_block_m, best_block_n = None, None
|
||||
for block_m in block_ms:
|
||||
for block_n in block_ns:
|
||||
success = False
|
||||
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
||||
if best_block_m is None or best_block_n is None:
|
||||
success = True
|
||||
elif num_waves < best_num_waves:
|
||||
success = True
|
||||
elif num_waves == best_num_waves:
|
||||
# Check last wave utilization
|
||||
util = get_last_wave_util(block_m, block_n)
|
||||
best_util = get_last_wave_util(best_block_m, best_block_n)
|
||||
success = util > best_util or (util == best_util and (block_n >= best_block_n and block_m <= best_block_m))
|
||||
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
|
||||
assert best_block_m is not None and best_block_n is not None
|
||||
|
||||
# Always pick the longest one
|
||||
# NOTES: for double B scales, the best number of stages may be reduced
|
||||
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
|
||||
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
|
||||
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
|
||||
if best_smem_size <= sm90_capacity:
|
||||
best_num_stages = num_stages
|
||||
break
|
||||
assert best_num_stages is not None
|
||||
|
||||
# Decide the number of TMA multicast
|
||||
best_num_tma_multicast = 1
|
||||
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
|
||||
best_num_tma_multicast = 2
|
||||
|
||||
return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size
|
||||
|
||||
|
||||
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor) -> None:
|
||||
"""
|
||||
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[m, n]`, representing the result.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
m, k = lhs.shape
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
assert n % 64 == 0 and k % 128 == 0
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert n > 0 and k > 0
|
||||
assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert rhs_scales.is_contiguous()
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
return
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16), ('m', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
@ -0,0 +1,182 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from .gemm import get_best_configs
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
||||
|
||||
# C++ code templates
|
||||
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
||||
template = """
|
||||
using namespace deep_gemm;
|
||||
|
||||
// Templated args from Python JIT call
|
||||
constexpr auto N = {N}, K = {K};
|
||||
constexpr auto BLOCK_M = {BLOCK_M};
|
||||
constexpr auto BLOCK_N = {BLOCK_N};
|
||||
constexpr auto kNumStages = {NUM_STAGES};
|
||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||
|
||||
// Make a templated grouped GEMM
|
||||
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
|
||||
|
||||
// Launch kernel
|
||||
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
||||
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
|
||||
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
|
||||
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
|
||||
GemmType::run(out, rhs_scales, grouped_layout,
|
||||
m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
||||
stream, num_sms, smem_size);
|
||||
"""
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor, m_indices: torch.Tensor) -> None:
|
||||
"""
|
||||
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
||||
`get_m_alignment_for_contiguous_layout()` (128).
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
|
||||
m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
|
||||
`m_indices[i]` records the group which the j-th row of the LHS belong to,
|
||||
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
|
||||
Values of `m_indices` in every-m-alignment-block must also be the same.
|
||||
`-1` in this tensor indicates no RHS matrix selected, the kernel will skip the computation for that aligned block.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
m, k = lhs.shape
|
||||
num_groups, n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
m__ = m_indices.numel()
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ == m__ and k == k_ and n == n_
|
||||
assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert m_indices.dtype == torch.int32
|
||||
assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
assert out.is_contiguous() and m_indices.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert rhs_scales.is_contiguous()
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
return
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms,
|
||||
is_grouped_contiguous=True)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
m_indices, m, num_groups,
|
||||
torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16),
|
||||
('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
|
||||
"""
|
||||
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
||||
should be separately transposed.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
|
||||
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
|
||||
in the i-th group.
|
||||
expected_m: a value hint (which is a value on CPU) for the M expectation of each batch,
|
||||
correctly setting this value may lead to better performance.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
num_groups, m, k = lhs.shape
|
||||
num_groups_, n, k_ = rhs.shape
|
||||
num_groups__, m_, n_ = out.shape
|
||||
num_groups___ = masked_m.numel()
|
||||
|
||||
# Type and shape checks
|
||||
assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
||||
assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
|
||||
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert masked_m.dtype == torch.int32
|
||||
assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
assert out.is_contiguous() and masked_m.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert rhs_scales.is_contiguous()
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
masked_m, m,
|
||||
torch.cuda.current_stream(), num_sms, smem_size)
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'},
|
||||
space=(),
|
||||
includes=includes,
|
||||
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
('out', torch.bfloat16),
|
||||
('grouped_layout', torch.int32), ('m', int),
|
||||
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
runtime(*args)
|
||||
@ -0,0 +1,81 @@
|
||||
import copy
|
||||
import os
|
||||
import torch
|
||||
from typing import Any, Dict
|
||||
|
||||
from ..jit import build, cpp_format, generate, Runtime
|
||||
|
||||
|
||||
class JITTuner:
|
||||
def __init__(self) -> None:
|
||||
self.tuned = {}
|
||||
|
||||
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
|
||||
includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime:
|
||||
# NOTES: we always assume the space and template will not change
|
||||
# We also assume the GPU device will not be changed
|
||||
# NOTES: the function must have no accumulated side effects
|
||||
keys = {k: keys[k] for k in sorted(keys.keys())}
|
||||
signature = (name, f'{keys}')
|
||||
if signature in self.tuned:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT kernel {name} with keys {keys}')
|
||||
return self.tuned[signature]
|
||||
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
|
||||
|
||||
assert signature not in self.tuned
|
||||
assert args is not None
|
||||
space = (dict(), ) if len(space) == 0 else space
|
||||
|
||||
kernels = []
|
||||
for tuned_keys in space:
|
||||
assert isinstance(tuned_keys, dict)
|
||||
full_keys = copy.deepcopy(keys)
|
||||
full_keys.update(tuned_keys)
|
||||
code = generate(includes, arg_defs, cpp_format(template, full_keys))
|
||||
|
||||
# Illegal build must raise errors
|
||||
kernels.append((build(name, arg_defs, code), tuned_keys))
|
||||
|
||||
best_runtime, best_time, best_keys = None, None, None
|
||||
for runtime, tuned_keys in kernels:
|
||||
if len(space) > 1:
|
||||
# Check kernel validity
|
||||
return_code = runtime(*args)
|
||||
if return_code != 0:
|
||||
# Pass illegal kernels, e.g. insufficient shared memory capacity
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
|
||||
continue
|
||||
|
||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
|
||||
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
start_event.record()
|
||||
for i in range(20):
|
||||
assert runtime(*args) == 0
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
elapsed_time = start_event.elapsed_time(end_event)
|
||||
else:
|
||||
elapsed_time = 0
|
||||
|
||||
# Compare if better
|
||||
if best_time is None or elapsed_time < best_time:
|
||||
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
|
||||
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
|
||||
|
||||
# Cache the best runtime and return
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
|
||||
print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
|
||||
self.tuned[signature] = best_runtime
|
||||
return best_runtime
|
||||
|
||||
|
||||
jit_tuner = JITTuner()
|
||||
@ -0,0 +1,105 @@
|
||||
import torch
|
||||
|
||||
_num_sms = None
|
||||
|
||||
|
||||
def set_num_sms(num_sms: int) -> None:
|
||||
"""
|
||||
Set the maximum SM count for all GEMM kernels to use.
|
||||
|
||||
Arguments:
|
||||
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
||||
"""
|
||||
global _num_sms
|
||||
assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
_num_sms = num_sms
|
||||
|
||||
|
||||
def get_num_sms() -> int:
|
||||
"""
|
||||
Get the current maximum limit of SM count for all GEMM kernels to use.
|
||||
If the count is never specified, the function will return the number of device SMs.
|
||||
|
||||
Returns:
|
||||
Current maximum limit of SM count for all GEMM kernels to use.
|
||||
"""
|
||||
global _num_sms
|
||||
if _num_sms is None:
|
||||
_num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
return _num_sms
|
||||
|
||||
|
||||
def cell_div(x: int, y: int) -> int:
|
||||
"""
|
||||
Perform ceiling division of two integers.
|
||||
|
||||
Args:
|
||||
x: the dividend.
|
||||
y: the divisor.
|
||||
|
||||
Returns:
|
||||
The result of the ceiling division.
|
||||
"""
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def get_m_alignment_for_contiguous_layout():
|
||||
"""
|
||||
When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis.
|
||||
Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well
|
||||
with GEMM block shape.
|
||||
|
||||
Returns:
|
||||
Group-level alignment requirement for grouped contiguous layout, which is always 128.
|
||||
"""
|
||||
return 128
|
||||
|
||||
|
||||
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
"""
|
||||
Global memory address of TMA must be 16-byte aligned.
|
||||
Since we use column-major layout for the LHS scaling tensor,
|
||||
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
|
||||
|
||||
Arguments:
|
||||
x: original M-axis shape of the LHS scaling tensor.
|
||||
element_size: element size of the LHS scaling tensor.
|
||||
|
||||
Returns:
|
||||
M-axis shape of the LHS scaling tensor after padding.
|
||||
"""
|
||||
tma_alignment_bytes = 16
|
||||
assert tma_alignment_bytes % element_size == 0
|
||||
alignment = tma_alignment_bytes // element_size
|
||||
return cell_div(x, alignment) * alignment
|
||||
|
||||
|
||||
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
|
||||
If the input tensor is already column-major layout and 16-byte aligned along the M axis
|
||||
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
|
||||
|
||||
Arguments:
|
||||
x: usually the LHS scaling tensor in GEMM.
|
||||
|
||||
Returns:
|
||||
The LHS scaling tensor of TMA-aligned transposed format.
|
||||
"""
|
||||
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
|
||||
assert x.dim() in (2, 3)
|
||||
remove_dim = False
|
||||
if x.dim() == 2:
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
b, m, n = x.shape
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
|
||||
# The last kernel gives a column-major TMA aligned layout
|
||||
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
||||
return x.squeeze(0) if remove_dim else x
|
||||
|
||||
# Normal layout requires transposing
|
||||
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||
aligned_x[:, :m, :] = x
|
||||
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
||||
@ -0,0 +1,154 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
high_precision: bool = False):
|
||||
# Flush L2 cache with 256 MB data
|
||||
torch.cuda.synchronize()
|
||||
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
cache.zero_()
|
||||
|
||||
# Warmup
|
||||
for _ in range(num_warmups):
|
||||
fn()
|
||||
|
||||
# Add a large kernel to eliminate the CPU launch overhead
|
||||
if high_precision:
|
||||
x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
x @ y
|
||||
|
||||
# Testing
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for i in range(num_tests):
|
||||
fn()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return start_event.elapsed_time(end_event) / num_tests
|
||||
|
||||
|
||||
class empty_suppress:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
pass
|
||||
|
||||
|
||||
class suppress_stdout_stderr:
|
||||
def __enter__(self):
|
||||
self.outnull_file = open(os.devnull, 'w')
|
||||
self.errnull_file = open(os.devnull, 'w')
|
||||
|
||||
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
||||
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
||||
|
||||
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
||||
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
||||
|
||||
self.old_stdout = sys.stdout
|
||||
self.old_stderr = sys.stderr
|
||||
|
||||
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
||||
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
||||
|
||||
sys.stdout = self.outnull_file
|
||||
sys.stderr = self.errnull_file
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
sys.stdout = self.old_stdout
|
||||
sys.stderr = self.old_stderr
|
||||
|
||||
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
||||
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
||||
|
||||
os.close(self.old_stdout_fileno)
|
||||
os.close(self.old_stderr_fileno)
|
||||
|
||||
self.outnull_file.close()
|
||||
self.errnull_file.close()
|
||||
|
||||
|
||||
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
|
||||
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = False):
|
||||
# Conflict with Nsight Systems
|
||||
using_nsys = os.environ.get('DG_NSYS_PROFILING', False)
|
||||
|
||||
# For some auto-tuning kernels with prints
|
||||
fn()
|
||||
|
||||
# Profile
|
||||
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
|
||||
with suppress():
|
||||
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
|
||||
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
|
||||
with profiler:
|
||||
for i in range(2):
|
||||
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
|
||||
if barrier_comm_profiling:
|
||||
lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
lhs @ rhs
|
||||
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
|
||||
for _ in range(num_tests):
|
||||
if flush_l2:
|
||||
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
|
||||
fn()
|
||||
|
||||
if not using_nsys:
|
||||
profiler.step()
|
||||
|
||||
# Return 1 if using Nsight Systems
|
||||
if using_nsys:
|
||||
return 1
|
||||
|
||||
# Parse the profiling table
|
||||
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
||||
is_tupled = isinstance(kernel_names, tuple)
|
||||
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
|
||||
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
for name in kernel_names:
|
||||
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
|
||||
|
||||
# Save chrome traces
|
||||
if trace_path is not None:
|
||||
profiler.export_chrome_trace(trace_path)
|
||||
|
||||
# Return average kernel times
|
||||
units = {'ms': 1e3, 'us': 1e6}
|
||||
kernel_times = []
|
||||
for name in kernel_names:
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
kernel_times.append(float(time_str.replace(unit, '')) / scale)
|
||||
break
|
||||
break
|
||||
return tuple(kernel_times) if is_tupled else kernel_times[0]
|
||||
|
||||
|
||||
def calc_diff(x, y):
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def count_bytes(tensors):
|
||||
total = 0
|
||||
for t in tensors:
|
||||
if isinstance(t, tuple):
|
||||
total += count_bytes(t)
|
||||
else:
|
||||
total += t.numel() * t.element_size()
|
||||
return total
|
||||
@ -34,7 +34,6 @@ template<typename RasterOrderOptions>
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool verify = true;
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f;
|
||||
@ -42,12 +41,9 @@ struct Options {
|
||||
bool save_aux = true;
|
||||
bool save_amax = true;
|
||||
int iterations = 1000;
|
||||
int warmup = 1000;
|
||||
int m = 1024, n = 512, k = 1024, l = 1;
|
||||
RasterOrderOptions raster;
|
||||
int swizzle;
|
||||
float epsilon = 0.02f;
|
||||
float non_zero_floor = 1.f;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
@ -72,11 +68,7 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("device_scale", device_scale, false);
|
||||
cmd.get_cmd_line_argument("save_aux", save_aux, true);
|
||||
cmd.get_cmd_line_argument("save_amax", save_amax, true);
|
||||
cmd.get_cmd_line_argument("warmup", warmup);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("verify", verify);
|
||||
cmd.get_cmd_line_argument("epsilon", epsilon);
|
||||
cmd.get_cmd_line_argument("non-zero-floor", non_zero_floor);
|
||||
|
||||
char raster_char;
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
@ -97,8 +89,8 @@ struct Options {
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling\n\n"
|
||||
<< " Hopper FP8 GEMM using a Warp Specialized kernel with Blockwise Scaling.\n\n"
|
||||
out << "54_fp8_hopper_warp_specialized_gemm\n\n"
|
||||
<< " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
@ -117,14 +109,11 @@ struct Options {
|
||||
<< " --save_amax=<bool> Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n"
|
||||
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
|
||||
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
|
||||
<< " --verify=<bool> Verify the results.\n\n"
|
||||
<< " --epsilon=<float> The epsilon value for comparing the results.\n\n"
|
||||
<< " --non-zero-floor=<float> The none zero floor for comparing the results.\n\n";
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
<< "$ " << "54_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -0,0 +1,444 @@
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
#include "mma_utils.cuh"
|
||||
#include "scheduler.cuh"
|
||||
#include "tma_utils.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class Layout {
|
||||
RowMajor,
|
||||
ColMajor
|
||||
};
|
||||
|
||||
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
|
||||
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
|
||||
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
|
||||
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
||||
uint32_t kNumTMAMulticast,
|
||||
GemmType kGemmType>
|
||||
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
||||
fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
DG_STATIC_ASSERT(cell_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
|
||||
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// Shared memory
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
||||
static constexpr uint32_t SHAPE_K_SCALES = cell_div(SHAPE_K, BLOCK_K);
|
||||
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
||||
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
|
||||
constexpr uint32_t kNumIterations = cell_div(SHAPE_K, kFullKOfAllStages);
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_id();
|
||||
|
||||
// Prefetch TMA descriptors at very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
|
||||
// Data on shared memory
|
||||
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
|
||||
__nv_fp8_e4m3* smem_a[kNumStages];
|
||||
__nv_fp8_e4m3* smem_b[kNumStages];
|
||||
float* smem_scales_a[kNumStages];
|
||||
float* smem_scales_b;
|
||||
|
||||
// TMA Barrier for both divisible and non-divisible cases
|
||||
Barrier* full_barriers[kNumStages];
|
||||
Barrier* empty_barriers[kNumStages];
|
||||
|
||||
// Fill shared memory pointers
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}
|
||||
smem_scales_b = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE));
|
||||
|
||||
// Fill barriers
|
||||
DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers");
|
||||
DG_STATIC_ASSERT(not kMustUseUniformedScaleB or SHAPE_K_SCALES % (sizeof(Barrier) / sizeof(float)) == 0, "Misaligned barriers");
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_scales_b + SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i] = barrier_start_ptr + i;
|
||||
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
|
||||
}
|
||||
|
||||
// Initialize barriers
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast");
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
|
||||
}
|
||||
|
||||
// Synchronize all threads to make barrier visible in normal memory model
|
||||
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
auto launch_k_iterations = [](const auto& func) {
|
||||
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
|
||||
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
} else {
|
||||
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
func(kNumIterations - 1, NotDivisibleK{});
|
||||
}
|
||||
};
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr int kNumTMARegisters = 40;
|
||||
constexpr int kNumMathRegisters = 232;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
|
||||
// Issue TMA A with broadcasting
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||
|
||||
// Issue TMA B without broadcasting
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
if constexpr (kNumTMAMulticast > 1) {
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s)
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
|
||||
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
// Decide the number of scales B to load
|
||||
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
|
||||
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
|
||||
if constexpr (not kMustUseUniformedScaleB) {
|
||||
num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
|
||||
num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8;
|
||||
}
|
||||
uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2);
|
||||
|
||||
// Load B scales with math warp-groups
|
||||
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
|
||||
if (threadIdx.x >= 32) {
|
||||
auto num_previous_lines = scheduler.get_global_idx<false>(cell_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
|
||||
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
|
||||
#pragma unroll
|
||||
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
|
||||
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
|
||||
}
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Accumulation for WGMMA or CUDA promotion
|
||||
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](int s) {
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
|
||||
}
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Read B scales
|
||||
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
|
||||
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
|
||||
|
||||
// Wait TMA arrivals
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
|
||||
// Read A scales
|
||||
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
||||
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival
|
||||
empty_barrier_arrive(s);
|
||||
|
||||
// Promote with scales
|
||||
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
|
||||
float scale_0_1, scale_1_1;
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
||||
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
||||
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
||||
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
||||
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
||||
}
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
});
|
||||
|
||||
// Write back to shared memory using STSM
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
|
||||
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
|
||||
);
|
||||
}
|
||||
if constexpr (WGMMA::kNumAccum % 8 != 0) {
|
||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
|
||||
);
|
||||
}
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
if (threadIdx.x == 0) {
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
|
||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
cute::tma_store_arrive();
|
||||
cute::tma_store_wait<0>();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAMulticast,
|
||||
GemmType kGemmType>
|
||||
class Gemm {
|
||||
private:
|
||||
using Barrier = cuda::barrier<cuda::thread_scope_block>;
|
||||
|
||||
public:
|
||||
Gemm() = default;
|
||||
|
||||
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const CUtensorMap& tma_a_desc,
|
||||
const CUtensorMap& tma_b_desc,
|
||||
const CUtensorMap& tma_scales_a_desc,
|
||||
const CUtensorMap& tma_d_desc,
|
||||
cudaStream_t stream,
|
||||
int num_sms, uint32_t smem_size) {
|
||||
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
||||
constexpr uint32_t kNumTMAThreads = 128;
|
||||
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
||||
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
|
||||
kNumTMAMulticast, kGemmType>;
|
||||
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
||||
|
||||
// Cluster launch
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = num_sms;
|
||||
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
config.dynamicSmemBytes = smem_size;
|
||||
config.stream = stream;
|
||||
|
||||
// Clusters for TMA multicast
|
||||
// NOTES: `>= 4` cluster size will cause performance degradation
|
||||
cudaLaunchAttribute attr;
|
||||
attr.id = cudaLaunchAttributeClusterDimension;
|
||||
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
|
||||
config.attrs = &attr;
|
||||
config.numAttrs = 1;
|
||||
|
||||
// Launch
|
||||
auto status = cudaLaunchKernelEx(&config, kernel,
|
||||
gmem_d, scales_b, grouped_layout,
|
||||
shape_m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
|
||||
DG_HOST_ASSERT(status == cudaSuccess);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_b_desc(T* global_address) {
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, BLOCK_M, BLOCK_N,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
||||
// Make TMA aligned to 16 bytes
|
||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||
shape_m = cell_div(shape_m, kAlignment) * kAlignment;
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
shape_m, cell_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_desc(
|
||||
T* global_address, Layout layout,
|
||||
uint32_t gmem_rows, uint32_t gmem_cols,
|
||||
uint32_t smem_rows, uint32_t smem_cols,
|
||||
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
|
||||
if (layout == Layout::RowMajor) {
|
||||
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
|
||||
uint32_t smem_dim[2] = {smem_cols, smem_rows};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type);
|
||||
} else {
|
||||
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
|
||||
uint32_t smem_dim[2] = {smem_rows, smem_cols};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
@ -0,0 +1,885 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
struct SM90_64x16x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %10, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
||||
" %8,"
|
||||
" %9,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 16;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x24x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %14, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11},"
|
||||
" %12,"
|
||||
" %13,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 24;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x32x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %18, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
||||
" %16,"
|
||||
" %17,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 32;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x40x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %22, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19},"
|
||||
" %20,"
|
||||
" %21,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 40;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x48x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %26, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
||||
" %24,"
|
||||
" %25,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 48;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x56x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %30, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27}, "
|
||||
" %28,"
|
||||
" %29,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 56;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x64x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %34, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31}, "
|
||||
" %32,"
|
||||
" %33,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 64;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x72x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %38, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35}, "
|
||||
" %36,"
|
||||
" %37,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 72;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x80x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %42, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39}, "
|
||||
" %40,"
|
||||
" %41,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 80;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x88x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %46, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43}, "
|
||||
" %44,"
|
||||
" %45,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 88;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x96x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %50, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47}, "
|
||||
" %48,"
|
||||
" %49,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 96;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x104x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %54, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51}, "
|
||||
" %52,"
|
||||
" %53,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 104;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x112x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %58, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55}, "
|
||||
" %56,"
|
||||
" %57,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 112;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x120x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %62, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59}, "
|
||||
" %60,"
|
||||
" %61,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 120;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x128x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %66, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59, %60, %61, %62, %63}, "
|
||||
" %64,"
|
||||
" %65,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 128;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
struct SM90_64x192x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
||||
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
|
||||
float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79,
|
||||
float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87,
|
||||
float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %98, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
||||
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
||||
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
||||
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
||||
" %88, %89, %90, %91, %92, %93, %94, %95}, "
|
||||
" %96,"
|
||||
" %97,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
||||
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
||||
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
||||
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
||||
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
||||
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
|
||||
d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79],
|
||||
d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87],
|
||||
d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 192;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x2_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
|
||||
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
|
||||
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
|
||||
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x4_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) {
|
||||
const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
|
||||
*reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
|
||||
asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
||||
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
|
||||
}
|
||||
};
|
||||
|
||||
__device__ void warpgroup_arrive() {
|
||||
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
|
||||
__device__ void warpgroup_commit_batch() {
|
||||
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
|
||||
__device__ void warpgroup_fence_operand(float& reg) {
|
||||
asm volatile("" : "+f"(reg) :: "memory");
|
||||
}
|
||||
|
||||
__forceinline__ __device__ uint32_t get_lane_id() {
|
||||
uint32_t lane_id;
|
||||
asm("mov.u32 %0, %laneid;" : "=r"(lane_id));
|
||||
return lane_id;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) {
|
||||
uint32_t ret;
|
||||
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) {
|
||||
int4 ret;
|
||||
asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
|
||||
float ret;
|
||||
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
|
||||
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
|
||||
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ void warpgroup_wait() {
|
||||
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
|
||||
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
|
||||
}
|
||||
|
||||
union GmmaDescriptor {
|
||||
__host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept {
|
||||
desc_ = t.desc_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept {
|
||||
desc_ = t.desc_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
uint64_t desc_;
|
||||
uint32_t reg32_[2];
|
||||
uint16_t reg16_[4];
|
||||
|
||||
struct {
|
||||
uint16_t start_address_: 14, : 2;
|
||||
uint16_t leading_byte_offset_: 14, : 2;
|
||||
uint16_t stride_byte_offset_: 14, : 2;
|
||||
uint8_t : 1, base_offset_: 3, : 4;
|
||||
uint8_t : 6, layout_type_: 2;
|
||||
} bitfield;
|
||||
|
||||
// Decay to an `uint64_t`
|
||||
__host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; }
|
||||
};
|
||||
|
||||
template <class PointerType>
|
||||
__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type,
|
||||
int leading_byte_offset = 0,
|
||||
int stride_byte_offset = 1024) {
|
||||
GmmaDescriptor desc;
|
||||
auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
desc.bitfield.start_address_ = uint_ptr >> 4;
|
||||
desc.bitfield.layout_type_ = layout_type;
|
||||
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
|
||||
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
|
||||
desc.bitfield.base_offset_ = 0;
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <int N>
|
||||
struct FP8MMASelector {
|
||||
static constexpr auto select_type() {
|
||||
if constexpr (N == 16) return SM90_64x16x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 24) return SM90_64x24x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 32) return SM90_64x32x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 40) return SM90_64x40x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 48) return SM90_64x48x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 56) return SM90_64x56x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 64) return SM90_64x64x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 72) return SM90_64x72x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 80) return SM90_64x80x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 88) return SM90_64x88x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 96) return SM90_64x96x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 104) return SM90_64x104x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS();
|
||||
}
|
||||
|
||||
using type = decltype(select_type());
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
@ -0,0 +1,103 @@
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class GemmType {
|
||||
Normal,
|
||||
GroupedContiguous,
|
||||
GroupedMasked
|
||||
};
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
|
||||
template <GemmType kGemmType,
|
||||
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
|
||||
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
|
||||
uint32_t kNumNBlocks = cell_div(SHAPE_N, BLOCK_N),
|
||||
uint32_t kNumNBlocksPerGroup = 16>
|
||||
struct Scheduler {
|
||||
int current_iter = -1;
|
||||
uint32_t num_aligned_m_blocks;
|
||||
|
||||
// For normal GEMM
|
||||
// Maybe not used in the masked grouped GEMM
|
||||
uint32_t num_blocks;
|
||||
|
||||
// For grouped GEMM
|
||||
int* grouped_layout;
|
||||
// Only used for masked layout
|
||||
uint32_t curr_group_idx, curr_cumsum;
|
||||
|
||||
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
|
||||
int* grouped_layout = nullptr) {
|
||||
num_aligned_m_blocks = cell_div(shape_m, BLOCK_M);
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
||||
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
||||
this->grouped_layout = grouped_layout;
|
||||
} else if (kGemmType == GemmType::GroupedMasked) {
|
||||
curr_group_idx = curr_cumsum = 0;
|
||||
this->grouped_layout = grouped_layout;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||
|
||||
// Swizzle for better L2 usages
|
||||
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
|
||||
auto group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
|
||||
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
m_block_idx = in_group_idx / num_n_blocks_in_group;
|
||||
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
|
||||
}
|
||||
|
||||
template <bool kIgnoreGroupedForGroupedContiguous=true>
|
||||
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
|
||||
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return block_idx * block_size;
|
||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
||||
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||
return offset * shape_dim + block_idx * block_size;
|
||||
} else if (kGemmType == GemmType::GroupedMasked) {
|
||||
return curr_group_idx * shape_dim + block_idx * block_size;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;
|
||||
|
||||
if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||
uint32_t num_m_blocks;
|
||||
while (true) {
|
||||
// End of the task
|
||||
if (curr_group_idx == kNumGroups)
|
||||
return false;
|
||||
|
||||
// Within current group
|
||||
num_m_blocks = cell_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
|
||||
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
|
||||
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
|
||||
break;
|
||||
|
||||
// Move to check the next group
|
||||
curr_group_idx ++, curr_cumsum = current_m_block_cumsum;
|
||||
}
|
||||
|
||||
get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
|
||||
} else {
|
||||
if (next_block_idx >= num_blocks)
|
||||
return false;
|
||||
|
||||
get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
} // namespace deep_gemm
|
||||
@ -0,0 +1,96 @@
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <cuda.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda/barrier>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <class T>
|
||||
constexpr CUtensorMapDataType get_CUtensorMapDataType() {
|
||||
if constexpr (std::is_same<T, uint8_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, uint16_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
|
||||
} else if constexpr (std::is_same<T, uint32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
|
||||
} else if constexpr (std::is_same<T, uint64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
|
||||
} else if constexpr (std::is_same<T, int32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT32;
|
||||
} else if constexpr (std::is_same<T, int64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT64;
|
||||
} else if constexpr (std::is_same<T, __half>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
} else if constexpr (std::is_same<T, float>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
} else if constexpr (std::is_same<T, double>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
|
||||
}
|
||||
}
|
||||
|
||||
PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
|
||||
// Get pointer to `cuTensorMapEncodeTiled`
|
||||
cudaDriverEntryPointQueryResult driver_status;
|
||||
void* cuTensorMapEncodeTiled_ptr = nullptr;
|
||||
|
||||
#if CUDA_VERSION >= 12050
|
||||
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#else
|
||||
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#endif
|
||||
|
||||
if (driver_status != cudaDriverEntryPointSuccess)
|
||||
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
|
||||
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
|
||||
uint64_t stride_in_bytes, uint32_t smem_dim[2],
|
||||
CUtensorMapSwizzle swizzle_type,
|
||||
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
|
||||
CUtensorMap tensor_map{};
|
||||
constexpr uint32_t rank = 2;
|
||||
uint64_t global_stride[rank - 1] = {stride_in_bytes};
|
||||
uint32_t elem_strides[rank] = {1, 1};
|
||||
|
||||
if (encode_func == nullptr)
|
||||
encode_func = get_cuTensorMapEncodeTiled();
|
||||
|
||||
auto result = encode_func(
|
||||
&tensor_map, get_CUtensorMapDataType<typename std::remove_cv<T>::type>(), rank,
|
||||
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
DG_HOST_ASSERT(result == CUDA_SUCCESS);
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
template <uint32_t kNumTMAMulticast = 1>
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
int32_t const& crd_0, int32_t const& crd_1) {
|
||||
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
} else if (cute::block_rank_in_cluster() == 0) {
|
||||
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@ -0,0 +1,48 @@
|
||||
#pragma once
|
||||
|
||||
#include <exception>
|
||||
|
||||
#ifdef __CLION_IDE__
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
|
||||
#define printf host_device_printf
|
||||
#endif
|
||||
|
||||
class AssertionException : public std::exception {
|
||||
private:
|
||||
std::string message{};
|
||||
|
||||
public:
|
||||
explicit AssertionException(const std::string& message) : message(message) {}
|
||||
|
||||
const char *what() const noexcept override { return message.c_str(); }
|
||||
};
|
||||
|
||||
#ifndef DG_HOST_ASSERT
|
||||
#define DG_HOST_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", \
|
||||
__FILE__, __LINE__, #cond); \
|
||||
throw AssertionException("Assertion failed: " #cond); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_DEVICE_ASSERT
|
||||
#define DG_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
|
||||
asm("trap;"); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_STATIC_ASSERT
|
||||
#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ constexpr T cell_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
@ -0,0 +1,504 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for GETT in host-side code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/relatively_equal.h"
|
||||
#include <iostream>
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::reference::host {
|
||||
|
||||
template<class T, class = void>
|
||||
struct ElementTraits {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
struct ElementTraits<T, std::enable_if_t<!std::is_same_v<decltype(std::declval<T>().get()), void> > > {
|
||||
using type = decltype(std::declval<T>().get());
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ElementAccumulator_,
|
||||
class TensorA_, // (M, K, L)
|
||||
class TensorB_, // (N, K, L)
|
||||
class TensorScaleA_, // (m, k, L)
|
||||
class TensorScaleB_, // (n, k, L)
|
||||
class TileShape_
|
||||
>
|
||||
struct GettMainloopParams {
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using TensorA = TensorA_;
|
||||
using TensorB = TensorB_;
|
||||
using EngineA = typename TensorA::engine_type;
|
||||
using LayoutA = typename TensorA::layout_type;
|
||||
using EngineB = typename TensorB::engine_type;
|
||||
using LayoutB = typename TensorB::layout_type;
|
||||
|
||||
using TensorScaleA = TensorScaleA_;
|
||||
using TensorScaleB = TensorScaleB_;
|
||||
using TileShape = TileShape_;
|
||||
using EngineScaleA = typename TensorScaleA::engine_type;
|
||||
using EngineScaleB = typename TensorScaleB::engine_type;
|
||||
|
||||
TensorA A{};
|
||||
TensorB B{};
|
||||
TensorScaleA ScaleA{};
|
||||
TensorScaleB ScaleB{};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template<
|
||||
class ElementScalar_,
|
||||
class ElementScalingFactor_,
|
||||
class ElementAccumulator_,
|
||||
class ElementCompute_,
|
||||
class TensorC_, // (M, N, L)
|
||||
class TensorD_, // (M, N, L)
|
||||
class VectorBias_ = TensorD_, // (M, 1)
|
||||
class TensorAux_ = TensorD_, // (M, N, L)
|
||||
class VectorAlpha_ = TensorD_, // (M, 1)
|
||||
class VectorBeta_ = VectorAlpha_, // (M, 1)
|
||||
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>,
|
||||
class BiasBinaryOp_ = cutlass::plus<ElementCompute_>,
|
||||
bool PerColumnBias_ = false
|
||||
>
|
||||
struct GettEpilogueParams {
|
||||
using ElementScalar = ElementScalar_;
|
||||
using ElementScalingFactor = ElementScalingFactor_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using TensorC = TensorC_;
|
||||
using TensorD = TensorD_;
|
||||
using TensorAux = TensorAux_;
|
||||
using VectorBias = VectorBias_;
|
||||
using VectorAlpha = VectorAlpha_;
|
||||
using VectorBeta = VectorBeta_;
|
||||
using ActivationFunctor = ActivationFunctor_;
|
||||
using BiasBinaryOp = BiasBinaryOp_;
|
||||
|
||||
using EngineC = typename TensorC::engine_type;
|
||||
using LayoutC = typename TensorC::layout_type;
|
||||
using EngineD = typename TensorD::engine_type;
|
||||
using LayoutD = typename TensorD::layout_type;
|
||||
static constexpr bool PerColumnBias = PerColumnBias_;
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
ElementScalar beta = ElementScalar(0);
|
||||
|
||||
TensorC C{};
|
||||
TensorD D{};
|
||||
VectorBias Bias{};
|
||||
TensorAux Aux{};
|
||||
VectorAlpha Valpha{};
|
||||
VectorBeta Vbeta{};
|
||||
ElementCompute st = ElementCompute(1);
|
||||
|
||||
ElementAccumulator* abs_max_D = nullptr;
|
||||
ElementAccumulator* abs_max_Aux = nullptr;
|
||||
|
||||
ElementScalingFactor scale_a = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_b = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_c = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_d = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_aux = ElementScalingFactor(1);
|
||||
|
||||
bool beta_per_channel_scaling = false;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GETT - General Tensor-Tensor contraction reference kernel with Blockwise scaling
|
||||
template <
|
||||
class MainloopParams,
|
||||
class EpilogueParams
|
||||
>
|
||||
void Gett(
|
||||
MainloopParams const& mainloop_params,
|
||||
EpilogueParams const& epilogue_params)
|
||||
{
|
||||
|
||||
static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{});
|
||||
static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{});
|
||||
// printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n");
|
||||
// printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n");
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel for collapse(3)
|
||||
#endif
|
||||
for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) {
|
||||
for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) {
|
||||
for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) {
|
||||
typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN];
|
||||
gett_mainloop(mainloop_params, m, n, l, acc);
|
||||
gett_epilogue(epilogue_params, m, n, l, acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GETT - Mainloop
|
||||
template <class MainloopParams, class ElementAccumulator, int kBlockM, int kBlockN>
|
||||
void gett_mainloop(
|
||||
MainloopParams const& mainloop_params,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t l,
|
||||
ElementAccumulator (&acc)[kBlockM][kBlockN])
|
||||
{
|
||||
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B");
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B");
|
||||
|
||||
using cute::raw_pointer_cast;
|
||||
|
||||
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
|
||||
using ElementB = typename ElementTraits<typename MainloopParams::EngineB::value_type>::type;
|
||||
using ElementBlockScaleA = typename ElementTraits<typename MainloopParams::EngineScaleA::value_type>::type;
|
||||
using ElementBlockScaleB = typename ElementTraits<typename MainloopParams::EngineScaleB::value_type>::type;
|
||||
|
||||
using RingOp = multiply_add<ElementAccumulator, ElementAccumulator, ElementAccumulator>;
|
||||
RingOp fma_op;
|
||||
|
||||
multiplies<ElementAccumulator> scale_op;
|
||||
|
||||
static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});;
|
||||
|
||||
// Tempo accumulators to seperate blockwise accumulation
|
||||
typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN];
|
||||
|
||||
// Zero out accumulators
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
||||
acc_temp[m_b][n_b] = ElementAccumulator(0);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t block_m = m / kBlockM;
|
||||
int64_t block_n = n / kBlockN;
|
||||
cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, l);
|
||||
cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, l);
|
||||
|
||||
// Compute on this k-block
|
||||
for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
|
||||
|
||||
// Load Blockwise scaling factor from blockscale Tensors for A and B
|
||||
int64_t block_k = k / kBlockK;
|
||||
ElementBlockScaleA scale_a = blockscale_A[block_k];
|
||||
ElementBlockScaleB scale_b = blockscale_B[block_k];
|
||||
|
||||
// Load A
|
||||
ElementAccumulator a_frag[kBlockM];
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
if (m + m_b < cute::size<0>(mainloop_params.A.layout())) {
|
||||
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
|
||||
a_frag[m_b] = static_cast<ElementAccumulator>(ElementA(mainloop_params.A(m + m_b, k, l)));
|
||||
} else {
|
||||
a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
||||
}
|
||||
}
|
||||
|
||||
// Load B
|
||||
ElementAccumulator b_frag[kBlockN];
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (n + n_b < cute::size<0>(mainloop_params.B.layout())) {
|
||||
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
|
||||
b_frag[n_b] = static_cast<ElementAccumulator>(ElementB(mainloop_params.B(n + n_b, k, l)));
|
||||
} else {
|
||||
b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
||||
}
|
||||
}
|
||||
|
||||
// do compute
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply Blockwise-scaling at kBlockK boundary
|
||||
// (a) Apply block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary
|
||||
// (b) Zero-out partial temporary (acc_temp),
|
||||
// (c) Update permanent (accu)
|
||||
if ((k+1) % kBlockK == 0) {
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a * scale_b;
|
||||
acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b];
|
||||
acc_temp[m_b][n_b] = ElementAccumulator(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GETT - Epilogue
|
||||
template <class EpilogueParams, class ElementAccumulator, int kBlockM, int kBlockN>
|
||||
void gett_epilogue(
|
||||
EpilogueParams const& epilogue_params,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t l,
|
||||
ElementAccumulator (&acc)[kBlockM][kBlockN])
|
||||
{
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B");
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B");
|
||||
|
||||
using cute::raw_pointer_cast;
|
||||
|
||||
using ElementCompute = typename EpilogueParams::ElementCompute;
|
||||
using ElementC = typename EpilogueParams::TensorC::value_type;
|
||||
using ElementD = typename EpilogueParams::TensorD::value_type;
|
||||
using ElementAux = typename EpilogueParams::TensorAux::value_type;
|
||||
using ElementBias = typename EpilogueParams::VectorBias::value_type;
|
||||
using ElementScalar = typename EpilogueParams::ElementScalar;
|
||||
using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor;
|
||||
using ActivationFunctor = typename EpilogueParams::ActivationFunctor;
|
||||
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
|
||||
|
||||
constexpr bool PerColBias = EpilogueParams::PerColumnBias;
|
||||
constexpr bool IsScalingAndAmaxOutputNeeded =
|
||||
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
|
||||
|
||||
constexpr bool IsScalingAndAmaxAuxOutputNeeded =
|
||||
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
|
||||
|
||||
constexpr bool IsReLUAuxNeeded =
|
||||
(cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ReLu<ElementCompute>> or
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>) and
|
||||
cute::is_same_v<ElementAux, cutlass::uint1b_t>;
|
||||
constexpr bool IsClamp =
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>;
|
||||
|
||||
constexpr bool IsBackpropFusion =
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dGELU<ElementCompute>> or
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dReLU<ElementCompute>>;
|
||||
|
||||
// Input related converter
|
||||
NumericConverter<ElementCompute, ElementAccumulator> accumulator_converter;
|
||||
NumericConverter<ElementCompute, ElementC> source_converter;
|
||||
NumericConverter<ElementCompute, ElementBias> bias_converter;
|
||||
[[maybe_unused]] NumericConverter<ElementCompute, ElementAux> aux_source_converter;
|
||||
|
||||
// Scale related converter
|
||||
NumericConverter<ElementCompute, ElementScalar> scale_converter;
|
||||
NumericConverter<ElementCompute, ElementScalingFactor> scaling_factor_converter;
|
||||
|
||||
// Abs max converter
|
||||
[[maybe_unused]] NumericConverter<ElementAccumulator, ElementCompute> abs_max_output_converter;
|
||||
|
||||
// Output related converter
|
||||
NumericConverter<ElementD, ElementCompute> destination_converter;
|
||||
[[maybe_unused]] NumericConverter<ElementAux, ElementCompute> aux_destination_converter;
|
||||
NumericConverter<ElementBias, ElementCompute> dBias_converter;
|
||||
|
||||
// Epilogue operations
|
||||
multiply_add<ElementCompute, ElementCompute, ElementCompute> epilogue_fma;
|
||||
multiplies<ElementCompute> mul;
|
||||
plus<ElementCompute> add;
|
||||
|
||||
// Activation operation
|
||||
ActivationFunctor activation;
|
||||
|
||||
// Bias binary operation
|
||||
BiasBinaryOp bias_op;
|
||||
|
||||
// Do conversion
|
||||
ElementCompute converted_alpha = scale_converter(epilogue_params.alpha);
|
||||
ElementCompute converted_beta = scale_converter(epilogue_params.beta);
|
||||
ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a);
|
||||
ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b);
|
||||
ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c);
|
||||
ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d);
|
||||
ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux);
|
||||
|
||||
// Init local var
|
||||
[[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0);
|
||||
[[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0);
|
||||
|
||||
converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
|
||||
converted_beta = mul(converted_beta, converted_scale_c);
|
||||
|
||||
ElementCompute inter_accum[kBlockM][kBlockN];
|
||||
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
ElementCompute local_dBias = ElementCompute(0);
|
||||
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
|
||||
// Convert every type to ElementCompute first, do compute, convert to output type, write it out
|
||||
ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]);
|
||||
// per-row alpha
|
||||
if (raw_pointer_cast(epilogue_params.Valpha.data())) {
|
||||
converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b));
|
||||
}
|
||||
ElementCompute output = mul(converted_alpha, converted_acc);
|
||||
|
||||
if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) {
|
||||
ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b));
|
||||
output = bias_op(output, converted_bias);
|
||||
}
|
||||
|
||||
if (raw_pointer_cast(epilogue_params.C.data())) {
|
||||
ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l));
|
||||
// per-row beta
|
||||
if (epilogue_params.Vbeta.data()) {
|
||||
converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b));
|
||||
}
|
||||
output = epilogue_fma(converted_beta, converted_src, output);
|
||||
}
|
||||
|
||||
if constexpr (IsBackpropFusion) {
|
||||
ElementAux aux_input = ElementAux(0);
|
||||
if (raw_pointer_cast(epilogue_params.Aux.data())) {
|
||||
aux_input = epilogue_params.Aux(m + m_b, n + n_b, l);
|
||||
}
|
||||
|
||||
output = activation(output, aux_source_converter(aux_input));
|
||||
local_dBias = add(local_dBias, output);
|
||||
}
|
||||
else {
|
||||
if (raw_pointer_cast(epilogue_params.Aux.data())) {
|
||||
auto aux_output = output;
|
||||
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
|
||||
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
|
||||
local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output);
|
||||
aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0));
|
||||
}
|
||||
|
||||
if constexpr (IsReLUAuxNeeded) {
|
||||
epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0);
|
||||
} else {
|
||||
epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (IsClamp) { // Treat Clamp as ReLU
|
||||
output = activation(output, {0, std::numeric_limits<ElementCompute>::max()});
|
||||
}
|
||||
else {
|
||||
output = activation(output);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (IsScalingAndAmaxOutputNeeded) {
|
||||
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
|
||||
local_abs_max_output = amax_op(local_abs_max_output, output);
|
||||
output = epilogue_fma(converted_scale_d, output, ElementCompute(0));
|
||||
}
|
||||
|
||||
inter_accum[m_b][n_b] = ElementCompute(output);
|
||||
}
|
||||
} // n_b
|
||||
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) {
|
||||
if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) {
|
||||
ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b));
|
||||
local_dBias = add(local_dBias, converted_dBias);
|
||||
epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias);
|
||||
}
|
||||
}
|
||||
} // m_b
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
|
||||
epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp critical(Abs_Max_Data_Update)
|
||||
#endif
|
||||
{
|
||||
if constexpr (IsScalingAndAmaxOutputNeeded) {
|
||||
if (epilogue_params.abs_max_D) {
|
||||
*epilogue_params.abs_max_D = maximum_with_nan_propogation<ElementAccumulator>{}(
|
||||
*epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
|
||||
if (epilogue_params.abs_max_Aux) {
|
||||
*epilogue_params.abs_max_Aux = maximum_with_nan_propogation<ElementAccumulator>{}(
|
||||
*epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GEMM - General Matrix-Matrix contraction without conjugation options
|
||||
template <
|
||||
class MainloopParams,
|
||||
class EpilogueParams
|
||||
>
|
||||
void Gemm3x(
|
||||
MainloopParams const& mainloop_params,
|
||||
EpilogueParams const& epilogue_params)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{}));
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{}));
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{}));
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) "
|
||||
"with Batchmode are supported");
|
||||
// Lower the Matrix-Multiplication with Blockwise scaling (Gemm3x) to a Tensor Contraction (Gett).
|
||||
Gett(mainloop_params, epilogue_params);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // cutlass::reference::host
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,511 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for GETT in host-side code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/relatively_equal.h"
|
||||
#include <iostream>
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::reference::host {
|
||||
|
||||
template<class T, class = void>
|
||||
struct ElementTraits {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
struct ElementTraits<T, std::enable_if_t<!std::is_same_v<decltype(std::declval<T>().get()), void> > > {
|
||||
using type = decltype(std::declval<T>().get());
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ElementAccumulator_,
|
||||
class TensorA_, // (M, K, L)
|
||||
class TensorB_, // (N, K, L)
|
||||
class TensorScaleA_, // (m, k, L)
|
||||
class TensorScaleB_, // (n, k, L)
|
||||
class TileShape_
|
||||
>
|
||||
struct GettMainloopParams {
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using TensorA = TensorA_;
|
||||
using TensorB = TensorB_;
|
||||
using EngineA = typename TensorA::engine_type;
|
||||
using LayoutA = typename TensorA::layout_type;
|
||||
using EngineB = typename TensorB::engine_type;
|
||||
using LayoutB = typename TensorB::layout_type;
|
||||
|
||||
using TensorScaleA = TensorScaleA_;
|
||||
using TensorScaleB = TensorScaleB_;
|
||||
using TileShape = TileShape_;
|
||||
using EngineScaleA = typename TensorScaleA::engine_type;
|
||||
using EngineScaleB = typename TensorScaleB::engine_type;
|
||||
|
||||
TensorA A{};
|
||||
TensorB B{};
|
||||
TensorScaleA ScaleA{};
|
||||
TensorScaleB ScaleB{};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template<
|
||||
class ElementScalar_,
|
||||
class ElementScalingFactor_,
|
||||
class ElementAccumulator_,
|
||||
class ElementCompute_,
|
||||
class TensorC_, // (M, N, L)
|
||||
class TensorD_, // (M, N, L)
|
||||
class VectorBias_ = TensorD_, // (M, 1)
|
||||
class TensorAux_ = TensorD_, // (M, N, L)
|
||||
class VectorAlpha_ = TensorD_, // (M, 1)
|
||||
class VectorBeta_ = VectorAlpha_, // (M, 1)
|
||||
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>,
|
||||
class BiasBinaryOp_ = cutlass::plus<ElementCompute_>,
|
||||
bool PerColumnBias_ = false
|
||||
>
|
||||
struct GettEpilogueParams {
|
||||
using ElementScalar = ElementScalar_;
|
||||
using ElementScalingFactor = ElementScalingFactor_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using TensorC = TensorC_;
|
||||
using TensorD = TensorD_;
|
||||
using TensorAux = TensorAux_;
|
||||
using VectorBias = VectorBias_;
|
||||
using VectorAlpha = VectorAlpha_;
|
||||
using VectorBeta = VectorBeta_;
|
||||
using ActivationFunctor = ActivationFunctor_;
|
||||
using BiasBinaryOp = BiasBinaryOp_;
|
||||
|
||||
using EngineC = typename TensorC::engine_type;
|
||||
using LayoutC = typename TensorC::layout_type;
|
||||
using EngineD = typename TensorD::engine_type;
|
||||
using LayoutD = typename TensorD::layout_type;
|
||||
static constexpr bool PerColumnBias = PerColumnBias_;
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
ElementScalar beta = ElementScalar(0);
|
||||
|
||||
TensorC C{};
|
||||
TensorD D{};
|
||||
VectorBias Bias{};
|
||||
TensorAux Aux{};
|
||||
VectorAlpha Valpha{};
|
||||
VectorBeta Vbeta{};
|
||||
ElementCompute st = ElementCompute(1);
|
||||
|
||||
ElementAccumulator* abs_max_D = nullptr;
|
||||
ElementAccumulator* abs_max_Aux = nullptr;
|
||||
|
||||
ElementScalingFactor scale_a = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_b = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_c = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_d = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_aux = ElementScalingFactor(1);
|
||||
|
||||
bool beta_per_channel_scaling = false;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GETT - General Tensor-Tensor contraction reference kernel with Groupwise scaling
|
||||
template <
|
||||
class MainloopParams,
|
||||
class EpilogueParams
|
||||
>
|
||||
void Gett(
|
||||
MainloopParams const& mainloop_params,
|
||||
EpilogueParams const& epilogue_params)
|
||||
{
|
||||
|
||||
static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{});
|
||||
static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{});
|
||||
// printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n");
|
||||
// printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n");
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel for collapse(3)
|
||||
#endif
|
||||
for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) {
|
||||
for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) {
|
||||
for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) {
|
||||
typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN];
|
||||
gett_mainloop(mainloop_params, m, n, l, acc);
|
||||
gett_epilogue(epilogue_params, m, n, l, acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GETT - Mainloop
|
||||
template <class MainloopParams, class ElementAccumulator, int kBlockM, int kBlockN>
|
||||
void gett_mainloop(
|
||||
MainloopParams const& mainloop_params,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t l,
|
||||
ElementAccumulator (&acc)[kBlockM][kBlockN])
|
||||
{
|
||||
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B");
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B");
|
||||
|
||||
using cute::raw_pointer_cast;
|
||||
|
||||
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
|
||||
using ElementB = typename ElementTraits<typename MainloopParams::EngineB::value_type>::type;
|
||||
using ElementBlockScaleA = typename ElementTraits<typename MainloopParams::EngineScaleA::value_type>::type;
|
||||
using ElementBlockScaleB = typename ElementTraits<typename MainloopParams::EngineScaleB::value_type>::type;
|
||||
|
||||
using RingOp = multiply_add<ElementAccumulator, ElementAccumulator, ElementAccumulator>;
|
||||
RingOp fma_op;
|
||||
|
||||
multiplies<ElementAccumulator> scale_op;
|
||||
|
||||
static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});;
|
||||
|
||||
// Tempo accumulators to seperate blockwise accumulation
|
||||
typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN];
|
||||
|
||||
// Zero out accumulators
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
||||
acc_temp[m_b][n_b] = ElementAccumulator(0);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t block_m = m / kBlockM;
|
||||
int64_t block_n = n / kBlockN;
|
||||
cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, _, l);
|
||||
cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, _, l);
|
||||
|
||||
const int ScaleGranularityM = cute::size<0>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleA.shape());
|
||||
const int ScaleGranularityN = cute::size<1>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleB.shape());
|
||||
assert(cute::size<0>(typename MainloopParams::TileShape{}) == ScaleGranularityM * cute::size<1>(mainloop_params.ScaleA.shape()));
|
||||
assert(cute::size<1>(typename MainloopParams::TileShape{}) == ScaleGranularityN * cute::size<1>(mainloop_params.ScaleB.shape()));
|
||||
|
||||
// Compute on this k-block
|
||||
for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
|
||||
|
||||
// Load Blockwise scaling factor from blockscale Tensors for B
|
||||
int64_t block_k = k / kBlockK;
|
||||
cute::Tensor scale_a = blockscale_A(_, block_k);
|
||||
cute::Tensor scale_b = blockscale_B(_, block_k);
|
||||
|
||||
// Load A
|
||||
ElementAccumulator a_frag[kBlockM];
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
if (m + m_b < cute::size<0>(mainloop_params.A.layout())) {
|
||||
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
|
||||
a_frag[m_b] = static_cast<ElementAccumulator>(ElementA(mainloop_params.A(m + m_b, k, l)));
|
||||
} else {
|
||||
a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
||||
}
|
||||
}
|
||||
|
||||
// Load B
|
||||
ElementAccumulator b_frag[kBlockN];
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (n + n_b < cute::size<0>(mainloop_params.B.layout())) {
|
||||
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
|
||||
b_frag[n_b] = static_cast<ElementAccumulator>(ElementB(mainloop_params.B(n + n_b, k, l)));
|
||||
} else {
|
||||
b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
||||
}
|
||||
}
|
||||
|
||||
// do compute
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply Groupwise-scaling at kBlockK boundary
|
||||
// (a) Apply group and block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary
|
||||
// (b) Zero-out partial temporary (acc_temp),
|
||||
// (c) Update permanent (accu)
|
||||
if ((k+1) % kBlockK == 0) {
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
auto scale_a_m_b = scale_a[m_b / ScaleGranularityM];
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
auto scale_b_n_b = scale_b[n_b / ScaleGranularityN];
|
||||
ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a_m_b * scale_b_n_b;
|
||||
acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b];
|
||||
acc_temp[m_b][n_b] = ElementAccumulator(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GETT - Epilogue
|
||||
template <class EpilogueParams, class ElementAccumulator, int kBlockM, int kBlockN>
|
||||
void gett_epilogue(
|
||||
EpilogueParams const& epilogue_params,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t l,
|
||||
ElementAccumulator (&acc)[kBlockM][kBlockN])
|
||||
{
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B");
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B");
|
||||
|
||||
using cute::raw_pointer_cast;
|
||||
|
||||
using ElementCompute = typename EpilogueParams::ElementCompute;
|
||||
using ElementC = typename EpilogueParams::TensorC::value_type;
|
||||
using ElementD = typename EpilogueParams::TensorD::value_type;
|
||||
using ElementAux = typename EpilogueParams::TensorAux::value_type;
|
||||
using ElementBias = typename EpilogueParams::VectorBias::value_type;
|
||||
using ElementScalar = typename EpilogueParams::ElementScalar;
|
||||
using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor;
|
||||
using ActivationFunctor = typename EpilogueParams::ActivationFunctor;
|
||||
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
|
||||
|
||||
constexpr bool PerColBias = EpilogueParams::PerColumnBias;
|
||||
constexpr bool IsScalingAndAmaxOutputNeeded =
|
||||
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
|
||||
|
||||
constexpr bool IsScalingAndAmaxAuxOutputNeeded =
|
||||
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
|
||||
|
||||
constexpr bool IsReLUAuxNeeded =
|
||||
(cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ReLu<ElementCompute>> or
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>) and
|
||||
cute::is_same_v<ElementAux, cutlass::uint1b_t>;
|
||||
constexpr bool IsClamp =
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>;
|
||||
|
||||
constexpr bool IsBackpropFusion =
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dGELU<ElementCompute>> or
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dReLU<ElementCompute>>;
|
||||
|
||||
// Input related converter
|
||||
NumericConverter<ElementCompute, ElementAccumulator> accumulator_converter;
|
||||
NumericConverter<ElementCompute, ElementC> source_converter;
|
||||
NumericConverter<ElementCompute, ElementBias> bias_converter;
|
||||
[[maybe_unused]] NumericConverter<ElementCompute, ElementAux> aux_source_converter;
|
||||
|
||||
// Scale related converter
|
||||
NumericConverter<ElementCompute, ElementScalar> scale_converter;
|
||||
NumericConverter<ElementCompute, ElementScalingFactor> scaling_factor_converter;
|
||||
|
||||
// Abs max converter
|
||||
[[maybe_unused]] NumericConverter<ElementAccumulator, ElementCompute> abs_max_output_converter;
|
||||
|
||||
// Output related converter
|
||||
NumericConverter<ElementD, ElementCompute> destination_converter;
|
||||
[[maybe_unused]] NumericConverter<ElementAux, ElementCompute> aux_destination_converter;
|
||||
NumericConverter<ElementBias, ElementCompute> dBias_converter;
|
||||
|
||||
// Epilogue operations
|
||||
multiply_add<ElementCompute, ElementCompute, ElementCompute> epilogue_fma;
|
||||
multiplies<ElementCompute> mul;
|
||||
plus<ElementCompute> add;
|
||||
|
||||
// Activation operation
|
||||
ActivationFunctor activation;
|
||||
|
||||
// Bias binary operation
|
||||
BiasBinaryOp bias_op;
|
||||
|
||||
// Do conversion
|
||||
ElementCompute converted_alpha = scale_converter(epilogue_params.alpha);
|
||||
ElementCompute converted_beta = scale_converter(epilogue_params.beta);
|
||||
ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a);
|
||||
ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b);
|
||||
ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c);
|
||||
ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d);
|
||||
ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux);
|
||||
|
||||
// Init local var
|
||||
[[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0);
|
||||
[[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0);
|
||||
|
||||
converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
|
||||
converted_beta = mul(converted_beta, converted_scale_c);
|
||||
|
||||
ElementCompute inter_accum[kBlockM][kBlockN];
|
||||
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
ElementCompute local_dBias = ElementCompute(0);
|
||||
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
|
||||
// Convert every type to ElementCompute first, do compute, convert to output type, write it out
|
||||
ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]);
|
||||
// per-row alpha
|
||||
if (raw_pointer_cast(epilogue_params.Valpha.data())) {
|
||||
converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b));
|
||||
}
|
||||
ElementCompute output = mul(converted_alpha, converted_acc);
|
||||
|
||||
if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) {
|
||||
ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b));
|
||||
output = bias_op(output, converted_bias);
|
||||
}
|
||||
|
||||
if (raw_pointer_cast(epilogue_params.C.data())) {
|
||||
ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l));
|
||||
// per-row beta
|
||||
if (epilogue_params.Vbeta.data()) {
|
||||
converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b));
|
||||
}
|
||||
output = epilogue_fma(converted_beta, converted_src, output);
|
||||
}
|
||||
|
||||
if constexpr (IsBackpropFusion) {
|
||||
ElementAux aux_input = ElementAux(0);
|
||||
if (raw_pointer_cast(epilogue_params.Aux.data())) {
|
||||
aux_input = epilogue_params.Aux(m + m_b, n + n_b, l);
|
||||
}
|
||||
|
||||
output = activation(output, aux_source_converter(aux_input));
|
||||
local_dBias = add(local_dBias, output);
|
||||
}
|
||||
else {
|
||||
if (raw_pointer_cast(epilogue_params.Aux.data())) {
|
||||
auto aux_output = output;
|
||||
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
|
||||
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
|
||||
local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output);
|
||||
aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0));
|
||||
}
|
||||
|
||||
if constexpr (IsReLUAuxNeeded) {
|
||||
epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0);
|
||||
} else {
|
||||
epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (IsClamp) { // Treat Clamp as ReLU
|
||||
output = activation(output, {0, std::numeric_limits<ElementCompute>::max()});
|
||||
}
|
||||
else {
|
||||
output = activation(output);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (IsScalingAndAmaxOutputNeeded) {
|
||||
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
|
||||
local_abs_max_output = amax_op(local_abs_max_output, output);
|
||||
output = epilogue_fma(converted_scale_d, output, ElementCompute(0));
|
||||
}
|
||||
|
||||
inter_accum[m_b][n_b] = ElementCompute(output);
|
||||
}
|
||||
} // n_b
|
||||
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) {
|
||||
if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) {
|
||||
ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b));
|
||||
local_dBias = add(local_dBias, converted_dBias);
|
||||
epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias);
|
||||
}
|
||||
}
|
||||
} // m_b
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
|
||||
epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp critical(Abs_Max_Data_Update)
|
||||
#endif
|
||||
{
|
||||
if constexpr (IsScalingAndAmaxOutputNeeded) {
|
||||
if (epilogue_params.abs_max_D) {
|
||||
*epilogue_params.abs_max_D = maximum_with_nan_propogation<ElementAccumulator>{}(
|
||||
*epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
|
||||
if (epilogue_params.abs_max_Aux) {
|
||||
*epilogue_params.abs_max_Aux = maximum_with_nan_propogation<ElementAccumulator>{}(
|
||||
*epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GEMM - General Matrix-Matrix contraction without conjugation options
|
||||
template <
|
||||
class MainloopParams,
|
||||
class EpilogueParams
|
||||
>
|
||||
void Gemm3x(
|
||||
MainloopParams const& mainloop_params,
|
||||
EpilogueParams const& epilogue_params)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{}));
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{}));
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{}));
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) "
|
||||
"with Batchmode are supported");
|
||||
// Lower the Matrix-Multiplication with Groupwise scaling (Gemm3x) to a Tensor Contraction (Gett).
|
||||
Gett(mainloop_params, epilogue_params);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // cutlass::reference::host
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,69 @@
|
||||
import os
|
||||
import setuptools
|
||||
import shutil
|
||||
import subprocess
|
||||
from setuptools.command.develop import develop
|
||||
from setuptools.command.install import install
|
||||
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
jit_include_dirs = ('deep_gemm/include/deep_gemm', )
|
||||
cutlass_dirs = '../../include'
|
||||
third_party_include_dirs = (os.path.join(cutlass_dirs, 'cute'), os.path.join(cutlass_dirs, 'cutlass'))
|
||||
print(third_party_include_dirs)
|
||||
|
||||
|
||||
class PostDevelopCommand(develop):
|
||||
def run(self):
|
||||
develop.run(self)
|
||||
self.make_jit_include_symlinks()
|
||||
|
||||
@staticmethod
|
||||
def make_jit_include_symlinks():
|
||||
# Make symbolic links of third-party include directories
|
||||
for d in third_party_include_dirs:
|
||||
dirname = d.split('/')[-1]
|
||||
src_dir = f'{current_dir}/{d}'
|
||||
dst_dir = f'{current_dir}/deep_gemm/include/{dirname}'
|
||||
if not os.path.exists(src_dir):
|
||||
os.makedirs(src_dir, exist_ok=True)
|
||||
assert os.path.exists(src_dir)
|
||||
if os.path.exists(dst_dir):
|
||||
assert os.path.islink(dst_dir)
|
||||
os.unlink(dst_dir)
|
||||
os.symlink(src_dir, dst_dir, target_is_directory=True)
|
||||
|
||||
|
||||
class PostInstallCommand(install):
|
||||
def run(self):
|
||||
install.run(self)
|
||||
self.copy_jit_includes()
|
||||
|
||||
def copy_jit_includes(self):
|
||||
# Copy include directories needed by JIT
|
||||
shutil.rmtree(f'{self.build_lib}/deep_gemm/include', ignore_errors=True)
|
||||
os.makedirs(f'{self.build_lib}/deep_gemm/include', exist_ok=False)
|
||||
for d in jit_include_dirs + third_party_include_dirs:
|
||||
src_dir = f'{current_dir}/{d}'
|
||||
dst_dir = f'{self.build_lib}/deep_gemm/include/{d.split("/")[-1]}'
|
||||
assert os.path.exists(src_dir)
|
||||
shutil.copytree(src_dir, dst_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cmd = ['git', 'rev-parse', '--short', 'HEAD']
|
||||
revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
|
||||
except:
|
||||
revision = ''
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
setuptools.setup(
|
||||
name='deep_gemm',
|
||||
version='1.0.0' + revision,
|
||||
packages=['deep_gemm', 'deep_gemm/jit', 'deep_gemm/jit_kernels'],
|
||||
cmdclass={
|
||||
'develop': PostDevelopCommand,
|
||||
'install': PostInstallCommand
|
||||
}
|
||||
)
|
||||
@ -0,0 +1,158 @@
|
||||
import random
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
import deep_gemm
|
||||
from deep_gemm import bench_kineto, calc_diff, cell_div, get_col_major_tma_aligned_tensor
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((cell_div(m, 128) * 128, cell_div(n, 128) * 128), dtype=x.dtype, device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
|
||||
|
||||
def construct(m: int, k: int, n: int) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
||||
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
ref_out = x @ y.t()
|
||||
|
||||
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
||||
return x_fp8, y_fp8, out, ref_out
|
||||
|
||||
|
||||
def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
||||
out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16)
|
||||
ref_out = torch.einsum('gmk,gnk->gmn', x, y)
|
||||
|
||||
assert m % 4 == 0, f'TMA alignment error: {m}'
|
||||
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
|
||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
|
||||
for i in range(num_groups):
|
||||
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
|
||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||
|
||||
# For non-masked input, we must merge the group and M dims
|
||||
if not is_masked:
|
||||
x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1])
|
||||
out, ref_out = out.view(-1, n), ref_out.view(-1, n)
|
||||
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
||||
return x_fp8, y_fp8, out, ref_out
|
||||
|
||||
|
||||
def test_gemm() -> None:
|
||||
print('Testing GEMM:')
|
||||
for m in (64, 128, 4096):
|
||||
for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
|
||||
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
||||
diff = calc_diff(out, ref_out)
|
||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
# Construct new tensors every time to avoid L2 cache acceleration
|
||||
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_contiguous() -> None:
|
||||
print('Testing grouped contiguous GEMM:')
|
||||
|
||||
for num_groups, m, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)):
|
||||
# TODO: make a stronger test
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
|
||||
m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
|
||||
m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||
diff = calc_diff(out, ref_out)
|
||||
assert diff < 0.001, f'm={m * num_groups}, {k=}, {n=}, {diff:.5f}'
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
# Construct new tensors every time to avoid L2 cache acceleration
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
|
||||
m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
|
||||
m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_masked() -> None:
|
||||
print('Testing grouped masked GEMM:')
|
||||
|
||||
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
|
||||
for k, n in ((7168, 4096), (2048, 7168), ):
|
||||
# Test correctness
|
||||
masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384)))
|
||||
for i in range(10):
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
|
||||
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
||||
for j in range(num_groups):
|
||||
masked_m[j] = random.choice(masked_m_candidates)
|
||||
expected_m = int(masked_m.float().mean()) + 1
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
|
||||
for j in range(num_groups):
|
||||
diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])
|
||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
# Construct new tensors every time to avoid L2 cache acceleration
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
|
||||
masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m)
|
||||
|
||||
# Test performance with fixed shapes
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
print('Library path:')
|
||||
print(f' > {deep_gemm.__path__}\n')
|
||||
|
||||
test_gemm()
|
||||
test_m_grouped_gemm_contiguous()
|
||||
test_m_grouped_gemm_masked()
|
||||
@ -0,0 +1,64 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Any
|
||||
|
||||
from deep_gemm import jit
|
||||
|
||||
|
||||
class Capture:
|
||||
def __init__(self) -> None:
|
||||
self.read_fd = None
|
||||
self.write_fd = None
|
||||
self.saved_stdout = None
|
||||
self.captured = None
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
self.read_fd, self.write_fd = os.pipe()
|
||||
self.saved_stdout = os.dup(1)
|
||||
os.dup2(self.write_fd, 1)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
os.dup2(self.saved_stdout, 1)
|
||||
os.close(self.write_fd)
|
||||
with os.fdopen(self.read_fd, 'r') as f:
|
||||
self.captured = f.read()
|
||||
|
||||
def capture(self) -> str:
|
||||
return self.captured
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Runtime
|
||||
print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n')
|
||||
|
||||
# Templates
|
||||
print('Generated code:')
|
||||
args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16),
|
||||
('enable_double_streams', bool), ('stream', torch.cuda.Stream))
|
||||
body = "\n"
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n'
|
||||
body += 'std::cout << enable_double_streams << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n'
|
||||
code = jit.generate((), args, body)
|
||||
print(code)
|
||||
|
||||
# Build
|
||||
print('Building ...')
|
||||
func = jit.build('test_func', args, code)
|
||||
|
||||
# Test correctness
|
||||
print('Running ...')
|
||||
fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda')
|
||||
fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda')
|
||||
bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda')
|
||||
with Capture() as capture:
|
||||
assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0
|
||||
output = capture.capture()
|
||||
ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n'
|
||||
assert output == ref_output, f'{output=}, {ref_output=}'
|
||||
|
||||
print('JIT test passed')
|
||||
12
examples/68_hopper_flash_mla/README.md
Normal file
12
examples/68_hopper_flash_mla/README.md
Normal file
@ -0,0 +1,12 @@
|
||||
# Hopper FlashMLA - Examples
|
||||
The codes in this example are migrated from [FlashMLA](https://github.com/deepseek-ai/FlashMLA/tree/main), it implements an efficient MLA decoding kernel for Hopper GPU.
|
||||
|
||||
# Run the example
|
||||
### Install
|
||||
```
|
||||
python setup.py install
|
||||
```
|
||||
### Run the test
|
||||
```
|
||||
python tests/test_flash_mla.py
|
||||
```
|
||||
213
examples/68_hopper_flash_mla/csrc/flash_api.cpp
Normal file
213
examples/68_hopper_flash_mla/csrc/flash_api.cpp
Normal file
@ -0,0 +1,213 @@
|
||||
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp
|
||||
|
||||
#include <torch/python.h>
|
||||
#include <torch/nn/functional.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cutlass/fast_math.h>
|
||||
|
||||
#include "flash_mla.h"
|
||||
#include "static_switch.h"
|
||||
|
||||
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
|
||||
std::vector<at::Tensor>
|
||||
get_mla_metadata(
|
||||
at::Tensor &seqlens_k,
|
||||
const int num_heads_per_head_k,
|
||||
const int num_heads_k
|
||||
) {
|
||||
// This should match the logic in the MLA kernel.
|
||||
static constexpr int block_size_m = 64;
|
||||
static constexpr int block_size_n = 64;
|
||||
static constexpr int fixed_overhead_num_blocks = 5;
|
||||
|
||||
CHECK_DEVICE(seqlens_k);
|
||||
TORCH_CHECK(seqlens_k.is_contiguous());
|
||||
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
|
||||
|
||||
int batch_size = seqlens_k.size(0);
|
||||
int *seqlens_k_ptr = seqlens_k.data_ptr<int>();
|
||||
auto options = seqlens_k.options();
|
||||
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
int sm_count = dprops->multiProcessorCount;
|
||||
int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m);
|
||||
|
||||
auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);
|
||||
auto num_splits = torch::empty({batch_size + 1}, options);
|
||||
int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
|
||||
int *num_splits_ptr = num_splits.data_ptr<int>();
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
Mla_metadata_params params = {};
|
||||
params.seqlens_k_ptr = seqlens_k_ptr;
|
||||
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
|
||||
params.num_splits_ptr = num_splits_ptr;
|
||||
params.batch_size = batch_size;
|
||||
params.block_size_n = block_size_n;
|
||||
params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
|
||||
params.num_sm_parts = num_sm_parts;
|
||||
get_mla_metadata_func(params, stream);
|
||||
|
||||
return {tile_scheduler_metadata, num_splits};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd_kvcache_mla(
|
||||
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
|
||||
std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
||||
const int head_size_v,
|
||||
const at::Tensor &seqlens_k, // batch_size
|
||||
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
||||
const at::Tensor &num_splits // batch_size + 1
|
||||
) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90);
|
||||
|
||||
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
|
||||
CHECK_DEVICE(block_table);
|
||||
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
|
||||
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int seqlen_q_ori = sizes[1];
|
||||
const int num_heads_ori = sizes[2];
|
||||
const int head_size = sizes[3];
|
||||
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
||||
TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
|
||||
|
||||
const int max_num_blocks_per_seq = block_table.size(1);
|
||||
const int num_blocks = kcache.size(0);
|
||||
const int page_block_size = kcache.size(1);
|
||||
const int num_heads_k = kcache.size(2);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be postive");
|
||||
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (seqlen_q_ori == 1) { is_causal = false; }
|
||||
|
||||
const int ngroups = num_heads_ori / num_heads_k;
|
||||
const int seqlen_q = seqlen_q_ori * ngroups;
|
||||
const int num_heads = num_heads_k;
|
||||
q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3)
|
||||
.reshape({batch_size, seqlen_q, num_heads, head_size});
|
||||
|
||||
int head_size_k = head_size;
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
||||
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
|
||||
if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
|
||||
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
|
||||
|
||||
|
||||
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
|
||||
CHECK_DEVICE(seqlens_k);
|
||||
CHECK_CONTIGUOUS(seqlens_k);
|
||||
CHECK_SHAPE(seqlens_k, batch_size);
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
|
||||
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
|
||||
Flash_fwd_mla_params params = {};
|
||||
// Set the sizes.
|
||||
params.b = batch_size;
|
||||
params.seqlen_q = seqlen_q;
|
||||
params.cu_seqlens_k = seqlens_k.data_ptr<int>();
|
||||
params.h = num_heads;
|
||||
params.h_h_k_ratio = num_heads / num_heads_k;
|
||||
params.ngroups = ngroups;
|
||||
params.is_causal = is_causal;
|
||||
params.d = head_size;
|
||||
params.d_v = head_size_v;
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
|
||||
// Set the pointers and strides.
|
||||
params.q_ptr = q.data_ptr();
|
||||
params.k_ptr = kcache.data_ptr();
|
||||
params.v_ptr = vcache.data_ptr();
|
||||
params.o_ptr = out.data_ptr();
|
||||
params.softmax_lse_ptr = softmax_lse.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.q_batch_stride = q.stride(0);
|
||||
params.k_batch_stride = kcache.stride(0);
|
||||
params.v_batch_stride = vcache.stride(0);
|
||||
params.o_batch_stride = out.stride(0);
|
||||
params.q_row_stride = q.stride(-3);
|
||||
params.k_row_stride = kcache.stride(-3);
|
||||
params.v_row_stride = vcache.stride(-3);
|
||||
params.o_row_stride = out.stride(-3);
|
||||
params.q_head_stride = q.stride(-2);
|
||||
params.k_head_stride = kcache.stride(-2);
|
||||
params.v_head_stride = vcache.stride(-2);
|
||||
params.o_head_stride = out.stride(-2);
|
||||
|
||||
params.block_table = block_table.data_ptr<int>();
|
||||
params.block_table_batch_stride = block_table.stride(0);
|
||||
params.page_block_size = page_block_size;
|
||||
|
||||
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
|
||||
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
|
||||
CHECK_DEVICE(tile_scheduler_metadata);
|
||||
CHECK_CONTIGUOUS(tile_scheduler_metadata);
|
||||
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
|
||||
params.num_sm_parts = tile_scheduler_metadata.size(0);
|
||||
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
|
||||
CHECK_DEVICE(num_splits);
|
||||
CHECK_CONTIGUOUS(num_splits);
|
||||
params.num_splits_ptr = num_splits.data_ptr<int>();
|
||||
|
||||
at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
|
||||
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
|
||||
params.oaccum_ptr = out_accum.data_ptr();
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK(head_size == 576);
|
||||
|
||||
if (q_dtype == torch::kBFloat16) {
|
||||
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
|
||||
}
|
||||
#ifndef FLASH_MLA_DISABLE_FP16
|
||||
else if (q_dtype == torch::kHalf) {
|
||||
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, stream);
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported tensor dtype for query");
|
||||
}
|
||||
|
||||
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
||||
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
||||
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
|
||||
.reshape({batch_size, num_heads_ori, seqlen_q_ori});
|
||||
|
||||
return {out, softmax_lse};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.doc() = "FlashMLA";
|
||||
m.def("get_mla_metadata", &get_mla_metadata);
|
||||
m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla);
|
||||
}
|
||||
@ -0,0 +1,3 @@
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
@ -0,0 +1,3 @@
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
603
examples/68_hopper_flash_mla/csrc/flash_fwd_mla_kernel.h
Normal file
603
examples/68_hopper_flash_mla/csrc/flash_fwd_mla_kernel.h
Normal file
@ -0,0 +1,603 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#include "named_barrier.h"
|
||||
#include "utils.h"
|
||||
#include "softmax.h"
|
||||
#include "static_switch.h"
|
||||
#include "flash_mla.h"
|
||||
|
||||
|
||||
template<typename PrecType, int DIM, int DIM2 = DIM>
|
||||
constexpr auto getSmemLayoutK() {
|
||||
constexpr int headSizeBytes = sizeof(PrecType) * DIM;
|
||||
constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2;
|
||||
|
||||
if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) {
|
||||
return GMMA::Layout_K_SW128_Atom<PrecType>{};
|
||||
} else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) {
|
||||
return GMMA::Layout_K_SW64_Atom<PrecType>{};
|
||||
} else {
|
||||
return GMMA::Layout_K_SW32_Atom<PrecType>{};
|
||||
}
|
||||
}
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, int kHeadDimV_ = 0>
|
||||
struct Flash_fwd_kernel_traits_mla {
|
||||
using Element = elem_type;
|
||||
using ElementAccum = float;
|
||||
using index_t = int64_t;
|
||||
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
static constexpr int kNWarpsS = 4;
|
||||
static constexpr int kNThreadsS = kNWarpsS * 32;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
|
||||
static_assert(kHeadDimV % 32 == 0);
|
||||
static_assert(kHeadDimV <= kHeadDim);
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
using TiledMma = decltype(make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>,
|
||||
GMMA::Major::K, GMMA::Major::K>(),
|
||||
Layout<Shape<Int<kNWarpsS / 4>, _1, _1>>{}));
|
||||
|
||||
static constexpr int AtomLayoutNO = kNThreads / kNThreadsS;
|
||||
using TiledMmaO = decltype(make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kHeadDimV / AtomLayoutNO>, Int<kBlockN>>,
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
Layout<Shape<Int<kNWarpsS / 4>, Int<AtomLayoutNO>, _1>>{}));
|
||||
|
||||
using SmemLayoutQ = decltype(tile_to_shape(
|
||||
getSmemLayoutK<Element, kHeadDim>(),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutK = decltype(tile_to_shape(
|
||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutV = decltype(tile_to_shape(
|
||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
|
||||
using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
|
||||
|
||||
using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
|
||||
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
|
||||
|
||||
using SmemLayoutAtomO = decltype(composition(
|
||||
Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, Element>;
|
||||
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||
using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
|
||||
static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;
|
||||
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
|
||||
using GmemLayoutAtom = Layout<
|
||||
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
using GmemTiledCopy = decltype(make_tiled_copy(
|
||||
Copy_Atom<Gmem_copy_struct, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
|
||||
using GmemLayoutAtomO = Layout<
|
||||
Shape<Int<kNThreadsS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
using GmemTiledCopyO = decltype(make_tiled_copy(
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
||||
GmemLayoutAtomO{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
|
||||
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;
|
||||
using GmemLayoutAtomOaccum = Layout<
|
||||
Shape<Int<kNThreadsS / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
|
||||
Stride<Int<kGmemThreadsPerRowAccum>, _1>>;
|
||||
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per store
|
||||
};
|
||||
|
||||
namespace flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<typename Kernel_traits>
|
||||
struct SharedStorageMLA {
|
||||
union {
|
||||
struct {
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
|
||||
};
|
||||
struct {
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_max;
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_sum;
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Split, typename SharedStorage, typename AccO, typename Softmax>
|
||||
__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx,
|
||||
SharedStorage &shared_storage, AccO tOrO, Softmax softmax) {
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
|
||||
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
typename Kernel_traits::TiledMmaO tiled_mma_o;
|
||||
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
|
||||
|
||||
// Epilogue
|
||||
|
||||
const int split_offset = __ldg(params.num_splits_ptr + bidb);
|
||||
|
||||
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(tOrO, params.scale_softmax);
|
||||
|
||||
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
|
||||
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
||||
// Partition sO to match the accumulator partitioning
|
||||
using SmemTiledCopyO = std::conditional_t<
|
||||
!Split,
|
||||
typename Kernel_traits::SmemCopyAtomO,
|
||||
typename Kernel_traits::SmemCopyAtomOaccum
|
||||
>;
|
||||
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o);
|
||||
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
||||
Tensor rO = flash::convert_type<ElementO>(tOrO);
|
||||
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
|
||||
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
__syncthreads();
|
||||
|
||||
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
|
||||
|
||||
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
|
||||
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
||||
const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
||||
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
|
||||
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
|
||||
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
using GmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum>;
|
||||
GmemTiledCopyO gmem_tiled_copy_Oaccum;
|
||||
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
||||
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tidx >= kNThreadsS) { return; }
|
||||
|
||||
Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
|
||||
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
|
||||
|
||||
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1)
|
||||
Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0);
|
||||
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
|
||||
if (get<1>(taccOcO_row(0)) == 0) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(lse); ++mi) {
|
||||
const int row = get<0>(taccOcO_row(mi));
|
||||
if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
|
||||
}
|
||||
}
|
||||
|
||||
// Construct identity layout for sO
|
||||
Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
// Repeat the partitioning with identity layouts
|
||||
Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM
|
||||
);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
|
||||
__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params ¶ms,
|
||||
const int bidb, const int bidh, const int m_block,
|
||||
const int n_split_idx, const int seqlen_k,
|
||||
const int n_block_min, const int n_block_max, const bool NoSplit,
|
||||
SharedStorage &shared_storage) {
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
|
||||
constexpr int kNThreads = Kernel_traits::kNThreads;
|
||||
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
|
||||
static_assert(kNThreads == 256 and kNThreadsS == 128);
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
int n_block = n_block_max - 1;
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
||||
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
|
||||
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
|
||||
|
||||
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
|
||||
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
|
||||
Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{});
|
||||
Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS);
|
||||
Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{});
|
||||
Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS);
|
||||
Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{});
|
||||
Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS);
|
||||
|
||||
typename Kernel_traits::TiledMmaO tiled_mma_o;
|
||||
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
|
||||
Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N)
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
|
||||
clear(tOrO);
|
||||
|
||||
flash::Softmax<2 * size<1>(tOrO)> softmax;
|
||||
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
if (warp_group_idx == 0) {
|
||||
typename Kernel_traits::TiledMma tiled_mma;
|
||||
auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
||||
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
|
||||
|
||||
if (n_block % 2 == 1) {
|
||||
// Double buffer for sK
|
||||
constexpr int sK_offset = size(sK);
|
||||
tSrK.data() = tSrK.data() + sK_offset / 8;
|
||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
}
|
||||
|
||||
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
||||
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
|
||||
// We will have at least 1 "masking" iteration.
|
||||
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
|
||||
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
|
||||
constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
|
||||
#pragma unroll 1
|
||||
for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) {
|
||||
__syncthreads();
|
||||
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
|
||||
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma, tSrQ, tSrK, tSrS);
|
||||
|
||||
const bool is_masking_step = masking_step > 0;
|
||||
const bool is_first_masking_step = masking_step == n_masking_steps;
|
||||
|
||||
if (is_masking_step) {
|
||||
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
|
||||
Tensor tScS = thr_mma.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
if constexpr (!Is_causal) { // Just masking based on col
|
||||
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY;
|
||||
} else {
|
||||
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
|
||||
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
|
||||
int row = int(get<0>(tScS(i)));
|
||||
int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
|
||||
if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We have key_padding_mask so we'll need to Check_inf
|
||||
Tensor scale_o = is_first_masking_step
|
||||
? softmax.template softmax</*Is_first=*/true, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
|
||||
: is_masking_step ?
|
||||
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
|
||||
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
|
||||
|
||||
Tensor rP = flash::convert_type<Element>(tSrS);
|
||||
cute::copy(rP, tPsP);
|
||||
cute::copy(scale_o, tScale_osScale_o);
|
||||
|
||||
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
||||
|
||||
flash::rescale_o(tOrO, scale_o);
|
||||
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||
|
||||
// Double buffer for sK
|
||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
||||
tSrK.data() = tSrK.data() + sK_offset / 8;
|
||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
}
|
||||
|
||||
cute::copy(softmax.row_max, tRow_maxsRow_max);
|
||||
cute::copy(softmax.row_sum, tRow_sumsRow_sum);
|
||||
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
||||
} else {
|
||||
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
|
||||
int cur_block_table = __ldg(&block_table[n_block]);
|
||||
|
||||
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.q_row_stride, _1{}));
|
||||
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
|
||||
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
|
||||
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
|
||||
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
|
||||
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
params.seqlen_q - m_block * kBlockM);
|
||||
|
||||
const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
||||
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.k_row_stride, _1{}));
|
||||
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K;
|
||||
auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS);
|
||||
Tensor tKgK = gmem_thr_copy_K.partition_S(gK);
|
||||
Tensor tKsK = gmem_thr_copy_K.partition_D(sK);
|
||||
Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
||||
Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
|
||||
Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tKsK)));
|
||||
|
||||
if (n_block % 2 == 1) {
|
||||
// Double buffer for sK
|
||||
constexpr int sK_offset = size(sK);
|
||||
tKsK.data() = tKsK.data() + sK_offset;
|
||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
}
|
||||
|
||||
// We need to clear the sK smem tiles because K is V.
|
||||
const index_t offset_k = cur_block_table * params.k_batch_stride;
|
||||
tKgK.data() = tKgK.data() + offset_k;
|
||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK,
|
||||
seqlen_k - n_block * kBlockN);
|
||||
tKgK.data() = tKgK.data() + -offset_k;
|
||||
cute::cp_async_fence();
|
||||
|
||||
if (n_block - 1 >= n_block_min) {
|
||||
cur_block_table = __ldg(&block_table[n_block - 1]);
|
||||
}
|
||||
|
||||
#pragma unroll 1
|
||||
for (; n_block >= n_block_min; --n_block) {
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
if (n_block - 1 >= n_block_min) {
|
||||
// Double buffer for sK
|
||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
||||
tKsK.data() = tKsK.data() + sK_offset;
|
||||
|
||||
const index_t offset_k = cur_block_table * params.k_batch_stride;
|
||||
tKgK.data() = tKgK.data() + offset_k;
|
||||
flash::copy</*Is_even_MN=*/true, /*Is_even_K=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK);
|
||||
tKgK.data() = tKgK.data() + -offset_k;
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
||||
|
||||
if (n_block - 2 >= n_block_min) {
|
||||
cur_block_table = __ldg(&block_table[n_block - 2]);
|
||||
}
|
||||
|
||||
typename Kernel_traits::TiledMma tiled_mma;
|
||||
auto tSrS_layout = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}).layout();
|
||||
Tensor rP = make_tensor<Element>(tSrS_layout);
|
||||
Tensor scale_o = make_tensor<float>(Shape<_2>{});
|
||||
cute::copy(tScale_osScale_o, scale_o);
|
||||
cute::copy(tPsP, rP);
|
||||
|
||||
flash::rescale_o(tOrO, scale_o);
|
||||
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||
|
||||
// Double buffer for sK
|
||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
}
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
||||
cute::copy(tRow_maxsRow_max, softmax.row_max);
|
||||
cute::copy(tRow_sumsRow_sum, softmax.row_sum);
|
||||
}
|
||||
|
||||
if (NoSplit)
|
||||
store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
|
||||
else
|
||||
store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
|
||||
__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1)
|
||||
flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
const int m_block = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
const int partition_idx = blockIdx.z;
|
||||
|
||||
extern __shared__ char shared_memory[];
|
||||
auto &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
|
||||
|
||||
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
|
||||
int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
|
||||
int begin_idx = tile_scheduler_metadata.x;
|
||||
int begin_seqlen = tile_scheduler_metadata.y;
|
||||
int end_idx = tile_scheduler_metadata.z;
|
||||
int end_seqlen = tile_scheduler_metadata.w;
|
||||
if (begin_idx >= params.b) return;
|
||||
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
|
||||
|
||||
#pragma unroll 1
|
||||
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
|
||||
const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
|
||||
const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);
|
||||
const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
|
||||
const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
|
||||
const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN);
|
||||
if (batch_id > begin_idx) {
|
||||
__syncthreads(); // Barrier between two tiles.
|
||||
}
|
||||
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Element, typename ElementAccum, typename index_t, int kHeadDimV, int kMaxSplits>
|
||||
__global__ void __launch_bounds__(256, 1, 1)
|
||||
flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
|
||||
constexpr int kNThreads = 128;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int bidx = blockIdx.x;
|
||||
const int hs = params.h * params.seqlen_q;
|
||||
const int batch_idx = bidx / hs;
|
||||
const int hs_idx = bidx % hs;
|
||||
|
||||
const int split_offset = __ldg(params.num_splits_ptr + batch_idx);
|
||||
const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset;
|
||||
FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits);
|
||||
if (actual_num_splits == 1) return;
|
||||
|
||||
__shared__ ElementAccum sLseScale[kMaxSplits];
|
||||
|
||||
const index_t row_offset_lseaccum = split_offset * hs + hs_idx;
|
||||
const index_t row_offset_lse = bidx;
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum),
|
||||
Shape<Int<kMaxSplits>>{}, make_stride(hs));
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
||||
Shape<_1>{}, Stride<_1>{});
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
if (warp_idx == 0) {
|
||||
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
|
||||
|
||||
float local_lse[kNLsePerThread];
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + tidx;
|
||||
local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY;
|
||||
}
|
||||
|
||||
float max_lse = -INFINITY;
|
||||
for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]);
|
||||
for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
|
||||
max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
|
||||
|
||||
float sum_lse = 0;
|
||||
for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse);
|
||||
for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
|
||||
|
||||
float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse;
|
||||
if (tidx == 0) gLSE(0) = global_lse;
|
||||
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + tidx;
|
||||
if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
static_assert(kHeadDimV % kNThreads == 0);
|
||||
constexpr int Elements = kHeadDimV / kNThreads;
|
||||
const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV;
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
|
||||
Shape<Int<kHeadDimV>>{}, Stride<_1>{});
|
||||
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
Layout<Shape<Int<kNThreads>>>{},
|
||||
Layout<Shape<Int<Elements>>>{}));
|
||||
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
|
||||
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
||||
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
|
||||
Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
|
||||
Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
|
||||
clear(tOrO);
|
||||
|
||||
for (int split = 0; split < actual_num_splits; ++split) {
|
||||
cute::copy(tOgOaccum, tOrOaccum);
|
||||
ElementAccum lse_scale = sLseScale[split];
|
||||
for (int i = 0; i < size(tOrO); ++i) {
|
||||
tOrO(i) += lse_scale * tOrOaccum(i);
|
||||
}
|
||||
tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV;
|
||||
}
|
||||
|
||||
Tensor rO = flash::convert_type<Element>(tOrO);
|
||||
const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q;
|
||||
const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q;
|
||||
auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride;
|
||||
Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
|
||||
cute::copy(rO, gO);
|
||||
}
|
||||
|
||||
} // namespace flash
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, typename SharedStorage>
|
||||
void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
|
||||
FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
|
||||
const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
auto kernel = &flash::flash_fwd_splitkv_mla_kernel<Kernel_traits, Is_causal, SharedStorage>;
|
||||
constexpr size_t smem_size = sizeof(SharedStorage);
|
||||
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
kernel<<<dim3(num_m_block, params.h, params.num_sm_parts), Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
});
|
||||
CHECK_CUDA_KERNEL_LAUNCH();
|
||||
|
||||
dim3 grid_combine(params.b * params.h * params.seqlen_q);
|
||||
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
|
||||
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
|
||||
typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;
|
||||
combine_kernel<<<grid_combine, 128, 0, stream>>>(params);
|
||||
});
|
||||
CHECK_CUDA_KERNEL_LAUNCH();
|
||||
}
|
||||
|
||||
template<typename T, int Headdim>
|
||||
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
|
||||
static_assert(Headdim == 576);
|
||||
FLASH_ASSERT(params.d_v == 512);
|
||||
FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV
|
||||
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
|
||||
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
|
||||
}
|
||||
77
examples/68_hopper_flash_mla/csrc/flash_fwd_mla_metadata.cu
Normal file
77
examples/68_hopper_flash_mla/csrc/flash_fwd_mla_metadata.cu
Normal file
@ -0,0 +1,77 @@
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
|
||||
static constexpr int MaxBatchSize = 4096;
|
||||
|
||||
__global__ void __launch_bounds__(256, 1, 1)
|
||||
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
|
||||
int *seqlens_k_ptr = params.seqlens_k_ptr;
|
||||
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
|
||||
int *num_splits_ptr = params.num_splits_ptr;
|
||||
int batch_size = params.batch_size;
|
||||
int block_size_n = params.block_size_n;
|
||||
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
|
||||
int num_sm_parts = params.num_sm_parts;
|
||||
|
||||
__shared__ int num_blocks_shared[MaxBatchSize];
|
||||
__shared__ int num_splits_shared[MaxBatchSize];
|
||||
|
||||
int total_num_blocks = 0;
|
||||
for (int i = threadIdx.x; i < batch_size; i += 32) {
|
||||
int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
|
||||
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
|
||||
num_blocks_shared[i] = num_blocks;
|
||||
}
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
|
||||
|
||||
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
|
||||
num_splits_shared[0] = 0;
|
||||
for (int i = 0; i < num_sm_parts; ++i) {
|
||||
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
|
||||
tile_scheduler_metadata0[0] = now_idx;
|
||||
tile_scheduler_metadata0[1] = now_block * block_size_n;
|
||||
tile_scheduler_metadata1 = now_n_split_idx;
|
||||
int remain_payload = payload;
|
||||
while (now_idx < batch_size) {
|
||||
int num_blocks = num_blocks_shared[now_idx];
|
||||
int now_remain_blocks = num_blocks - now_block;
|
||||
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
|
||||
cum_num_splits += now_n_split_idx + 1;
|
||||
num_splits_shared[now_idx + 1] = cum_num_splits;
|
||||
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
|
||||
++now_idx;
|
||||
now_block = 0;
|
||||
now_n_split_idx = 0;
|
||||
} else {
|
||||
if (remain_payload - fixed_overhead_num_blocks > 0) {
|
||||
now_block += remain_payload - fixed_overhead_num_blocks;
|
||||
++now_n_split_idx;
|
||||
remain_payload = 0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
|
||||
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
|
||||
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
|
||||
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
|
||||
}
|
||||
FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
for (int i = threadIdx.x; i <= batch_size; i += 32) {
|
||||
num_splits_ptr[i] = num_splits_shared[i];
|
||||
}
|
||||
}
|
||||
|
||||
void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) {
|
||||
FLASH_ASSERT(params.batch_size < MaxBatchSize);
|
||||
get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);
|
||||
CHECK_CUDA_KERNEL_LAUNCH();
|
||||
}
|
||||
63
examples/68_hopper_flash_mla/csrc/flash_mla.h
Normal file
63
examples/68_hopper_flash_mla/csrc/flash_mla.h
Normal file
@ -0,0 +1,63 @@
|
||||
#pragma once
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Flash_fwd_mla_params {
|
||||
using index_t = int64_t;
|
||||
|
||||
int b, seqlen_q, d, d_v;
|
||||
int h, h_h_k_ratio, ngroups;
|
||||
bool is_causal;
|
||||
float scale_softmax, scale_softmax_log2;
|
||||
int *__restrict__ cu_seqlens_k;
|
||||
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
void *__restrict__ v_ptr;
|
||||
void *__restrict__ o_ptr;
|
||||
void *__restrict__ softmax_lse_ptr;
|
||||
|
||||
index_t q_batch_stride;
|
||||
index_t k_batch_stride;
|
||||
index_t v_batch_stride;
|
||||
index_t o_batch_stride;
|
||||
index_t q_row_stride;
|
||||
index_t k_row_stride;
|
||||
index_t v_row_stride;
|
||||
index_t o_row_stride;
|
||||
index_t q_head_stride;
|
||||
index_t k_head_stride;
|
||||
index_t v_head_stride;
|
||||
index_t o_head_stride;
|
||||
|
||||
int *__restrict__ block_table;
|
||||
index_t block_table_batch_stride;
|
||||
int page_block_size;
|
||||
|
||||
int *__restrict__ tile_scheduler_metadata_ptr;
|
||||
int num_sm_parts;
|
||||
int *__restrict__ num_splits_ptr;
|
||||
|
||||
void *__restrict__ softmax_lseaccum_ptr;
|
||||
void *__restrict__ oaccum_ptr;
|
||||
};
|
||||
|
||||
static constexpr int TileSchedulerMetaDataSize = 8;
|
||||
// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _]
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, int Headdim>
|
||||
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
|
||||
struct Mla_metadata_params {
|
||||
int *__restrict__ seqlens_k_ptr;
|
||||
int *__restrict__ tile_scheduler_metadata_ptr;
|
||||
int *__restrict__ num_splits_ptr;
|
||||
int batch_size;
|
||||
int block_size_n;
|
||||
int fixed_overhead_num_blocks;
|
||||
int num_sm_parts;
|
||||
};
|
||||
|
||||
void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream);
|
||||
15
examples/68_hopper_flash_mla/csrc/named_barrier.h
Normal file
15
examples/68_hopper_flash_mla/csrc/named_barrier.h
Normal file
@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/barrier.h"
|
||||
|
||||
namespace flash {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Enumerates the reserved named barriers to avoid potential conflicts
|
||||
|
||||
enum class NamedBarriers {
|
||||
SReady = 1,
|
||||
SoftmaxReady = 2,
|
||||
};
|
||||
|
||||
} // flash
|
||||
197
examples/68_hopper_flash_mla/csrc/softmax.h
Normal file
197
examples/68_hopper_flash_mla/csrc/softmax.h
Normal file
@ -0,0 +1,197 @@
|
||||
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
||||
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
summary(mi) = op(summary(mi), tensor(mi, ni));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++){
|
||||
dst(i) = Allreduce<4>::run(src(i), op);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
thread_reduce_<zero_init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
||||
MaxOp<float> max_op;
|
||||
reduce_<zero_init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
||||
SumOp<float> sum_op;
|
||||
thread_reduce_<zero_init>(tensor, sum, sum_op);
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ auto scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
||||
// We don't want (-inf - (-inf)) since that would give NaN.
|
||||
// If we don't have float around M_LOG2E the multiplication is done in fp64.
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
// The following macro will disable the use of fma.
|
||||
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
|
||||
// This macro is set in PyTorch and not FlashAttention
|
||||
#ifdef UNFUSE_FMA
|
||||
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
|
||||
#else
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
MaxOp<float> max_op;
|
||||
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
max(mi) = max_op(max(mi), tensor(mi, ni));
|
||||
}
|
||||
max(mi) = Allreduce<4>::run(max(mi), max_op);
|
||||
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
||||
// We don't want (-inf - (-inf)) since that would give NaN.
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
|
||||
sum(mi) = 0;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
sum(mi) += tensor(mi, ni);
|
||||
}
|
||||
SumOp<float> sum_op;
|
||||
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Tensor0, typename Tensor1>
|
||||
__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) {
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(scale_o); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); }
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kNRows>
|
||||
struct Softmax {
|
||||
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
||||
TensorT row_max, row_sum;
|
||||
|
||||
__forceinline__ __device__ Softmax() {};
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) {
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
TensorT scale_o;
|
||||
clear(scale_o);
|
||||
if (Is_first) {
|
||||
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = !Check_inf
|
||||
? row_max(mi)
|
||||
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
|
||||
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
scale_o(mi) = scores_scale;
|
||||
row_sum(mi) *= scores_scale;
|
||||
}
|
||||
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
// We don't do the reduce across threads here since we don't need to use the row_sum.
|
||||
// We do that reduce at the end when we need to normalize the softmax.
|
||||
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
|
||||
}
|
||||
return scale_o;
|
||||
};
|
||||
|
||||
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
TensorT lse = make_fragment_like(row_sum);
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
|
||||
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
||||
}
|
||||
return lse;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace flash
|
||||
65
examples/68_hopper_flash_mla/csrc/static_switch.h
Normal file
65
examples/68_hopper_flash_mla/csrc/static_switch.h
Normal file
@ -0,0 +1,65 @@
|
||||
#pragma once
|
||||
|
||||
#define CHECK_CUDA(call) \
|
||||
do { \
|
||||
cudaError_t status_ = call; \
|
||||
if (status_ != cudaSuccess) { \
|
||||
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
|
||||
|
||||
|
||||
#define FLASH_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
|
||||
#define FLASH_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
|
||||
asm("trap;"); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
constexpr static bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr static bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
|
||||
#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \
|
||||
[&] { \
|
||||
if (NUM_SPLITS <= 32) { \
|
||||
constexpr static int NAME = 32; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (NUM_SPLITS <= 64) { \
|
||||
constexpr static int NAME = 64; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (NUM_SPLITS <= 96) { \
|
||||
constexpr static int NAME = 96; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (NUM_SPLITS <= 128) { \
|
||||
constexpr static int NAME = 128; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (NUM_SPLITS <= 160) { \
|
||||
constexpr static int NAME = 160; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
FLASH_ASSERT(false); \
|
||||
} \
|
||||
}()
|
||||
238
examples/68_hopper_flash_mla/csrc/utils.h
Normal file
238
examples/68_hopper_flash_mla/csrc/utils.h
Normal file
@ -0,0 +1,238 @@
|
||||
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace flash {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct MaxOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<float> {
|
||||
// This is slightly faster
|
||||
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
|
||||
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
|
||||
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
|
||||
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
|
||||
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (arrive) {
|
||||
warpgroup_arrive();
|
||||
}
|
||||
if constexpr (zero_init) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
} else {
|
||||
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
}
|
||||
if constexpr (commit) {
|
||||
warpgroup_commit_batch();
|
||||
}
|
||||
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
template<bool Transposed=false, typename Layout0>
|
||||
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) {
|
||||
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = acc_layout;
|
||||
if constexpr (!Transposed) {
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
|
||||
} else {
|
||||
return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
|
||||
}
|
||||
|
||||
} else { // SM80
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
||||
if constexpr (!Transposed) {
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
} else {
|
||||
return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
|
||||
// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
|
||||
// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N))
|
||||
template<typename MMA_Traits, typename Layout0>
|
||||
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) {
|
||||
using X = Underscore;
|
||||
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
|
||||
if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) {
|
||||
auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16))
|
||||
return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
|
||||
} else {
|
||||
static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1);
|
||||
static_assert(decltype(stride<0, 0>(acc_layout))::value == 1);
|
||||
static_assert(decltype(stride<0, 1>(acc_layout))::value == 2);
|
||||
auto l = logical_divide(get<0, 2>(acc_layout), Tile<Layout<Shape<_2, _2>>>{}); // (((2, 2), N / 32))
|
||||
// This combines the first two modes (<0, 0> and <0, 1>) into one mode.
|
||||
// Will require register shuffling later to be correct.
|
||||
return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)),
|
||||
get<1>(acc_layout),
|
||||
coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N)
|
||||
// This combination is right but doesn't work with register shuffling.
|
||||
// return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)),
|
||||
// get<1>(acc_layout),
|
||||
// coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
|
||||
}
|
||||
} else { // SM80
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{});
|
||||
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||
if constexpr (mma_shape_K == 8) {
|
||||
return acc_layout;
|
||||
} else {
|
||||
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
|
||||
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
// HACK: this requires tensor to be "contiguous"
|
||||
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
|
||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Blocks until all but N previous cp.async.commit_group operations have committed.
|
||||
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
|
||||
// (which is equivalent to commit_group then wait_group 0).
|
||||
// Instead we just call cp.async.wait_group 0, which is slightly faster.
|
||||
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
|
||||
template <int N>
|
||||
CUTE_HOST_DEVICE
|
||||
void cp_async_wait() {
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
|
||||
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
||||
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
// There's no case where !Clear_OOB_K && Clear_OOB_MN
|
||||
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || predicate_K(k)) {
|
||||
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
} else if (Clear_OOB_MN) {
|
||||
cute::clear(D(_, m, _));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
6
examples/68_hopper_flash_mla/flash_mla/__init__.py
Normal file
6
examples/68_hopper_flash_mla/flash_mla/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
__version__ = "1.0.0"
|
||||
|
||||
from flash_mla.flash_mla_interface import (
|
||||
get_mla_metadata,
|
||||
flash_mla_with_kvcache,
|
||||
)
|
||||
@ -0,0 +1,67 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import flash_mla_cuda
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_heads_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
cache_seqlens: (batch_size), dtype torch.int32.
|
||||
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||
num_heads_k: num_heads_k.
|
||||
|
||||
Returns:
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
|
||||
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
head_dim_v: int,
|
||||
tile_scheduler_metadata: torch.Tensor,
|
||||
num_splits: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head dimension of v.
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
|
||||
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
|
||||
Returns:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
|
||||
q,
|
||||
k_cache,
|
||||
None,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
)
|
||||
return out, softmax_lse
|
||||
87
examples/68_hopper_flash_mla/setup.py
Normal file
87
examples/68_hopper_flash_mla/setup.py
Normal file
@ -0,0 +1,87 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
from torch.utils.cpp_extension import (
|
||||
BuildExtension,
|
||||
CUDAExtension,
|
||||
IS_WINDOWS,
|
||||
)
|
||||
|
||||
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE"
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
|
||||
return nvcc_extra_args + ["--threads", nvcc_threads]
|
||||
|
||||
def get_sources():
|
||||
sources = [
|
||||
"csrc/flash_api.cpp",
|
||||
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
||||
"csrc/flash_fwd_mla_metadata.cu",
|
||||
]
|
||||
|
||||
if not DISABLE_FP16:
|
||||
sources.append("csrc/flash_fwd_mla_fp16_sm90.cu")
|
||||
|
||||
return sources
|
||||
|
||||
def get_features_args():
|
||||
features_args = []
|
||||
if DISABLE_FP16:
|
||||
features_args.append("-DFLASH_MLA_DISABLE_FP16")
|
||||
return features_args
|
||||
|
||||
cc_flag = []
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_90a,code=sm_90a")
|
||||
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
if IS_WINDOWS:
|
||||
cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"]
|
||||
else:
|
||||
cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"]
|
||||
|
||||
ext_modules = []
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
name="flash_mla_cuda",
|
||||
sources=get_sources(),
|
||||
extra_compile_args={
|
||||
"cxx": cxx_args + get_features_args(),
|
||||
"nvcc": append_nvcc_threads(
|
||||
[
|
||||
"-O3",
|
||||
"-std=c++17",
|
||||
"-DNDEBUG",
|
||||
"-D_USE_MATH_DEFINES",
|
||||
"-Wno-deprecated-declarations",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
"--use_fast_math",
|
||||
"--ptxas-options=-v,--register-usage-level=10"
|
||||
]
|
||||
+ cc_flag
|
||||
) + get_features_args(),
|
||||
},
|
||||
include_dirs=[
|
||||
Path(this_dir) / "csrc",
|
||||
Path(this_dir) / ".." / ".." / "include",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
setup(
|
||||
name="flash_mla",
|
||||
version="1.0.0",
|
||||
packages=find_packages(include=['flash_mla']),
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
||||
153
examples/68_hopper_flash_mla/tests/test_flash_mla.py
Normal file
153
examples/68_hopper_flash_mla/tests/test_flash_mla.py
Normal file
@ -0,0 +1,153 @@
|
||||
import argparse
|
||||
import math
|
||||
import random
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
|
||||
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
value = value.float()
|
||||
key = key.repeat_interleave(h_q // h_kv, dim=0)
|
||||
value = value.repeat_interleave(h_q // h_kv, dim=0)
|
||||
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
|
||||
if is_causal:
|
||||
s_q = query.shape[-2]
|
||||
s_k = key.shape[-2]
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
|
||||
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
attn_bias.to(query.dtype)
|
||||
attn_weight += attn_bias
|
||||
lse = attn_weight.logsumexp(dim=-1)
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||
return attn_weight @ value, lse
|
||||
|
||||
|
||||
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
||||
x, y = x.double(), y.double()
|
||||
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
|
||||
amax_diff = (x - y).abs().max().item()
|
||||
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
|
||||
assert cos_diff < 1e-5
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
|
||||
print(
|
||||
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}"
|
||||
)
|
||||
|
||||
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
||||
if varlen:
|
||||
for i in range(b):
|
||||
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
|
||||
total_seqlens = cache_seqlens.sum().item()
|
||||
mean_seqlens = cache_seqlens.float().mean().int().item()
|
||||
max_seqlen = cache_seqlens.max().item()
|
||||
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
|
||||
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
|
||||
|
||||
q = torch.randn(b, s_q, h_q, d)
|
||||
block_size = 64
|
||||
block_table = torch.arange(
|
||||
b * max_seqlen_pad // block_size, dtype=torch.int32
|
||||
).view(b, max_seqlen_pad // block_size)
|
||||
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
||||
for i in range(b):
|
||||
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
|
||||
float("nan")
|
||||
)
|
||||
blocked_v = blocked_k[..., :dv]
|
||||
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens, s_q * h_q // h_kv, h_kv
|
||||
)
|
||||
|
||||
def flash_mla():
|
||||
return flash_mla_with_kvcache(
|
||||
q,
|
||||
blocked_k,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
def ref_mla():
|
||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||
for i in range(b):
|
||||
begin = i * max_seqlen_pad
|
||||
end = begin + cache_seqlens[i]
|
||||
O, LSE = scaled_dot_product_attention(
|
||||
q[i].transpose(0, 1),
|
||||
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||
h_q=h_q,
|
||||
h_kv=h_kv,
|
||||
is_causal=causal,
|
||||
)
|
||||
out[i] = O.transpose(0, 1)
|
||||
lse[i] = LSE
|
||||
return out, lse
|
||||
|
||||
out_flash, lse_flash = flash_mla()
|
||||
out_torch, lse_torch = ref_mla()
|
||||
cal_diff(out_flash, out_torch, "out")
|
||||
cal_diff(lse_flash, lse_torch, "lse")
|
||||
|
||||
t = triton.testing.do_bench(flash_mla)
|
||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
|
||||
torch.finfo(q.dtype).bits // 8
|
||||
)
|
||||
print(
|
||||
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
|
||||
)
|
||||
|
||||
|
||||
def main(torch_dtype):
|
||||
device = torch.device("cuda:0")
|
||||
torch.set_default_dtype(torch_dtype)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
h_kv = 1
|
||||
d, dv = 576, 512
|
||||
causal = True
|
||||
|
||||
for b in [128]:
|
||||
for s in [4096, 8192]:
|
||||
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
|
||||
for s_q in [1, 2]: # MTP = 1, 2
|
||||
for varlen in [False, True]:
|
||||
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["bf16", "fp16"],
|
||||
default="bf16",
|
||||
help="Data type to use for testing (bf16 or fp16)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
torch_dtype = torch.bfloat16
|
||||
if args.dtype == "fp16":
|
||||
torch_dtype = torch.float16
|
||||
|
||||
main(torch_dtype)
|
||||
@ -1,773 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Grouped scale Hopper FP8 Grouped GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
|
||||
This example demonstrates a grouped scaled FP8 Grouped GEMM using the new CUTLASS 3.0.
|
||||
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
|
||||
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
|
||||
which are more efficient than the Ampere tensor core instructions.
|
||||
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
|
||||
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
|
||||
copies between thread blocks in a cluster. This example also showcases on-the-fly modification of TMA
|
||||
descriptors to move between groups/problem_count (represented by groups).
|
||||
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
|
||||
4. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
|
||||
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
|
||||
improve performance.
|
||||
Examples:
|
||||
$ ./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling \
|
||||
--m=2816 --n=3072 --k=16384 --save_aux=false --save_amax=false \
|
||||
--raster=h --swizzle=2 --benchmark=./test_benchmark.txt
|
||||
|
||||
Where the test_benchmark.txt may look as such:
|
||||
0 256x512x128
|
||||
1 256x512x512
|
||||
2 512x256x128
|
||||
3 256x256x128
|
||||
4 256x512x1024
|
||||
5 1024x512x128 and so on
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <cfloat>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
// Includes from examples directory
|
||||
#include "helper.h"
|
||||
#include "hopper_fp8_commandline.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementBlockScale = float; // Element type for blockscaling during accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
constexpr int ScaleGranularityM = 1;
|
||||
constexpr int ScaleGranularityN = 128;
|
||||
constexpr int ScaleGranularityK = 128;
|
||||
|
||||
constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementD, LayoutD *, AlignmentD,
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, cute::tuple<LayoutA *, LayoutSFA *>, AlignmentA,
|
||||
ElementB, cute::tuple<LayoutB *, LayoutSFB *>, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopWithGroupWiseScaling,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
|
||||
// Extract information from Gemm kernel.
|
||||
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
|
||||
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
|
||||
"ElementAccumulator and ElementBlockScale should be same datatype");
|
||||
|
||||
/// Initialization
|
||||
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
std::vector<int64_t> offset_A;
|
||||
std::vector<int64_t> offset_B;
|
||||
std::vector<int64_t> offset_C;
|
||||
std::vector<int64_t> offset_D;
|
||||
std::vector<int64_t> offset_blockscale_A;
|
||||
std::vector<int64_t> offset_blockscale_B;
|
||||
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
std::vector<LayoutSFA> layout_SFA_host;
|
||||
std::vector<LayoutSFB> layout_SFB_host;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<ElementA> block_A;
|
||||
cutlass::DeviceAllocation<ElementB> block_B;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<ElementD> block_D;
|
||||
cutlass::DeviceAllocation<ElementBlockScale> blockscale_block_A;
|
||||
cutlass::DeviceAllocation<ElementBlockScale> blockscale_block_B;
|
||||
|
||||
cutlass::DeviceAllocation<const ElementA *> ptr_A;
|
||||
cutlass::DeviceAllocation<const ElementB *> ptr_B;
|
||||
cutlass::DeviceAllocation<const ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<ElementD *> ptr_D;
|
||||
cutlass::DeviceAllocation<ElementD *> ptr_ref_D;
|
||||
cutlass::DeviceAllocation<const ElementBlockScale *> ptr_blockscale_A;
|
||||
cutlass::DeviceAllocation<const ElementBlockScale *> ptr_blockscale_B;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
cutlass::DeviceAllocation<LayoutSFA> layout_SFA;
|
||||
cutlass::DeviceAllocation<LayoutSFB> layout_SFB;
|
||||
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams<Shape<int,int,int>>::RasterOrderOptions;
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element, class ScopeMin = std::nullopt_t, class ScopeMax = std::nullopt_t>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023,
|
||||
ScopeMin scope_min = std::nullopt, ScopeMax scope_max = std::nullopt) {
|
||||
|
||||
double _scope_max, _scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
if (bits_input == 1) {
|
||||
_scope_max = 2;
|
||||
_scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
_scope_max = 2;
|
||||
_scope_min = -2;
|
||||
} else if (bits_input == 16) {
|
||||
_scope_max = 5;
|
||||
_scope_min = -5;
|
||||
} else {
|
||||
_scope_max = 8;
|
||||
_scope_min = -8;
|
||||
}
|
||||
if constexpr (!std::is_same_v<ScopeMax, std::nullopt_t>) {
|
||||
_scope_max = scope_max;
|
||||
}
|
||||
if constexpr (!std::is_same_v<ScopeMin, std::nullopt_t>) {
|
||||
_scope_min = scope_min;
|
||||
}
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) _scope_max, (Element) _scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Allocates device-side data
|
||||
template <typename OptionType>
|
||||
void allocate(const OptionType &options) {
|
||||
|
||||
int64_t total_elements_A = 0;
|
||||
int64_t total_elements_B = 0;
|
||||
int64_t total_elements_C = 0;
|
||||
int64_t total_elements_D = 0;
|
||||
int64_t total_elements_blockscale_A = 0;
|
||||
int64_t total_elements_blockscale_B = 0;
|
||||
|
||||
offset_A.clear();
|
||||
offset_B.clear();
|
||||
offset_C.clear();
|
||||
offset_D.clear();
|
||||
offset_blockscale_A.clear();
|
||||
offset_blockscale_B.clear();
|
||||
stride_A_host.clear();
|
||||
stride_B_host.clear();
|
||||
stride_C_host.clear();
|
||||
stride_D_host.clear();
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
auto group_layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1));
|
||||
auto group_layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1));
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B);
|
||||
offset_C.push_back(total_elements_C);
|
||||
offset_D.push_back(total_elements_D);
|
||||
offset_blockscale_A.push_back(total_elements_blockscale_A);
|
||||
offset_blockscale_B.push_back(total_elements_blockscale_B);
|
||||
|
||||
int64_t elements_A = M * K;
|
||||
int64_t elements_B = K * N;
|
||||
int64_t elements_C = M * N;
|
||||
int64_t elements_D = M * N;
|
||||
int64_t elements_blockscale_A = size(filter_zeros(group_layout_SFA));
|
||||
int64_t elements_blockscale_B = size(filter_zeros(group_layout_SFB));
|
||||
|
||||
total_elements_A += elements_A;
|
||||
total_elements_B += elements_B;
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
total_elements_blockscale_A += elements_blockscale_A;
|
||||
total_elements_blockscale_B += elements_blockscale_B;
|
||||
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
|
||||
layout_SFA_host.push_back(group_layout_SFA);
|
||||
layout_SFB_host.push_back(group_layout_SFB);
|
||||
|
||||
}
|
||||
|
||||
block_A.reset(total_elements_A);
|
||||
block_B.reset(total_elements_B);
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
blockscale_block_A.reset(total_elements_blockscale_A);
|
||||
blockscale_block_B.reset(total_elements_blockscale_B);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
template <typename OptionType>
|
||||
void initialize(const OptionType &options) {
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
|
||||
std::vector<ElementA *> ptr_A_host(options.groups);
|
||||
std::vector<ElementB *> ptr_B_host(options.groups);
|
||||
std::vector<ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<ElementD *> ptr_D_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
std::vector<ElementBlockScale *> ptr_blockscale_A_host(options.groups);
|
||||
std::vector<ElementBlockScale *> ptr_blockscale_B_host(options.groups);
|
||||
|
||||
alpha_host.clear();
|
||||
beta_host.clear();
|
||||
|
||||
for (int i = 0; i < options.groups; i++) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
|
||||
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
ptr_blockscale_A.reset(options.groups);
|
||||
ptr_blockscale_A.copy_from_host(ptr_blockscale_A_host.data());
|
||||
|
||||
ptr_blockscale_B.reset(options.groups);
|
||||
ptr_blockscale_B.copy_from_host(ptr_blockscale_B_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
layout_SFA.reset(options.groups);
|
||||
layout_SFA.copy_from_host(layout_SFA_host.data());
|
||||
|
||||
layout_SFB.reset(options.groups);
|
||||
layout_SFB.copy_from_host(layout_SFB_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_block(block_A, seed + 2022);
|
||||
initialize_block(block_B, seed + 2023);
|
||||
initialize_block(block_C, seed + 2024);
|
||||
initialize_block(blockscale_block_A, seed + 2025, -1, 1);
|
||||
initialize_block(blockscale_block_B, seed + 2026, -1, 1);
|
||||
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template<typename GemmArguments, typename OptionType>
|
||||
GemmArguments args_from_options(const OptionType &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
int device_id = 0;
|
||||
cutlass::KernelHardwareInfo kernel_hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info<typename Gemm::GemmKernel>(device_id);
|
||||
|
||||
GemmArguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), host_problem_shapes_available ? options.problem_sizes_host.data() : (decltype(options.problem_sizes_host.data())) nullptr},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(),
|
||||
ptr_blockscale_A.get(), layout_SFA.get(),
|
||||
ptr_blockscale_B.get(), layout_SFB.get()
|
||||
},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
ptr_C.get(), stride_C.get(),
|
||||
ptr_D.get(), stride_D.get()
|
||||
},
|
||||
kernel_hw_info
|
||||
};
|
||||
|
||||
auto &fusion_args = arguments.epilogue.thread;
|
||||
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
|
||||
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
// Single alpha and beta for all groups
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
}
|
||||
else {
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// One alpha and beta per each group
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
|
||||
arguments.scheduler.raster_order = options.raster;
|
||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename OptionType>
|
||||
bool verify(const OptionType &options) {
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
std::vector<ElementA> block_A_host(block_A.size());
|
||||
std::vector<ElementB> block_B_host(block_B.size());
|
||||
std::vector<ElementC> block_C_host(block_C.size());
|
||||
std::vector<ElementD> block_D_host_kernel(block_D.size());
|
||||
std::vector<ElementD> block_D_host_ref(block_D.size());
|
||||
std::vector<ElementBlockScale> blockscale_block_A_host(blockscale_block_A.size());
|
||||
std::vector<ElementBlockScale> blockscale_block_B_host(blockscale_block_B.size());
|
||||
|
||||
block_A.copy_to_host(block_A_host.data());
|
||||
block_B.copy_to_host(block_B_host.data());
|
||||
block_C.copy_to_host(block_C_host.data());
|
||||
block_D.copy_to_host(block_D_host_kernel.data());
|
||||
blockscale_block_A.copy_to_host(blockscale_block_A_host.data());
|
||||
blockscale_block_B.copy_to_host(blockscale_block_B_host.data());
|
||||
|
||||
bool passed = true;
|
||||
for (int group_idx = 0; group_idx < options.groups; group_idx++) {
|
||||
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
|
||||
auto [m, n, k] = options.problem_sizes_host.at(group_idx);
|
||||
auto gemm_problem_shape = cute::make_shape(m, n, k);
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(m, k, 1),
|
||||
stride_A_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
auto B = cute::make_tensor(block_B_host.data() + offset_B.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(n, k, 1),
|
||||
stride_B_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
auto C = cute::make_tensor(block_C_host.data() + offset_C.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(m, n, 1),
|
||||
stride_C_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
auto D = cute::make_tensor(block_D_host_ref.data() + offset_D.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(m, n, 1),
|
||||
stride_D_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
|
||||
auto SFA = cute::make_tensor(blockscale_block_A_host.data() + offset_blockscale_A.at(group_idx),
|
||||
layout_SFA_host.at(group_idx));
|
||||
auto SFB = cute::make_tensor(blockscale_block_B_host.data() + offset_blockscale_B.at(group_idx),
|
||||
layout_SFB_host.at(group_idx));
|
||||
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator,
|
||||
decltype(A),
|
||||
decltype(SFA),
|
||||
decltype(B),
|
||||
decltype(SFB)
|
||||
> mainloop_params{A, SFA, B, SFB};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
ElementScalar,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D),
|
||||
unused_t, // bias
|
||||
unused_t, // Aux
|
||||
unused_t, // valpha
|
||||
unused_t // vbeta
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.alpha = alpha_host.at(group_idx);
|
||||
epilogue_params.beta = beta_host.at(group_idx);
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
auto this_group_passed = std::equal(
|
||||
// std::execution::par_unseq,
|
||||
block_D_host_ref.data() + offset_D.at(group_idx),
|
||||
block_D_host_ref.data() + offset_D.at(group_idx) + m * n,
|
||||
block_D_host_kernel.data() + offset_D.at(group_idx)
|
||||
);
|
||||
|
||||
passed &= this_group_passed;
|
||||
|
||||
#if 0
|
||||
std::cout << "Group: " << group_idx << " M: " << m << " N: " << n << " K: " << k << " Status: " << this_group_passed << std::endl;
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename OptionType>
|
||||
int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<typename Gemm::Arguments>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
|
||||
std::cerr << "This example requires CUDA 12.3 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options<RasterOrderOptions, ProblemShape> options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
std::cout << "Running tests with host problem shapes:" << std::endl;
|
||||
run(options, true);
|
||||
std::cout << "Running tests without host problem shapes:" << std::endl;
|
||||
run(options, false);
|
||||
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,781 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Grouped scale Hopper FP8 Grouped GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
|
||||
This example demonstrates a grouped scaled FP8 Grouped GEMM using the new CUTLASS 3.0.
|
||||
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
|
||||
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
|
||||
which are more efficient than the Ampere tensor core instructions.
|
||||
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
|
||||
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
|
||||
copies between thread blocks in a cluster. This example also showcases on-the-fly modification of TMA
|
||||
descriptors to move between groups/problem_count (represented by groups).
|
||||
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
|
||||
4. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
|
||||
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
|
||||
improve performance.
|
||||
5. This example is tuned specifically for the sparse groups case, where the number of active groups (groups
|
||||
with non-zero problem count) is much smaller than the total number of groups.
|
||||
Examples:
|
||||
$ ./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups \
|
||||
--m=2816 --n=3072 --k=16384 --save_aux=false --save_amax=false \
|
||||
--raster=h --swizzle=2 --benchmark=./test_benchmark.txt
|
||||
|
||||
Where the test_benchmark.txt may look as such:
|
||||
0 256x512x128
|
||||
1 256x512x512
|
||||
2 512x256x128
|
||||
3 256x256x128
|
||||
4 256x512x1024
|
||||
5 1024x512x128 and so on
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <cfloat>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
// Includes from examples directory
|
||||
#include "helper.h"
|
||||
#include "hopper_fp8_commandline.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementBlockScale = float; // Element type for blockscaling during accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
|
||||
using TileShape = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()...
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
static constexpr int ScaleGranularityM = 1;
|
||||
static constexpr int ScaleGranularityN = 128;
|
||||
static constexpr int ScaleGranularityK = 128;
|
||||
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementD, LayoutD *, AlignmentD,
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, cute::tuple<LayoutA *, LayoutSFA *>, AlignmentA,
|
||||
ElementB, cute::tuple<LayoutB *, LayoutSFB *>, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopWithGroupWiseScaling,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Extract information from Gemm kernel.
|
||||
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
|
||||
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
|
||||
using ActivationFunctor = typename EpilogueOutputOp::ActivationFn;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
|
||||
"ElementAccumulator and ElementBlockScale should be same datatype");
|
||||
|
||||
/// Initialization
|
||||
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
std::vector<int64_t> offset_A;
|
||||
std::vector<int64_t> offset_B;
|
||||
std::vector<int64_t> offset_C;
|
||||
std::vector<int64_t> offset_D;
|
||||
std::vector<int64_t> offset_blockscale_A;
|
||||
std::vector<int64_t> offset_blockscale_B;
|
||||
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
std::vector<LayoutSFA> layout_SFA_host;
|
||||
std::vector<LayoutSFB> layout_SFB_host;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<ElementA> block_A;
|
||||
cutlass::DeviceAllocation<ElementB> block_B;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<ElementD> block_D;
|
||||
cutlass::DeviceAllocation<ElementBlockScale> blockscale_block_A;
|
||||
cutlass::DeviceAllocation<ElementBlockScale> blockscale_block_B;
|
||||
|
||||
cutlass::DeviceAllocation<const ElementA *> ptr_A;
|
||||
cutlass::DeviceAllocation<const ElementB *> ptr_B;
|
||||
cutlass::DeviceAllocation<const ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<ElementD *> ptr_D;
|
||||
cutlass::DeviceAllocation<ElementD *> ptr_ref_D;
|
||||
cutlass::DeviceAllocation<const ElementBlockScale *> ptr_blockscale_A;
|
||||
cutlass::DeviceAllocation<const ElementBlockScale *> ptr_blockscale_B;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
cutlass::DeviceAllocation<LayoutSFA> layout_SFA;
|
||||
cutlass::DeviceAllocation<LayoutSFB> layout_SFB;
|
||||
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams<Shape<int,int,int>>::RasterOrderOptions;
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
double gbps;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
double gbps = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), gbps(gbps), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element, class ScopeMin = std::nullopt_t, class ScopeMax = std::nullopt_t>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023,
|
||||
ScopeMin scope_min = std::nullopt, ScopeMax scope_max = std::nullopt) {
|
||||
|
||||
double _scope_max, _scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
if (bits_input == 1) {
|
||||
_scope_max = 2;
|
||||
_scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
_scope_max = 2;
|
||||
_scope_min = -2;
|
||||
} else if (bits_input == 16) {
|
||||
_scope_max = 5;
|
||||
_scope_min = -5;
|
||||
} else {
|
||||
_scope_max = 8;
|
||||
_scope_min = -8;
|
||||
}
|
||||
if constexpr (!std::is_same_v<ScopeMax, std::nullopt_t>) {
|
||||
_scope_max = scope_max;
|
||||
}
|
||||
if constexpr (!std::is_same_v<ScopeMin, std::nullopt_t>) {
|
||||
_scope_min = scope_min;
|
||||
}
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) _scope_max, (Element) _scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Allocates device-side data
|
||||
template <typename OptionType>
|
||||
void allocate(const OptionType &options) {
|
||||
|
||||
int64_t total_elements_A = 0;
|
||||
int64_t total_elements_B = 0;
|
||||
int64_t total_elements_C = 0;
|
||||
int64_t total_elements_D = 0;
|
||||
int64_t total_elements_blockscale_A = 0;
|
||||
int64_t total_elements_blockscale_B = 0;
|
||||
|
||||
offset_A.clear();
|
||||
offset_B.clear();
|
||||
offset_C.clear();
|
||||
offset_D.clear();
|
||||
offset_blockscale_A.clear();
|
||||
offset_blockscale_B.clear();
|
||||
stride_A_host.clear();
|
||||
stride_B_host.clear();
|
||||
stride_C_host.clear();
|
||||
stride_D_host.clear();
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
|
||||
auto problem = options.problem_sizes_after_alignment_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
auto group_layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1));
|
||||
auto group_layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1));
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B);
|
||||
offset_C.push_back(total_elements_C);
|
||||
offset_D.push_back(total_elements_D);
|
||||
offset_blockscale_A.push_back(total_elements_blockscale_A);
|
||||
offset_blockscale_B.push_back(total_elements_blockscale_B);
|
||||
|
||||
int64_t elements_A = M * K;
|
||||
int64_t elements_B = K * N;
|
||||
int64_t elements_C = M * N;
|
||||
int64_t elements_D = M * N;
|
||||
int64_t elements_blockscale_A = size(filter_zeros(group_layout_SFA));
|
||||
int64_t elements_blockscale_B = size(filter_zeros(group_layout_SFB));
|
||||
|
||||
total_elements_A += elements_A;
|
||||
total_elements_B += elements_B;
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
total_elements_blockscale_A += elements_blockscale_A;
|
||||
total_elements_blockscale_B += elements_blockscale_B;
|
||||
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
|
||||
layout_SFA_host.push_back(group_layout_SFA);
|
||||
layout_SFB_host.push_back(group_layout_SFB);
|
||||
|
||||
}
|
||||
|
||||
block_A.reset(total_elements_A);
|
||||
block_B.reset(total_elements_B);
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
blockscale_block_A.reset(total_elements_blockscale_A);
|
||||
blockscale_block_B.reset(total_elements_blockscale_B);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
template <typename OptionType>
|
||||
void initialize(const OptionType &options) {
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_after_alignment_host.data());
|
||||
|
||||
std::vector<ElementA *> ptr_A_host(options.groups);
|
||||
std::vector<ElementB *> ptr_B_host(options.groups);
|
||||
std::vector<ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<ElementD *> ptr_D_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
std::vector<ElementBlockScale *> ptr_blockscale_A_host(options.groups);
|
||||
std::vector<ElementBlockScale *> ptr_blockscale_B_host(options.groups);
|
||||
|
||||
alpha_host.clear();
|
||||
beta_host.clear();
|
||||
|
||||
for (int i = 0; i < options.groups; i++) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
|
||||
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
ptr_blockscale_A.reset(options.groups);
|
||||
ptr_blockscale_A.copy_from_host(ptr_blockscale_A_host.data());
|
||||
|
||||
ptr_blockscale_B.reset(options.groups);
|
||||
ptr_blockscale_B.copy_from_host(ptr_blockscale_B_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
layout_SFA.reset(options.groups);
|
||||
layout_SFA.copy_from_host(layout_SFA_host.data());
|
||||
|
||||
layout_SFB.reset(options.groups);
|
||||
layout_SFB.copy_from_host(layout_SFB_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_block(block_A, seed + 2022);
|
||||
initialize_block(block_B, seed + 2023);
|
||||
initialize_block(block_C, seed + 2024);
|
||||
initialize_block(blockscale_block_A, seed + 2025, -1, 1);
|
||||
initialize_block(blockscale_block_B, seed + 2026, -1, 1);
|
||||
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template<typename GemmArguments, typename OptionType>
|
||||
GemmArguments args_from_options(const OptionType &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
int device_id = 0;
|
||||
cutlass::KernelHardwareInfo kernel_hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info<typename Gemm::GemmKernel>(device_id);
|
||||
|
||||
GemmArguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), host_problem_shapes_available ? options.problem_sizes_after_alignment_host.data() : (decltype(options.problem_sizes_after_alignment_host.data())) nullptr},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(),
|
||||
ptr_blockscale_A.get(), layout_SFA.get(),
|
||||
ptr_blockscale_B.get(), layout_SFB.get()
|
||||
},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
ptr_C.get(), stride_C.get(),
|
||||
ptr_D.get(), stride_D.get()
|
||||
},
|
||||
kernel_hw_info
|
||||
};
|
||||
|
||||
auto &fusion_args = arguments.epilogue.thread;
|
||||
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
|
||||
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
// Single alpha and beta for all groups
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
}
|
||||
else {
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// One alpha and beta per each group
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
|
||||
arguments.scheduler.raster_order = options.raster;
|
||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename OptionType>
|
||||
bool verify(const OptionType &options) {
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
std::vector<ElementA> block_A_host(block_A.size());
|
||||
std::vector<ElementB> block_B_host(block_B.size());
|
||||
std::vector<ElementC> block_C_host(block_C.size());
|
||||
std::vector<ElementD> block_D_host_kernel(block_D.size());
|
||||
std::vector<ElementD> block_D_host_ref(block_D.size());
|
||||
std::vector<ElementBlockScale> blockscale_block_A_host(blockscale_block_A.size());
|
||||
std::vector<ElementBlockScale> blockscale_block_B_host(blockscale_block_B.size());
|
||||
|
||||
block_A.copy_to_host(block_A_host.data());
|
||||
block_B.copy_to_host(block_B_host.data());
|
||||
block_C.copy_to_host(block_C_host.data());
|
||||
block_D.copy_to_host(block_D_host_kernel.data());
|
||||
blockscale_block_A.copy_to_host(blockscale_block_A_host.data());
|
||||
blockscale_block_B.copy_to_host(blockscale_block_B_host.data());
|
||||
|
||||
bool passed = true;
|
||||
for (int group_idx = 0; group_idx < options.groups; group_idx++) {
|
||||
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
|
||||
auto [m, n, k] = options.problem_sizes_after_alignment_host.at(group_idx);
|
||||
auto gemm_problem_shape = cute::make_shape(m, n, k);
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(m, k, 1),
|
||||
stride_A_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
auto B = cute::make_tensor(block_B_host.data() + offset_B.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(n, k, 1),
|
||||
stride_B_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
auto C = cute::make_tensor(block_C_host.data() + offset_C.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(m, n, 1),
|
||||
stride_C_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
auto D = cute::make_tensor(block_D_host_ref.data() + offset_D.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(m, n, 1),
|
||||
stride_D_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
|
||||
auto SFA = cute::make_tensor(blockscale_block_A_host.data() + offset_blockscale_A.at(group_idx),
|
||||
layout_SFA_host.at(group_idx));
|
||||
auto SFB = cute::make_tensor(blockscale_block_B_host.data() + offset_blockscale_B.at(group_idx),
|
||||
layout_SFB_host.at(group_idx));
|
||||
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator,
|
||||
decltype(A),
|
||||
decltype(SFA),
|
||||
decltype(B),
|
||||
decltype(SFB)
|
||||
> mainloop_params{A, SFA, B, SFB};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
ElementScalar,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D)
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.alpha = alpha_host.at(group_idx);
|
||||
epilogue_params.beta = beta_host.at(group_idx);
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
auto this_group_passed = std::equal(
|
||||
// std::execution::par_unseq,
|
||||
block_D_host_ref.data() + offset_D.at(group_idx),
|
||||
block_D_host_ref.data() + offset_D.at(group_idx) + m * n,
|
||||
block_D_host_kernel.data() + offset_D.at(group_idx)
|
||||
);
|
||||
|
||||
passed &= this_group_passed;
|
||||
|
||||
#if 0
|
||||
std::cout << "Group: " << group_idx << " M: " << m << " N: " << n << " K: " << k << " Status: " << this_group_passed << std::endl;
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename OptionType>
|
||||
int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<typename Gemm::Arguments>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0) {
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
result.gbps = options.template gbps<ElementA,
|
||||
ElementB,
|
||||
ElementC,
|
||||
ElementD,
|
||||
ElementBlockScale,
|
||||
TileShape,
|
||||
ScaleMsPerTile,
|
||||
ScaleNsPerTile>(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
|
||||
std::cerr << "This example requires CUDA 12.3 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options<RasterOrderOptions, ProblemShape> options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
run(options, true);
|
||||
|
||||
std::cout << "Running tests without host problem shapes:" << std::endl;
|
||||
run(options, false);
|
||||
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,84 +0,0 @@
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
|
||||
# Only the correctness check will be run by these commands.
|
||||
|
||||
set(TEST_RANDOM --iterations=0) # Random problem sizes
|
||||
set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_FIXED --m=2048 --n=5120 --k=512 --groups=50 --iterations=0) # Fixed problem sizes
|
||||
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0) # Fixed problem sizes
|
||||
|
||||
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
|
||||
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=500 --iterations=0) # Small problem sizes
|
||||
|
||||
cutlass_example_add_executable(
|
||||
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling
|
||||
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
)
|
||||
|
||||
# MSVC will fail to compile this example with the following error:
|
||||
# fatal error C1083: Cannot open source file: <Some Mojibake>: No such file or directory [...\examples\68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling\68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.vcxproj]
|
||||
# This is a known issue and we are working on a fix.
|
||||
if (NOT MSVC)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups
|
||||
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
)
|
||||
|
||||
endif()
|
||||
@ -1,271 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
// Command line options parsing
|
||||
template<typename _RasterOrderOptions, typename _ProblemShape>
|
||||
struct Options {
|
||||
|
||||
using RasterOrderOptions = _RasterOrderOptions;
|
||||
using ProblemShape = _ProblemShape;
|
||||
|
||||
bool help = false;
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
int iterations = 1000;
|
||||
int m = 1024, n = 512, k = 1024, groups = 10;
|
||||
std::string benchmark_path;
|
||||
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_after_alignment_host;
|
||||
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
|
||||
int const tma_alignment_bits = 128;
|
||||
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<cutlass::float_e4m3_t>::value;
|
||||
int const k_alignment = 128;
|
||||
int const m_alignment = 128;
|
||||
int const n_alignment = 128;
|
||||
|
||||
RasterOrderOptions raster;
|
||||
int swizzle;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("groups", groups);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
char raster_char;
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
|
||||
if (raster_char == 'N' || raster_char == 'n') {
|
||||
raster = RasterOrderOptions::AlongN;
|
||||
}
|
||||
else if (raster_char == 'M' || raster_char == 'm') {
|
||||
raster = RasterOrderOptions::AlongM;
|
||||
}
|
||||
else if (raster_char == 'H' || raster_char == 'h') {
|
||||
raster = RasterOrderOptions::Heuristic;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle, 1);
|
||||
cmd.get_cmd_line_argument("benchmark", benchmark_path);
|
||||
|
||||
// Decide how to initialize the problems
|
||||
if (!benchmark_path.empty()) {
|
||||
if (!benchmark_problems()) {
|
||||
problem_sizes_after_alignment_host.clear();
|
||||
problem_sizes_host.clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
else {
|
||||
randomize_problems(cmd);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void randomize_problems(cutlass::CommandLine &cmd) {
|
||||
int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
|
||||
cmd.get_cmd_line_argument("m", cmd_line_m);
|
||||
cmd.get_cmd_line_argument("n", cmd_line_n);
|
||||
cmd.get_cmd_line_argument("k", cmd_line_k);
|
||||
|
||||
problem_sizes_after_alignment_host.reserve(groups);
|
||||
problem_sizes_host.reserve(groups);
|
||||
for (int i = groups; i > 0; i--) {
|
||||
int m = cmd_line_m;
|
||||
int n = cmd_line_n;
|
||||
int k = cmd_line_k;
|
||||
if (m < 1) {
|
||||
m = m_alignment * ((rand() % (64 * alignment / m_alignment)) + 1);
|
||||
}
|
||||
if (n < 1) {
|
||||
n = n_alignment * ((rand() % (64 * alignment / n_alignment)) + 1);
|
||||
}
|
||||
if (k < 1) {
|
||||
k = k_alignment * ((rand() % (32 * alignment / k_alignment)) + 1);
|
||||
}
|
||||
problem_sizes_after_alignment_host.push_back({m, n, k});
|
||||
problem_sizes_host.push_back({m, n, k});
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a benchmark
|
||||
bool benchmark_problems() {
|
||||
std::ifstream file(benchmark_path);
|
||||
if (!file.good()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
while (file.good()) {
|
||||
|
||||
int idx = -1;
|
||||
std::string extent_str;
|
||||
|
||||
file >> idx >> extent_str;
|
||||
|
||||
if (idx < 0 || extent_str.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
cutlass::gemm::GemmCoord extent_after_alignment, extent;
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
|
||||
|
||||
for (int i = 0; i < int(tokens.size()); ++i) {
|
||||
int x = std::atoi(tokens.at(i).c_str());
|
||||
|
||||
extent.at(i) = x;
|
||||
// round up
|
||||
if (x % alignment) {
|
||||
x += (alignment - (x % alignment));
|
||||
}
|
||||
|
||||
extent_after_alignment.at(i) = x;
|
||||
}
|
||||
|
||||
problem_sizes_after_alignment_host.push_back({extent_after_alignment.m(), extent_after_alignment.n(), extent_after_alignment.k()});
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
}
|
||||
groups = static_cast<int>(problem_sizes_after_alignment_host.size());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Calculate memory bandwidth statistics
|
||||
template <class ElementA,
|
||||
class ElementB,
|
||||
class ElementC,
|
||||
class ElementD,
|
||||
class ElementBlockScale,
|
||||
class TileShape,
|
||||
int ScaleMsPerTile,
|
||||
int ScaleNsPerTile>
|
||||
auto gbps(double runtime_s) const {
|
||||
double total_read_bytes = 0;
|
||||
double total_write_bytes = 0;
|
||||
|
||||
// Calculate bytes read and written for each problem
|
||||
for (int i = 0; i < groups; ++i) {
|
||||
auto problem = problem_sizes_host.at(i);
|
||||
auto M = cute::get<0>(problem);
|
||||
auto N = cute::get<1>(problem);
|
||||
auto K = cute::get<2>(problem);
|
||||
|
||||
if (M > 0) { // Only count active problems
|
||||
// Matrix A: M*K elements read
|
||||
total_read_bytes += M * K * sizeof(ElementA);
|
||||
|
||||
// Matrix B: K*N elements read
|
||||
total_read_bytes += K * N * sizeof(ElementB);
|
||||
|
||||
// Matrix C: M*N elements read (for beta operation)
|
||||
total_read_bytes += M * N * sizeof(ElementC);
|
||||
|
||||
// Block scales for A and B
|
||||
auto blockscale_shape = cute::shape(cute::get<1>(cute::zipped_divide(cute::make_layout(problem), TileShape{})));
|
||||
auto blockscale_m = cute::get<0>(blockscale_shape);
|
||||
auto blockscale_n = cute::get<1>(blockscale_shape);
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
auto groupscale_m = blockscale_m * ScaleMsPerTile;
|
||||
auto groupscale_n = blockscale_n * ScaleNsPerTile;
|
||||
|
||||
total_read_bytes += groupscale_m * blockscale_k * sizeof(ElementBlockScale); // A scales
|
||||
total_read_bytes += groupscale_n * blockscale_k * sizeof(ElementBlockScale); // B scales
|
||||
|
||||
// Matrix D: M*N elements written
|
||||
total_write_bytes += M * N * sizeof(ElementD);
|
||||
}
|
||||
}
|
||||
|
||||
return (total_read_bytes + total_write_bytes) / 1.0e9 / runtime_s;
|
||||
}
|
||||
|
||||
double bandwidth_util(double eff_bandwidth) const {
|
||||
int memoryClockRate;
|
||||
int memoryBusWidth;
|
||||
cudaDeviceGetAttribute(&memoryClockRate, cudaDevAttrMemoryClockRate, 0);
|
||||
cudaDeviceGetAttribute(&memoryBusWidth, cudaDevAttrGlobalMemoryBusWidth , 0);
|
||||
double bw = 2.0 * memoryClockRate * (memoryBusWidth / 8) / 1.0e6;
|
||||
return eff_bandwidth / bw * 100.0;
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling\n\n"
|
||||
<< " Hopper FP8 Grouped GEMM using a Warp Specialized kernel with Blockwise Scaling.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
|
||||
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size.\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Number of real-valued multiply-adds
|
||||
uint64_t fmas = 0ull;
|
||||
|
||||
for (auto const [m, n, k] : problem_sizes_host) {
|
||||
fmas += static_cast<uint64_t>(m) *
|
||||
static_cast<uint64_t>(n) *
|
||||
static_cast<uint64_t>(k);
|
||||
}
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * uint64_t(fmas);
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
@ -374,7 +374,7 @@ void allocate(Options const& options) {
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
const int scale_k = 1;
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
|
||||
@ -510,7 +510,7 @@ void initialize(Options &options) {
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_tensor(block_A, seed + 2023);
|
||||
initialize_tensor(block_B, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2022);
|
||||
initialize_tensor(block_C, seed + 2021);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
@ -519,13 +519,13 @@ void initialize(Options &options) {
|
||||
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
const int scale_k = 1;
|
||||
auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{});
|
||||
auto shape_scale = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), scale_k, Int<1>{});
|
||||
auto layout_B = make_layout(shape_B, stride_B_host.at(i));
|
||||
auto layout_scale = make_layout(shape_scale, stride_S_host_ref.at(i));
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale, options.c, stream);
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale, options.k, stream);
|
||||
}
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
@ -619,7 +619,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro
|
||||
arguments = Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.c},
|
||||
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
@ -676,7 +676,6 @@ bool verify(Options const& options) {
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
// we don't swap and transpose in the verify so revert the problem shape.
|
||||
auto N = get<0>(problem);
|
||||
auto M = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
@ -713,7 +712,7 @@ bool verify(Options const& options) {
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
|
||||
std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl;
|
||||
std::cout << "Group: " << i << " Status: " << passed << std::endl;
|
||||
}
|
||||
}
|
||||
return passed;
|
||||
|
||||
@ -341,7 +341,7 @@ void allocate(Options const& options) {
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
const int scale_k = 1;
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
|
||||
@ -479,7 +479,7 @@ void initialize(Options& options) {
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_tensor(block_A, seed + 2023);
|
||||
initialize_tensor(block_B, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2022);
|
||||
cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size());
|
||||
initialize_tensor(block_C, seed + 2021);
|
||||
initialize_scale(block_scale, options);
|
||||
@ -565,7 +565,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro
|
||||
arguments = Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale_packed.get(), stride_S.get(), options.c},
|
||||
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale_packed.get(), stride_S.get(), options.k},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
@ -617,7 +617,6 @@ bool verify(Options const& options) {
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
// we don't swap and transpose in the verify so revert the problem shape.
|
||||
auto N = get<0>(problem);
|
||||
auto M = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
@ -631,11 +630,11 @@ bool verify(Options const& options) {
|
||||
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
|
||||
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
|
||||
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
const int scale_k = 1;
|
||||
auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i));
|
||||
auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i));
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.c, stream);
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream);
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
@ -660,7 +659,7 @@ bool verify(Options const& options) {
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
|
||||
std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl;
|
||||
std::cout << "Group: " << i << " Status: " << passed << std::endl;
|
||||
}
|
||||
}
|
||||
return passed;
|
||||
|
||||
@ -282,7 +282,7 @@ void allocate(Options const& options) {
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
const int scale_k = 1;
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
|
||||
@ -418,7 +418,7 @@ void initialize(Options &options) {
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_tensor(block_A, seed + 2023);
|
||||
initialize_tensor(block_B, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2022);
|
||||
initialize_tensor(block_C, seed + 2021);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
@ -485,7 +485,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro
|
||||
arguments = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.c},
|
||||
{ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
@ -542,7 +542,6 @@ bool verify(Options const& options) {
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
// we don't swap and transpose in the verify so revert the problem shape.
|
||||
auto N = get<0>(problem);
|
||||
auto M = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
@ -556,11 +555,11 @@ bool verify(Options const& options) {
|
||||
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
|
||||
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
|
||||
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
const int scale_k = 1;
|
||||
auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i));
|
||||
auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i));
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.c, stream);
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream);
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
@ -585,7 +584,7 @@ bool verify(Options const& options) {
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
|
||||
std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl;
|
||||
std::cout << "Group: " << i << " Status: " << passed << std::endl;
|
||||
}
|
||||
}
|
||||
return passed;
|
||||
|
||||
@ -50,7 +50,6 @@ set(TEST_RANDOM_PERF_LARGE_GROUP --groups=100 --iterations=10)
|
||||
set(TEST_DIRECT_BATCHED --m=2048 --n=5120 --k=8192 --mode=0 --iterations=0) # Direct conversion
|
||||
|
||||
set(TEST_SCALE_PERCOL --m=4096 --n=5120 --k=8192 --c=8192 --mode=1 --iterations=0) # Per Column scaling
|
||||
set(TEST_SCALE_GROUP --m=2048 --n=5120 --k=8192 --c=512 --mode=1 --iterations=0) # Group-wise scaling
|
||||
|
||||
cutlass_example_add_executable(
|
||||
69_hopper_mixed_dtype_grouped_gemm
|
||||
@ -70,7 +69,6 @@ cutlass_example_add_executable(
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
TEST_SCALE_GROUP
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
@ -91,7 +89,6 @@ cutlass_example_add_executable(
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
TEST_SCALE_GROUP
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
@ -112,5 +109,4 @@ cutlass_example_add_executable(
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
TEST_SCALE_GROUP
|
||||
)
|
||||
|
||||
@ -7,40 +7,8 @@ This example shows how to perform Grouped GEMMs on Hopper when A and B have diff
|
||||
- in the arguments, pass the group size, array of the problem sizes, and the array of strides for matrix A and B.
|
||||
- if scales and zero-points are included, also pass the array of their strides in the arguments.
|
||||
|
||||
Note that in Example 55, the argument `--g` is used to determine the group size of scaling. To avoid confusion with the `--groups` argument in this example, which defines the number of GEMMs, `--c` is used here to represent the group size for scaling.
|
||||
Note that in Example 55, the argument `--g` is used to determine the block scale size. It is important not to confuse this with the `--groups` argument in this example, which specifies the number of GEMMs.
|
||||
|
||||
## Upcoming features
|
||||
|
||||
Currently, the Mixed-input Grouped GEMM only supports row-wise scaling, and group-wise scaling for identical problem shapes across all groups. Please contact us if zero-points or block-wise scaling are needed.
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
Currently, the Mixed-input Grouped GEMM only supports row-wise scaling. Please contact us if zero-points or block-wise scaling are needed.
|
||||
|
||||
@ -58,7 +58,6 @@ public:
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
cmd.get_cmd_line_argument("groups", groups);
|
||||
cmd.get_cmd_line_argument("benchmark", benchmark_path);
|
||||
cmd.get_cmd_line_argument("c", c);
|
||||
MixedDtypeOptions::parse(argc, args);
|
||||
|
||||
@ -72,7 +71,6 @@ public:
|
||||
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
|
||||
<< " --c=<int> Sets the chunk size for scaling the quantized weights\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems\n"
|
||||
<< " --mode=<int> The mode to run the gemm\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
@ -185,6 +183,11 @@ void grouped_mixed_dtype_profiling(
|
||||
|
||||
result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size();
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta\n";
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host[i] << ", " << alpha_host[i] << ", " << beta_host[i] << '\n';
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << '\n'
|
||||
<< " Avg runtime : " << result.avg_runtime_ms << " ms\n"
|
||||
<< " GFLOPS : " << result.gflops << '\n';
|
||||
|
||||
@ -194,14 +194,12 @@ struct Options {
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int swizzle;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(8192), n(8192), k(8192),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10),
|
||||
swizzle(0)
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
@ -219,7 +217,6 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
@ -234,7 +231,6 @@ struct Options {
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
@ -335,8 +331,6 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
|
||||
@ -231,7 +231,6 @@ struct Options {
|
||||
bool save_amax = true;
|
||||
int iterations = 1000;
|
||||
int m = 1024, n = 512, k = 1024, l = 1;
|
||||
int swizzle = 0;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
@ -257,7 +256,6 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("save_aux", save_aux, true);
|
||||
cmd.get_cmd_line_argument("save_amax", save_amax, true);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
@ -273,7 +271,6 @@ struct Options {
|
||||
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n"
|
||||
<< " --scale_a=<f32> Scaling factor for A\n"
|
||||
<< " --scale_b=<f32> Scaling factor for B\n"
|
||||
<< " --scale_c=<f32> Scaling factor for C\n"
|
||||
@ -479,8 +476,6 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
fusion_args.amax_D_ptr = abs_max_D.device_data();
|
||||
}
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
|
||||
@ -28,29 +28,14 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
set(TEST_SWIZZLE_1 --swizzle=1)
|
||||
set(TEST_SWIZZLE_2 --swizzle=2)
|
||||
set(TEST_SWIZZLE_5 --swizzle=5)
|
||||
set(TEST_SWIZZLE_5_UNEVEN --swizzle=5 --m=4096 --n=16384)
|
||||
|
||||
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
cutlass_example_add_executable(
|
||||
70_blackwell_fp16_gemm
|
||||
70_blackwell_fp16_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_SWIZZLE_1
|
||||
TEST_SWIZZLE_2
|
||||
TEST_SWIZZLE_5
|
||||
TEST_SWIZZLE_5_UNEVEN
|
||||
)
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
70_blackwell_fp8_gemm
|
||||
70_blackwell_fp8_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_SWIZZLE_1
|
||||
TEST_SWIZZLE_2
|
||||
TEST_SWIZZLE_5
|
||||
TEST_SWIZZLE_5_UNEVEN
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -74,14 +74,12 @@ struct Options {
|
||||
|
||||
int m, n, k, l;
|
||||
float alpha, beta;
|
||||
int swizzle;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
error(false),
|
||||
m(2048), n(2048), k(2048), l(1),
|
||||
alpha(1.f), beta(0.f),
|
||||
swizzle(0)
|
||||
alpha(1.f), beta(0.f)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
@ -99,7 +97,6 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("l", l, 1);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
@ -115,8 +112,7 @@ struct Options {
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n\n";
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
@ -356,8 +352,6 @@ struct ExampleRunner {
|
||||
hw_info
|
||||
};
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
// See example 48 for details on custom EVT construction
|
||||
if constexpr (UseCustomEVT) {
|
||||
arguments.epilogue.thread =
|
||||
|
||||
@ -211,14 +211,12 @@ struct Options {
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int swizzle = 0;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10),
|
||||
swizzle(0)
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
@ -236,7 +234,6 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
@ -250,8 +247,7 @@ struct Options {
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
@ -337,7 +333,7 @@ bool initialize_block(
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
@ -348,8 +344,8 @@ void initialize(const Options &options) {
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
@ -391,7 +387,6 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
}
|
||||
};
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
return arguments;
|
||||
}
|
||||
|
||||
|
||||
@ -177,7 +177,7 @@ using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
|
||||
|
||||
using FusionOp = typename Gemm::EpilogueOutputOp;
|
||||
constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported;
|
||||
using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<OutputSFVectorSize>;
|
||||
using SfdOutputCfg = cutlass::detail::Sm100BlockScaledOutputConfig<OutputSFVectorSize>;
|
||||
using LayoutSFD = typename SfdOutputCfg::LayoutSF;
|
||||
|
||||
//
|
||||
@ -240,14 +240,12 @@ struct Options {
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int swizzle = 0;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10),
|
||||
swizzle(0)
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
@ -265,7 +263,6 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
@ -279,8 +276,7 @@ struct Options {
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
@ -366,9 +362,9 @@ bool initialize_block(
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
// For SFD tensor layout
|
||||
using Sm1xxBlockScaledOutputConfig= typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
using Sm100BlockScaledOutputConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
@ -379,8 +375,8 @@ void initialize(const Options &options) {
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
@ -436,7 +432,6 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data();
|
||||
}
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
return arguments;
|
||||
}
|
||||
|
||||
@ -480,12 +475,7 @@ bool verify(const Options &options) {
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
|
||||
|
||||
block_SFD.sync_host();
|
||||
bool passed_sfd = cutlass::reference::host::TensorEquals(block_reference_SFD.host_view(), block_SFD.host_view());
|
||||
passed_sfd &= (cutlass::reference::host::TensorNorm(block_reference_SFD.host_view()) > 0);
|
||||
passed_sfd &= (cutlass::reference::host::TensorNorm(block_SFD.host_view()) > 0);
|
||||
|
||||
return passed && passed_sfd;
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
|
||||
@ -212,14 +212,12 @@ struct Options {
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int swizzle = 0;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10),
|
||||
swizzle(0)
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
@ -237,7 +235,6 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
@ -251,8 +248,7 @@ struct Options {
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
@ -338,7 +334,7 @@ bool initialize_block(
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
@ -349,8 +345,8 @@ void initialize(const Options &options) {
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
@ -392,7 +388,6 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
}
|
||||
};
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
return arguments;
|
||||
}
|
||||
|
||||
|
||||
@ -214,8 +214,7 @@ struct Options {
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int preferred_cluster_m, preferred_cluster_n, fallback_cluster_m, fallback_cluster_n;
|
||||
int swizzle = 0;
|
||||
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(4096), n(4096), k(4096),
|
||||
@ -224,8 +223,7 @@ struct Options {
|
||||
preferred_cluster_m(4),
|
||||
preferred_cluster_n(4),
|
||||
fallback_cluster_m(2),
|
||||
fallback_cluster_n(1),
|
||||
swizzle(0)
|
||||
fallback_cluster_n(1)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
@ -247,7 +245,6 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("preferred_cluster_n", preferred_cluster_n, 4);
|
||||
cmd.get_cmd_line_argument("fallback_cluster_m", fallback_cluster_m, 2);
|
||||
cmd.get_cmd_line_argument("fallback_cluster_n", fallback_cluster_n, 1);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
|
||||
if (!validate_cluster_shape()){
|
||||
std::cout << "--Invalid cluster shapes" << std::endl;
|
||||
@ -268,7 +265,6 @@ struct Options {
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n"
|
||||
<< " --preferred_cluster_m=<str> Sets the M extent of preferred cluster shape\n"
|
||||
<< " --preferred_cluster_n=<str> Sets the N extent of preferred cluster shape\n"
|
||||
<< " --fallback_cluster_m=<str> Sets the M extent of fallback cluster shape\n"
|
||||
@ -388,8 +384,7 @@ typename Gemm::Arguments args_from_options(const Options &options) {
|
||||
|
||||
arguments.hw_info.cluster_shape = dim3(options.preferred_cluster_m, options.preferred_cluster_n,1);
|
||||
arguments.hw_info.cluster_shape_fallback = dim3(options.fallback_cluster_m, options.fallback_cluster_n,1);
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
|
||||
@ -242,7 +242,6 @@ using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTil
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool use_pdl = false;
|
||||
|
||||
float alpha = FLT_MAX;
|
||||
float beta = FLT_MAX;
|
||||
@ -265,9 +264,6 @@ struct Options {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
if (cmd.check_cmd_line_flag("use_pdl")) {
|
||||
use_pdl = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
@ -391,8 +387,7 @@ struct Options {
|
||||
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M)\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size\n"
|
||||
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n"
|
||||
<< " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n";
|
||||
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
@ -716,7 +711,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
@ -735,7 +730,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
|
||||
@ -219,14 +219,14 @@ using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig<
|
||||
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
using Sm100BlockScaledOutputConfig = cutlass::detail::Sm100BlockScaledOutputConfig<
|
||||
OutputSFVectorSize,
|
||||
cute::is_same_v<typename FusionOperation::GmemLayoutTagScalefactor,
|
||||
cutlass::layout::RowMajor> ? cute::UMMA::Major::K : cute::UMMA::Major::MN
|
||||
>;
|
||||
using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom;
|
||||
using LayoutSFD = typename Sm1xxBlockScaledOutputConfig::LayoutSF;
|
||||
using OutputSFAtom = typename Sm100BlockScaledOutputConfig::SfAtom;
|
||||
using LayoutSFD = typename Sm100BlockScaledOutputConfig::LayoutSF;
|
||||
|
||||
// Host-side allocations
|
||||
std::vector<StrideA> stride_A_host;
|
||||
@ -305,7 +305,6 @@ struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool verification = true;
|
||||
bool use_pdl = false;
|
||||
|
||||
float alpha = FLT_MAX;
|
||||
float beta = FLT_MAX;
|
||||
@ -329,12 +328,9 @@ struct Options {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
if (cmd.check_cmd_line_flag("no_verif")) {
|
||||
if (cmd.check_cmd_line_flag("no-verif")) {
|
||||
verification = false;
|
||||
}
|
||||
if (cmd.check_cmd_line_flag("use_pdl")) {
|
||||
use_pdl = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
@ -461,8 +457,7 @@ struct Options {
|
||||
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size\n"
|
||||
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n"
|
||||
<< " --no_verif Do not run (host-side) verification kernels\n"
|
||||
<< " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n";
|
||||
<< " --no-verif Do not run (host-side) verification kernels\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
@ -559,9 +554,9 @@ void allocate(const Options &options) {
|
||||
auto layout_B = make_layout(make_shape(N, K, 1), stride_B);
|
||||
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
|
||||
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
|
||||
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFD = Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
|
||||
|
||||
stride_A_host.push_back(stride_A);
|
||||
stride_B_host.push_back(stride_B);
|
||||
@ -780,9 +775,9 @@ bool verify(const Options &options) {
|
||||
auto layout_B = make_layout(make_shape(N, K, 1), stride_B);
|
||||
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
|
||||
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
|
||||
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFD = Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
|
||||
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A.at(i).host_data()), layout_A);
|
||||
@ -850,7 +845,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
@ -875,7 +870,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
|
||||
@ -67,6 +67,9 @@
|
||||
--b=2048 --h=2048 --d=2048 --q=2048 --k=2048
|
||||
*/
|
||||
|
||||
#define DSHOW(x) print(#x ": "); print(x); print("\n");
|
||||
#define DSHOWT(x) print(#x ": "); print_tensor(x); print("\n");
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
@ -244,8 +247,8 @@ struct Options {
|
||||
<< " and are split B-ways, alternatingly +10% and -10%\n"
|
||||
<< " with the last batch sized to make it fit\n"
|
||||
<< " implies at least residual masking for correctness\n"
|
||||
<< " --sm-count Sets SM count rather than querying it\n"
|
||||
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
|
||||
<< " --sm-count Sets SM count rather than querying it\n"
|
||||
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
|
||||
<< "\n";
|
||||
|
||||
return out;
|
||||
|
||||
@ -1,865 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Example implementation of fused multi-head attention for Blackwell using CUTLASS 3.
|
||||
|
||||
This example showcases the use of CUTLASS to build backward fused
|
||||
multi-head attantion (FMHA) collectives from existing CUTLASS collectives targeting
|
||||
the NVIDIA Blackwell architecture.
|
||||
|
||||
Background and motivation
|
||||
-------------------------
|
||||
CUTLASS is a highly flexible library that provides open-source building blocks
|
||||
for tensor core programming for GEMM or GEMM-like problems. Fused multi-head
|
||||
attention (FMHA) is a foundational kernel for large language models (LLMs) since it
|
||||
makes long sequence lengths feasible from a memory-usage perspective. It also
|
||||
improves computational efficiency since it transforms an outer-product-like and
|
||||
a matrix-vector-like GEMM into a fused operation with much higher arithmetic
|
||||
intensity. For more details, see Dao et al, 2022; Dao, 2023.
|
||||
Implementing this kernel in CUTLASS enabled easy customization and high
|
||||
performance.
|
||||
|
||||
Introduction
|
||||
------------
|
||||
The example targets the NVIDIA Blackwell architecture, and takes advantage of
|
||||
5th gen tensor cores and the Tensor Memory Accelerator (TMA), just like
|
||||
GEMMs do. It provides a backward pass (often abbreviated
|
||||
bwd in the code).
|
||||
The code is structured into three layers: The runner (and the reference kernels)
|
||||
takes care of initialization, measurement, and testing; the device layer
|
||||
orchestrates kernel calls and partitions workspace; and the kernel layer (just
|
||||
like the CUTLASS kernel layer.
|
||||
|
||||
Support
|
||||
---------
|
||||
|
||||
We support fp16 and fp8 data types with a head dimension of 128.
|
||||
|
||||
Example usage:
|
||||
$ ./examples/77_blackwell_fmha/77_blackwell_fmha_bwd_fp16 \
|
||||
--b=2048 --h=2048 --d=2048 --q=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "reference/fmha_fwd_reference.hpp"
|
||||
#include "reference/fmha_bwd_reference.hpp"
|
||||
#include "reference/reference_abs_error.hpp"
|
||||
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
#include "device/fmha_device_bwd.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
using namespace cutlass::fmha::collective;
|
||||
using namespace cutlass::fmha;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class InitStyle {
|
||||
kOne, kZero, kLinearStride128, kLinearStride1, kRandom, kNone
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool error = false;
|
||||
|
||||
int b = 16;
|
||||
int h = 16;
|
||||
int h_k = 1;
|
||||
int q = 1024;
|
||||
int k = 1024;
|
||||
int d = 128;
|
||||
int iterations = 3;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
|
||||
bool causal = false;
|
||||
int sm_count = 0;
|
||||
|
||||
std::string kernel_filter;
|
||||
|
||||
InitStyle init_style_q = InitStyle::kRandom;
|
||||
InitStyle init_style_k = InitStyle::kRandom;
|
||||
InitStyle init_style_v = InitStyle::kRandom;
|
||||
InitStyle init_style_do = InitStyle::kRandom;
|
||||
bool skip_reference = false;
|
||||
|
||||
static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) {
|
||||
std::string s;
|
||||
cmd.get_cmd_line_argument(name, s, s);
|
||||
if (s.empty()) {
|
||||
dst = src;
|
||||
}
|
||||
else {
|
||||
if (s == "r") {
|
||||
dst = InitStyle::kRandom;
|
||||
}
|
||||
else if (s == "0") {
|
||||
dst = InitStyle::kZero;
|
||||
}
|
||||
else if (s == "1") {
|
||||
dst = InitStyle::kOne;
|
||||
}
|
||||
else if (s == "d") {
|
||||
dst = InitStyle::kLinearStride1;
|
||||
}
|
||||
else if (s == "s") {
|
||||
dst = InitStyle::kLinearStride128;
|
||||
}
|
||||
else if (s == "n") {
|
||||
dst = InitStyle::kNone;
|
||||
}
|
||||
else {
|
||||
std::cout << "Error: " << s << " is not a valid input type.\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
Options defaults;
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("d", d, defaults.d);
|
||||
cmd.get_cmd_line_argument("h", h, -1);
|
||||
if (h == -1) h = 2048 / d;
|
||||
|
||||
cmd.get_cmd_line_argument("q", q, -1);
|
||||
cmd.get_cmd_line_argument("k", k, -1);
|
||||
if (q == -1) q = k;
|
||||
if (k == -1) k = q;
|
||||
if (q == -1 && k == -1) q = k = defaults.q;
|
||||
|
||||
cmd.get_cmd_line_argument("b", b, -1);
|
||||
if (b == -1) b = 16384 / k;
|
||||
if (b == 0) b = 1;
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
std::string mask;
|
||||
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
||||
if (mask == "causal") {
|
||||
causal = true;
|
||||
}
|
||||
else {
|
||||
causal = defaults.causal;
|
||||
}
|
||||
|
||||
skip_reference = cmd.check_cmd_line_flag("skip-reference");
|
||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||
|
||||
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_k);
|
||||
get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_v);
|
||||
get_init_style_argument(cmd, "init-style", init_style_do, defaults.init_style_do);
|
||||
get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q);
|
||||
get_init_style_argument(cmd, "init-style-k", init_style_k, init_style_k);
|
||||
get_init_style_argument(cmd, "init-style-v", init_style_v, init_style_v);
|
||||
get_init_style_argument(cmd, "init-style-do", init_style_v, init_style_do);
|
||||
|
||||
cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "77_blackwell_fmha_bwd\n\n"
|
||||
<< " This example showcases the use of CUTLASS's collective operation builders to easily construct\n"
|
||||
<< " fused multi-head attention kernels for the backward pass targeting NVIDIA's Blackwell architecture.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --b=<int> Sets the B extent\n"
|
||||
<< " --h=<int> Sets the H extent\n"
|
||||
<< " --q=<int> Sets the Q extent\n"
|
||||
<< " --k=<int> Sets the K extent\n"
|
||||
<< " --d=<int> Sets the D extentn"
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --mask=<no|causal> Enables masking\n"
|
||||
<< " --sm-count Sets SM count rather than querying it\n"
|
||||
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
|
||||
<< "\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
void initialize_block(
|
||||
DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) {
|
||||
|
||||
switch (init_style) {
|
||||
case InitStyle::kOne: {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) 1, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kZero: {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) 0, (Element) 0);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kRandom: {
|
||||
cutlass::reference::device::BlockFillRandomGaussian(
|
||||
block.get(), block.size(), seed, (Element) 0, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride1: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
for (int j = 0; j < 128; j++) {
|
||||
data[j + 128*i] = static_cast<Element>((double) (j % 4));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride128: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
for (int j = 0; j < 128; j++) {
|
||||
data[j + 128*i] = static_cast<Element>((double) (i % 4));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kNone: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ExampleResult {
|
||||
bool passed = false;
|
||||
bool verified = false;
|
||||
float runtime_ms = 0;
|
||||
double tflops_tc_s = 0;
|
||||
size_t smem_size = 0;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class TileShape,
|
||||
class DispatchPolicy,
|
||||
class ActiveMask,
|
||||
class... KernelOptions
|
||||
>
|
||||
struct BwdRunner {
|
||||
|
||||
#ifdef FP8
|
||||
using Element = cutlass::float_e4m3_t;
|
||||
#else
|
||||
using Element = cutlass::half_t;
|
||||
#endif
|
||||
using ElementAccumulator = float;
|
||||
|
||||
// Q K D (H B)
|
||||
using ProblemShapeType = cute::tuple<int, int, int, cute::tuple<int, int>>;
|
||||
|
||||
using Operation = cutlass::fmha::device::Sm100FmhaBwd<Element, ElementAccumulator, TileShape, ActiveMask>;
|
||||
|
||||
using TensorStride = Stride<int, _1, Stride<int, int>>; // Seq D (H B)
|
||||
using StrideQ = TensorStride;
|
||||
using StrideK = TensorStride;
|
||||
using StrideV = TensorStride;
|
||||
using StrideO = TensorStride;
|
||||
using StrideLSE = Stride<_1, Stride<int, int>>; // Seq (H B)
|
||||
|
||||
// Backwards specific
|
||||
using StrideDQ = TensorStride;
|
||||
using StrideDK = TensorStride;
|
||||
using StrideDV = TensorStride;
|
||||
using StrideDO = TensorStride;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideQ stride_Q;
|
||||
StrideK stride_K;
|
||||
StrideV stride_V;
|
||||
StrideO stride_O;
|
||||
StrideLSE stride_LSE;
|
||||
|
||||
StrideDQ stride_dQ;
|
||||
StrideDK stride_dK;
|
||||
StrideDV stride_dV;
|
||||
StrideDO stride_dO;
|
||||
|
||||
uint64_t seed = 0;
|
||||
|
||||
DeviceAllocation<Element> block_Q;
|
||||
DeviceAllocation<Element> block_K;
|
||||
DeviceAllocation<Element> block_V;
|
||||
DeviceAllocation<Element> block_O;
|
||||
DeviceAllocation<ElementAccumulator> block_LSE;
|
||||
|
||||
DeviceAllocation<Element> block_dQ;
|
||||
DeviceAllocation<Element> block_dK;
|
||||
DeviceAllocation<Element> block_dV;
|
||||
DeviceAllocation<Element> block_dO;
|
||||
|
||||
DeviceAllocation<Element> block_ref_dQ;
|
||||
DeviceAllocation<Element> block_ref_dK;
|
||||
DeviceAllocation<Element> block_ref_dV;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
bool verify(const ProblemShapeType& problem_shape) {
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
auto [H, B] = HB;
|
||||
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_Q);
|
||||
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_K);
|
||||
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_V);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_O);
|
||||
|
||||
// keep going here! (this might be better in cursor)
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
|
||||
select<0,3>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_dQ);
|
||||
|
||||
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_dK);
|
||||
|
||||
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_dV);
|
||||
|
||||
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_dO);
|
||||
|
||||
fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask{});
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-0 : 1e-2;
|
||||
const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3;
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
double max_diff = 0;
|
||||
double mean_diff = 0;
|
||||
reference_abs_diff(block_dQ, block_ref_dQ, max_diff, mean_diff);
|
||||
|
||||
bool passed_dQ = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_dQ) {
|
||||
std::cerr << "failed dQ: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
reference_abs_diff(block_dK, block_ref_dK, max_diff, mean_diff);
|
||||
|
||||
bool passed_dK = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_dK) {
|
||||
std::cerr << "failed dK: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
reference_abs_diff(block_dV, block_ref_dV, max_diff, mean_diff);
|
||||
|
||||
bool passed_dV = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_dV) {
|
||||
std::cerr << "failed dV: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
return passed_dQ && passed_dK && passed_dV;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const ProblemShapeType& problem_shape, Options const& options) {
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
|
||||
auto shape_QO = select<0,2,3>(problem_shape);
|
||||
auto shape_KV = select<1,2,3>(problem_shape);
|
||||
auto shape_LSE = select<0,3>(problem_shape);
|
||||
|
||||
stride_Q = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
|
||||
stride_K = make_stride(D, _1{}, make_stride(D*K, D*K*H));
|
||||
stride_V = stride_K;
|
||||
stride_O = stride_Q;
|
||||
stride_LSE = make_stride(_1{}, make_stride(Q, Q*H));
|
||||
|
||||
stride_dQ = stride_Q;
|
||||
stride_dK = stride_K;
|
||||
stride_dV = stride_V;
|
||||
stride_dO = stride_O;
|
||||
|
||||
auto lsize = [](auto shape) {
|
||||
return size(make_shape(1ull, shape));
|
||||
};
|
||||
|
||||
block_Q.reset(lsize(shape_QO));
|
||||
block_K.reset(lsize(shape_KV));
|
||||
block_V.reset(lsize(shape_KV));
|
||||
block_O.reset(lsize(shape_QO));
|
||||
block_LSE.reset(lsize(shape_LSE));
|
||||
|
||||
block_dQ.reset(lsize(shape_QO));
|
||||
block_dK.reset(lsize(shape_KV));
|
||||
block_dV.reset(lsize(shape_KV));
|
||||
block_dO.reset(lsize(shape_QO));
|
||||
|
||||
block_ref_dQ.reset(lsize(shape_QO));
|
||||
block_ref_dK.reset(lsize(shape_KV));
|
||||
block_ref_dV.reset(lsize(shape_KV));
|
||||
|
||||
initialize_block(block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(block_K, seed + 2022, options.init_style_k);
|
||||
initialize_block(block_V, seed + 2021, options.init_style_v);
|
||||
initialize_block(block_dO, seed + 2020, options.init_style_do);
|
||||
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_Q);
|
||||
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_K);
|
||||
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_V);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_O);
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
|
||||
select<0,3>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
if (! options.skip_reference) {
|
||||
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
|
||||
}
|
||||
}
|
||||
|
||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
auto problem_shape = make_shape(options.q, options.k, options.d, make_shape(options.h, options.b));
|
||||
|
||||
initialize(problem_shape, options);
|
||||
|
||||
ElementAccumulator softmax_scale = 1.0f / sqrtf(options.d);
|
||||
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape,
|
||||
block_Q.get(), stride_Q,
|
||||
block_K.get(), stride_K,
|
||||
block_V.get(), stride_V,
|
||||
block_O.get(), stride_O,
|
||||
block_LSE.get(), stride_LSE,
|
||||
block_dO.get(), stride_dO,
|
||||
block_dQ.get(), stride_dQ,
|
||||
block_dK.get(), stride_dK,
|
||||
block_dV.get(), stride_dV,
|
||||
softmax_scale,
|
||||
hw_info
|
||||
};
|
||||
|
||||
Operation op;
|
||||
|
||||
ExampleResult example_result;
|
||||
|
||||
example_result.smem_size = Operation::Kernel::SharedStorageSize;
|
||||
|
||||
size_t workspace_size = 0;
|
||||
workspace_size = Operation::get_workspace_size(arguments);
|
||||
DeviceAllocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
status = op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
status = op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Run
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result = cudaEventCreate(&event);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of GEMMs
|
||||
result = cudaEventRecord(events[0]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
for (int i = 0; i < options.iterations; i++) {
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Record an event when the GEMMs are complete
|
||||
result = cudaEventRecord(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result = cudaEventSynchronize(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
runtime_ms /= static_cast<float>(options.iterations);
|
||||
|
||||
double flops = 10.0 * (std::is_same_v<ActiveMask, CausalMask> ? 0.5 : 1.0);
|
||||
flops *= static_cast<double>(get<0>(problem_shape));
|
||||
flops *= static_cast<double>(get<1>(problem_shape));
|
||||
flops *= static_cast<double>(get<2>(problem_shape));
|
||||
flops *= static_cast<double>(get<3,0>(problem_shape));
|
||||
flops *= static_cast<double>(get<3,1>(problem_shape));
|
||||
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
example_result.tflops_tc_s = tflops_s;
|
||||
example_result.runtime_ms = runtime_ms;
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Verify that the result is correct
|
||||
bool passed = true;
|
||||
if (options.verify) {
|
||||
passed = verify(problem_shape);
|
||||
if (passed) example_result.verified = true;
|
||||
}
|
||||
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
example_result.passed = true;
|
||||
|
||||
return example_result;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl;
|
||||
if (verbose) {
|
||||
std::cout << " t=" << result.runtime_ms << "ms, "
|
||||
"smem=" << result.smem_size << "b" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct KernelCoop {};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Mask>
|
||||
void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
|
||||
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
};
|
||||
|
||||
using HeadDim = _64;
|
||||
|
||||
run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma");
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Mask>
|
||||
void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
|
||||
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
};
|
||||
|
||||
using HeadDim = _128;
|
||||
|
||||
run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma");
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main_single(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || props.major != 10) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture "
|
||||
<< "(compute capability 100a) and CUDA 12.8 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
if (options.sm_count == 0) {
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
else {
|
||||
hw_info.sm_count = options.sm_count;
|
||||
}
|
||||
|
||||
std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " ";
|
||||
std::cout << "Backward" << " " << (options.causal ? "Causal" : "Full") << " ";
|
||||
std::cout << "#SM " << hw_info.sm_count << std::endl;
|
||||
|
||||
auto with_causal = [&](auto fn) {
|
||||
if (options.causal) {
|
||||
fn(CausalMask{});
|
||||
}
|
||||
else {
|
||||
fn(NoMask{});
|
||||
}
|
||||
};
|
||||
|
||||
with_causal([&](auto fusion) {
|
||||
if (options.d <= 64) {
|
||||
run_bwd_64(fusion, options, hw_info);
|
||||
}
|
||||
else if (options.d <= 128) {
|
||||
run_bwd_128(fusion, options, hw_info);
|
||||
}
|
||||
else {
|
||||
std::cout << "No kernel instantiated for d=" << options.d << std::endl;
|
||||
}
|
||||
});
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
auto arg = full_arguments[i];
|
||||
size_t eq_pos = arg.find('=');
|
||||
std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1);
|
||||
std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1);
|
||||
for (;;) {
|
||||
size_t comma_pos = rest.find(',');
|
||||
std::string current = rest.substr(0, comma_pos);
|
||||
full_arguments[i] = prefix + current;
|
||||
std::vector<const char*> next_args;
|
||||
for (auto& elem : full_arguments) { next_args.push_back(elem.data()); }
|
||||
main(argc, next_args.data());
|
||||
if (comma_pos == std::string::npos) break;
|
||||
rest = rest.substr(comma_pos+1);
|
||||
}
|
||||
recursed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (! recursed) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,832 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file A MLA (Multi-Head Latent Attention) inference kernel sample for the
|
||||
NVIDIA Blackwell Architecture.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
#include <cmath>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "reference/fmha_mla_reference.hpp"
|
||||
#include "reference/reference_abs_error.hpp"
|
||||
|
||||
#include "device/sm100_mla.hpp"
|
||||
#include "kernel/sm100_mla_tile_scheduler.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class InitStyle {
|
||||
kOne, kLinearStride128, kLinearStride1, kRandom, kNone
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool error = false;
|
||||
|
||||
int b = 1;
|
||||
int k = 256;
|
||||
int split_kv = -1; // number of split along k dim.
|
||||
bool is_var_split_kv = false;
|
||||
int max_split_kv = 16;
|
||||
int page = -1;
|
||||
float spread = 0.2f;
|
||||
int iterations = 3;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
|
||||
int sm_count = 0;
|
||||
|
||||
std::string kernel_filter;
|
||||
|
||||
InitStyle init_style_q = InitStyle::kRandom;
|
||||
InitStyle init_style_c = InitStyle::kRandom;
|
||||
|
||||
static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) {
|
||||
std::string s;
|
||||
cmd.get_cmd_line_argument(name, s, s);
|
||||
if (s.empty()) {
|
||||
dst = src;
|
||||
}
|
||||
else {
|
||||
if (s == "r") {
|
||||
dst = InitStyle::kRandom;
|
||||
}
|
||||
else if (s == "1") {
|
||||
dst = InitStyle::kOne;
|
||||
}
|
||||
else if (s == "d") {
|
||||
dst = InitStyle::kLinearStride1;
|
||||
}
|
||||
else if (s == "s") {
|
||||
dst = InitStyle::kLinearStride128;
|
||||
}
|
||||
else if (s == "n") {
|
||||
dst = InitStyle::kNone;
|
||||
}
|
||||
else {
|
||||
std::cout << "Error: " << s << " is not a valid input type.\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
Options defaults;
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("k", k, -1);
|
||||
if (k == -1) k = defaults.k;
|
||||
|
||||
cmd.get_cmd_line_argument("b", b, -1);
|
||||
if (b == -1) b = 16384 / k;
|
||||
if (b == 0) b = 1;
|
||||
|
||||
cmd.get_cmd_line_argument("split_kv", split_kv, defaults.split_kv);
|
||||
cmd.get_cmd_line_argument("page", page, defaults.page);
|
||||
cmd.get_cmd_line_argument("spread", spread, defaults.spread);
|
||||
cmd.get_cmd_line_argument("is_var_split_kv", is_var_split_kv, false);
|
||||
if (page == -1) {
|
||||
is_var_split_kv = false;
|
||||
}
|
||||
cmd.get_cmd_line_argument("max_split_kv", max_split_kv, defaults.max_split_kv);
|
||||
if (is_var_split_kv == true) {
|
||||
split_kv = max_split_kv;
|
||||
}
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||
|
||||
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_c, defaults.init_style_c);
|
||||
get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q);
|
||||
get_init_style_argument(cmd, "init-style-c", init_style_c, init_style_c);
|
||||
|
||||
cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "77_blackwell_mla\n\n"
|
||||
<< " This example showcases the use of CUTLASS for fused multi-head latent\n"
|
||||
<< " attention kernels targeting NVIDIA's Blackwell architecture.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --b=<int> Sets the B extent\n"
|
||||
<< " --k=<int> Sets the K extent\n"
|
||||
<< " --page=<int> Enables paging and sets the page size\n"
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --spread=<float> Relative spread away from K for paging\n"
|
||||
<< " --split_kv=<int> Split KV factor\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --sm-count Sets SM count rather than querying it\n"
|
||||
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
|
||||
<< "\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
void initialize_block(
|
||||
DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) {
|
||||
|
||||
switch (init_style) {
|
||||
case InitStyle::kOne: {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) 1, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kRandom: {
|
||||
cutlass::reference::device::BlockFillRandomGaussian(
|
||||
block.get(), block.size(), seed, (Element) -1, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride1: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
for (int j = 0; j < 128; j++) {
|
||||
data[j + 128*i] = static_cast<Element>((double) (j % 4));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride128: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 64; i ++) {
|
||||
for (int j = 0; j < 64; j++) {
|
||||
data[j + 64*i] = static_cast<Element>((double) (i % 9));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kNone: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ExampleResult {
|
||||
bool passed = false;
|
||||
bool verified = false;
|
||||
float runtime_ms = 0;
|
||||
double tflops_tc_s = 0;
|
||||
double tbytes_s = 0;
|
||||
size_t smem_size = 0;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool v>
|
||||
struct IsPersistent {
|
||||
static const bool value = v;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class TileShape,
|
||||
class PersistenceOption = IsPersistent<true>
|
||||
>
|
||||
struct Runner {
|
||||
|
||||
#ifdef FP8
|
||||
using Element = cutlass::float_e4m3_t;
|
||||
#elif FP16
|
||||
using Element = cutlass::half_t;
|
||||
#else
|
||||
#error "Must either define FP8 or FP16"
|
||||
#endif
|
||||
|
||||
using ElementAcc = float;
|
||||
using ElementOut = cutlass::half_t;
|
||||
|
||||
using TileShapeH = cute::tuple_element_t<0, TileShape>;
|
||||
using TileShapeD = cute::tuple_element_t<2, TileShape>;
|
||||
|
||||
// H K (D_latent D_rope) B
|
||||
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
|
||||
|
||||
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
|
||||
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
|
||||
using StrideO = StrideK; // H D B
|
||||
using StrideLSE = cute::tuple<_1, int>; // H B
|
||||
|
||||
using TileScheduler = std::conditional_t<
|
||||
PersistenceOption::value,
|
||||
Sm100MlaPersistentTileScheduler,
|
||||
Sm100MlaIndividualTileScheduler
|
||||
>;
|
||||
|
||||
using Kernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
|
||||
TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler
|
||||
>;
|
||||
using Operation = cutlass::fmha::device::MLA<Kernel>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideQ stride_Q_latent;
|
||||
StrideK stride_C_latent;
|
||||
StrideQ stride_Q_rope;
|
||||
StrideK stride_K_rope;
|
||||
StrideO stride_O;
|
||||
StrideLSE stride_LSE;
|
||||
StrideLSE stride_PT;
|
||||
|
||||
uint64_t seed = 0;
|
||||
|
||||
int page_size = -1;
|
||||
int page_count = -1;
|
||||
|
||||
// We allocate Q and C as first latent, then rope
|
||||
// This means that we offset the pointer by HeadDim_latent to get the rope
|
||||
// portion
|
||||
DeviceAllocation<Element> block_Q;
|
||||
DeviceAllocation<Element> block_C;
|
||||
DeviceAllocation<ElementOut> block_O;
|
||||
DeviceAllocation<int> block_seq;
|
||||
DeviceAllocation<int> block_PT;
|
||||
DeviceAllocation<int> block_split_kv;
|
||||
DeviceAllocation<int> block_accum_split_len;
|
||||
DeviceAllocation<ElementAcc> block_LSE;
|
||||
DeviceAllocation<ElementOut> block_ref_O;
|
||||
DeviceAllocation<ElementAcc> block_ref_LSE;
|
||||
|
||||
ElementAcc scale;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
bool verify(const ProblemShape& problem_shape) {
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
int page_K = K;
|
||||
int page_B = B;
|
||||
if (block_PT.get() != nullptr) {
|
||||
page_K = page_size;
|
||||
page_B = page_count;
|
||||
}
|
||||
|
||||
Tensor mQ_latent = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
cute::make_tuple(H, D_latent, B),
|
||||
stride_Q_latent);
|
||||
|
||||
Tensor mQ_rope = make_tensor(make_gmem_ptr(block_Q.get() + D_latent),
|
||||
cute::make_tuple(H, D_rope, B),
|
||||
stride_Q_rope);
|
||||
|
||||
Tensor mC_latent = make_tensor(make_gmem_ptr(block_C.get()),
|
||||
cute::make_tuple(page_K, D_latent, page_B),
|
||||
stride_C_latent);
|
||||
|
||||
Tensor mK_rope = make_tensor(make_gmem_ptr(block_C.get() + D_latent),
|
||||
cute::make_tuple(page_K, D_rope, page_B),
|
||||
stride_K_rope);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()),
|
||||
cute::make_tuple(H, D_latent, B),
|
||||
stride_O);
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()),
|
||||
cute::make_tuple(H, B),
|
||||
stride_LSE);
|
||||
|
||||
Tensor mSeq = make_tensor(make_gmem_ptr(static_cast<int*>(block_seq.get())), make_shape(B));
|
||||
Tensor mPT = make_tensor(make_gmem_ptr(static_cast<int*>(block_PT.get())), make_shape(ceil_div(K, page_size), B), stride_PT);
|
||||
|
||||
fmha_mla_reference(problem_shape, mSeq, mPT, mQ_latent, mQ_rope, mC_latent, mK_rope, mO, mLSE, scale);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-2;
|
||||
const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3;
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
double max_diff = 0;
|
||||
double mean_diff = 0;
|
||||
#ifdef B2B
|
||||
reference_rel_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
#else
|
||||
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
#endif
|
||||
|
||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_O) {
|
||||
std::cerr << "failed O: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
bool passed_LSE = true;
|
||||
#ifndef B2B
|
||||
reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
|
||||
|
||||
passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if ( ! passed_LSE) {
|
||||
std::cerr << "failed LSE: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
return passed_O && passed_LSE;
|
||||
}
|
||||
|
||||
ProblemShape initialize(const Options& options) {
|
||||
auto problem_shape = cute::make_tuple(TileShapeH{}, options.k, TileShapeD{}, options.b);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
// the scale is based on the non-absorbed sizes, change as appropriate
|
||||
// we can't determine this parameter from the info we have, it's an input
|
||||
int D_non_latent = 128;
|
||||
scale = static_cast<decltype(scale)>(1.0 / sqrt(1.0 * (D_non_latent + D_rope)));
|
||||
// Shape (H, D, B)
|
||||
stride_Q_latent = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(H * (0 + D_latent + D_rope)));
|
||||
stride_Q_rope = stride_Q_latent;
|
||||
stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));
|
||||
stride_LSE = cute::make_tuple(_1{}, 0 + H);
|
||||
|
||||
block_Q.reset(static_cast<size_t>(options.b) * H * (D_latent + D_rope));
|
||||
block_O.reset(static_cast<size_t>(options.b) * H * D_latent);
|
||||
block_LSE.reset(static_cast<size_t>(options.b) * H);
|
||||
block_ref_O.reset(static_cast<size_t>(options.b) * H * D_latent);
|
||||
block_ref_LSE.reset(static_cast<size_t>(options.b) * H);
|
||||
|
||||
if (options.page == -1) {
|
||||
|
||||
stride_C_latent = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(options.k) * (D_latent + D_rope));
|
||||
stride_K_rope = stride_C_latent;
|
||||
|
||||
block_C.reset(static_cast<size_t>(options.b) * options.k * (D_latent + D_rope));
|
||||
|
||||
}
|
||||
else {
|
||||
|
||||
float spread = options.spread;
|
||||
int max_K = static_cast<int>((1 + spread) * K);
|
||||
int min_K = static_cast<int>((1 - spread) * K);
|
||||
page_size = options.page;
|
||||
page_count = B * ceil_div(max_K, page_size);
|
||||
stride_PT = cute::make_stride(_1{}, page_count);
|
||||
|
||||
std::vector<int> host_seq(B);
|
||||
std::vector<int> host_PT(page_count * B);
|
||||
|
||||
for (int i = 0; i < B; i++) {
|
||||
int seq = min_K + rand() % (max_K - min_K + 1);
|
||||
host_seq[i] = seq;
|
||||
for (int j = 0; j < ceil_div(seq, page_size); j++) {
|
||||
host_PT[page_count * i + j] = i + j * B;
|
||||
}
|
||||
}
|
||||
|
||||
block_seq.reset(host_seq.size());
|
||||
block_seq.copy_from_host(host_seq.data(), host_seq.size());
|
||||
block_PT.reset(host_PT.size());
|
||||
block_PT.copy_from_host(host_PT.data(), host_PT.size());
|
||||
|
||||
get<1>(problem_shape) = max_K;
|
||||
|
||||
stride_C_latent = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, page_size * static_cast<int64_t>((D_latent + D_rope)));
|
||||
stride_K_rope = stride_C_latent;
|
||||
|
||||
block_C.reset(page_count * page_size * static_cast<int64_t>((D_latent + D_rope)));
|
||||
|
||||
if (options.is_var_split_kv == true) {
|
||||
std::vector<int> host_split_kv(B);
|
||||
for(int i = 0; i < B; ++i) {
|
||||
auto len = host_seq[i];
|
||||
int split = ceil_div(options.max_split_kv, ceil_div(max_K, len));
|
||||
host_split_kv[i] = split;
|
||||
}
|
||||
block_split_kv.reset(B);
|
||||
block_split_kv.copy_from_host(host_split_kv.data(), host_split_kv.size());
|
||||
}
|
||||
}
|
||||
|
||||
initialize_block(block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(block_C, seed + 2022, options.init_style_c);
|
||||
|
||||
return problem_shape;
|
||||
}
|
||||
|
||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
|
||||
ProblemShape problem_shape = initialize(options);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape,
|
||||
{ scale,
|
||||
block_Q.get(), stride_Q_latent,
|
||||
block_Q.get() + D_latent, stride_Q_rope,
|
||||
block_C.get(), stride_C_latent,
|
||||
block_C.get() + D_latent, stride_K_rope,
|
||||
block_seq.get(),
|
||||
block_PT.get(), stride_PT,
|
||||
page_count, page_size},
|
||||
{ block_O.get(),
|
||||
stride_O,
|
||||
block_LSE.get(),
|
||||
stride_LSE},
|
||||
hw_info,
|
||||
options.split_kv,
|
||||
options.is_var_split_kv ? block_split_kv.get() : nullptr
|
||||
};
|
||||
if (options.split_kv < 0 && !options.is_var_split_kv) {
|
||||
Operation::set_split_kv(arguments);
|
||||
}
|
||||
|
||||
Operation op;
|
||||
|
||||
ExampleResult example_result;
|
||||
|
||||
example_result.smem_size = Operation::Kernel::SharedStorageSize;
|
||||
|
||||
size_t workspace_size = 0;
|
||||
workspace_size = Operation::get_workspace_size(arguments);
|
||||
DeviceAllocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
status = op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
status = op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
// Run
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result = cudaEventCreate(&event);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of GEMMs
|
||||
result = cudaEventRecord(events[0]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
for (int i = 0; i < options.iterations; i++) {
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Record an event when the GEMMs are complete
|
||||
result = cudaEventRecord(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result = cudaEventSynchronize(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
runtime_ms /= static_cast<float>(options.iterations);
|
||||
|
||||
double flops = 1.0;
|
||||
flops *= B;
|
||||
flops *= K;
|
||||
flops *= H;
|
||||
flops *= 2.0;
|
||||
flops *= (2.0 * D_latent + D_rope);
|
||||
|
||||
double bytes_q = sizeof(Element);
|
||||
bytes_q *= B;
|
||||
bytes_q *= H;
|
||||
bytes_q *= (D_latent + D_rope);
|
||||
double bytes_c = sizeof(Element);
|
||||
bytes_c *= B;
|
||||
bytes_c *= options.k; // K may be max_K here
|
||||
bytes_c *= (D_latent + D_rope);
|
||||
double bytes_o = sizeof(ElementOut);
|
||||
bytes_o *= B;
|
||||
bytes_o *= H;
|
||||
bytes_o *= D_latent;
|
||||
double bytes = bytes_q + bytes_c + bytes_o;
|
||||
|
||||
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
double tbytes_s = bytes * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
example_result.tflops_tc_s = tflops_s;
|
||||
example_result.tbytes_s = tbytes_s;
|
||||
example_result.runtime_ms = runtime_ms;
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Verify that the result is correct
|
||||
bool passed = true;
|
||||
if (options.verify) {
|
||||
passed = verify(problem_shape);
|
||||
if (passed) example_result.verified = true;
|
||||
}
|
||||
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
example_result.passed = true;
|
||||
|
||||
return example_result;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s " << result.tbytes_s << " TB/s" << std::endl;
|
||||
if (verbose) {
|
||||
std::cout << " t=" << result.runtime_ms * 1e3 << " us, "
|
||||
"smem=" << result.smem_size << "b" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_mla(Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, const char* name, auto... kernel_options) {
|
||||
if ((! options.kernel_filter.empty()) && (! std::regex_search(name, std::basic_regex(options.kernel_filter)))) {
|
||||
return;
|
||||
}
|
||||
Runner<decltype(shape), decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
};
|
||||
|
||||
using NumHeads = _128;
|
||||
using HeadDimLatent = _512;
|
||||
using HeadDim = Shape<HeadDimLatent, _64>;
|
||||
|
||||
std::cout << "###### B " << options.b << " MLA H " << 0 + NumHeads{} << " ";
|
||||
std::cout << "D_rope " << 0 + get<1>(HeadDim{}) << " D_latent " << 0 + get<0>(HeadDim{}) << " ";
|
||||
std::cout << "Q 1 K " << options.k << " Gen None ";
|
||||
std::cout << "Split " << options.split_kv << " Gen None ";
|
||||
std::cout << "#SM " << hw_info.sm_count << std::endl;
|
||||
|
||||
using Blocking = _128;
|
||||
std::string name = std::to_string((int) NumHeads{}) + "x" + std::to_string((int) Blocking{});
|
||||
std::string individual = " individual";
|
||||
std::string persistent = " persistent";
|
||||
#if FP8
|
||||
name += " fp8";
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
#elif FP16
|
||||
name += " fp16";
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
#endif
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
int main_single(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || props.major != 10) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture "
|
||||
<< "(compute capability major 10) and CUDA 12.8 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
if (options.sm_count == 0) {
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
else {
|
||||
hw_info.sm_count = options.sm_count;
|
||||
}
|
||||
|
||||
run_mla(options, hw_info);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
auto arg = full_arguments[i];
|
||||
size_t eq_pos = arg.find('=');
|
||||
std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1);
|
||||
std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1);
|
||||
for (;;) {
|
||||
size_t comma_pos = rest.find(',');
|
||||
std::string current = rest.substr(0, comma_pos);
|
||||
full_arguments[i] = prefix + current;
|
||||
std::vector<const char*> next_args;
|
||||
for (auto& elem : full_arguments) { next_args.push_back(elem.data()); }
|
||||
main(argc, next_args.data());
|
||||
if (comma_pos == std::string::npos) break;
|
||||
rest = rest.substr(comma_pos+1);
|
||||
}
|
||||
recursed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (! recursed) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -28,14 +28,12 @@
|
||||
|
||||
|
||||
set_property(
|
||||
SOURCE
|
||||
77_blackwell_fmha.cu
|
||||
77_blackwell_fmha_gen.cu
|
||||
77_blackwell_mla.cu
|
||||
77_blackwell_fmha_bwd.cu
|
||||
PROPERTY
|
||||
COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0"
|
||||
)
|
||||
SOURCE 77_blackwell_fmha.cu
|
||||
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0")
|
||||
|
||||
set_property(
|
||||
SOURCE 77_blackwell_fmha_gen.cu
|
||||
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0")
|
||||
|
||||
set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no)
|
||||
set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
|
||||
@ -50,98 +48,58 @@ set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=64 --verify)
|
||||
set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap)
|
||||
set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only)
|
||||
|
||||
set(TEST_MLA_BASIC --b=1 --k=512 --verify)
|
||||
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")))
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_fp8
|
||||
77_blackwell_fmha.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_CAUSAL
|
||||
# TEST_VARLEN
|
||||
# TEST_HDIM64
|
||||
# TEST_GQA)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_fp8 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_fp8 PRIVATE FP8)
|
||||
|
||||
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_gen_fp8
|
||||
77_blackwell_fmha_gen.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_GEN_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_gen_fp8 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_gen_fp8 PRIVATE FP8)
|
||||
|
||||
foreach(PREC fp8 fp16)
|
||||
string(TOUPPER "${PREC}" PREC_MACRO)
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_fp16
|
||||
77_blackwell_fmha.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_CAUSAL
|
||||
# TEST_VARLEN
|
||||
# TEST_HDIM64
|
||||
# TEST_GQA)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_fp16 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_${PREC}
|
||||
77_blackwell_fmha.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_CAUSAL
|
||||
# TEST_VARLEN
|
||||
# TEST_HDIM64
|
||||
# TEST_GQA)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO})
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_gen_${PREC}
|
||||
77_blackwell_fmha_gen.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_GEN_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_gen_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_gen_${PREC} PRIVATE ${PREC_MACRO})
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_mla_2sm_${PREC}
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_2sm_${PREC} PRIVATE ${PREC_MACRO})
|
||||
target_compile_options(77_blackwell_mla_2sm_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_mla_2sm_cpasync_${PREC}
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC)
|
||||
target_compile_options(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_mla_b2b_2sm_${PREC}
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${PREC_MACRO} B2B)
|
||||
target_compile_options(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_bwd_${PREC}
|
||||
77_blackwell_fmha_bwd.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
|
||||
target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_bwd_sat_${PREC}
|
||||
77_blackwell_fmha_bwd.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC)
|
||||
target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v)
|
||||
endforeach()
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_gen_fp16
|
||||
77_blackwell_fmha_gen.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_GEN_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_gen_fp16 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -21,68 +21,3 @@ To modify the code for fusions, `collective/fmha_fusion.hpp` provides the easies
|
||||
The `apply_mask` function is called with the accumulator of the first GEMM and the logical positions of those elements.
|
||||
It is well-suited for applying masks or activations.
|
||||
More complex fusions that require memory loads would require modifying the mainloop collective to orchestrate the load via TMA.
|
||||
|
||||
# FMHA for Blackwell: Backward
|
||||
|
||||
This sample provides code for fused multi-head attention backward pass.
|
||||
It supports HeadDims of 64 and 128, and fp8, fp16, and bf16 input data types.
|
||||
The blocking in sequence length Q and K is 128, loads are done via TMA.
|
||||
We support causal masking.
|
||||
The structure of this code is very similar to the forward pass, and the techniques are analogous.
|
||||
|
||||
There are three kernels to compute backwards:
|
||||
1. `FmhaKernelBwdSumOdO` to compute the sum of the outer product of O and dO.
|
||||
3. `Sm100FmhaBwdKernelTmaWarpSpecialized` to compute the backward pass.
|
||||
2. `FmhaKernelBwdConvert` to convert the dQ from fp32 to the final output precision.
|
||||
|
||||
`Sm100FmhaBwdKernelTmaWarpSpecialized` is the main point of this sample, as it demonstrates how to use tensor cores to achieve a high performance fused kernel.
|
||||
|
||||
# MLA Inference for Blackwell
|
||||
|
||||
This sample provides code for fused multi-head latent attention inference in
|
||||
the weight-absorbed regime, i.e. for latent head dim 512, and rope head dim 64.
|
||||
It supports fp16, bf16, and fp8 input and output types.
|
||||
|
||||
To accomodate the large output accumulator due to the large latent head dimension,
|
||||
the sample demonstrates how to leverage 2Sm Blackwell tensor cores.
|
||||
|
||||
Loading can be done via TMA (either without paging or with page size 128), or using `cp.async`
|
||||
for support of any power-of-two page size less than or equal to 128.
|
||||
With paging, the code also supports variable sequence length.
|
||||
|
||||
The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an MLA kernel.
|
||||
|
||||
The example builds six binaries, showcasing TMA and `cp.async` usage, as well as a back-to-back gemm (essentially turning the softmax into a no-op) for fp8 and fp16.
|
||||
For detailed information on how to invoke them, check out either the tests in `CMakeLists.txt` or the `--help` for them.
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -1,92 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace cutlass::fmha {
|
||||
|
||||
struct Pow2 {
|
||||
int n;
|
||||
int log2_n;
|
||||
|
||||
explicit CUTE_DEVICE Pow2(int n) : n(n) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
log2_n = __ffs(n) - 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class T>
|
||||
CUTE_HOST_DEVICE T operator *(T const& b) const {
|
||||
return n * b;
|
||||
}
|
||||
|
||||
template<int N>
|
||||
CUTE_HOST_DEVICE auto operator *(Int<N> const&) const {
|
||||
if constexpr (N & (N - 1) == 0) {
|
||||
return Pow2{n * N};
|
||||
}
|
||||
return n * N;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template<class T>
|
||||
CUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) {
|
||||
return a >> b.log2_n;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
CUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) {
|
||||
return a & (b.n - 1);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) {
|
||||
return a < b.n;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE void print(Pow2 const& a) {
|
||||
printf("2^%d", a.log2_n);
|
||||
}
|
||||
|
||||
} // end namespace cutlass::fmha
|
||||
|
||||
namespace cute {
|
||||
|
||||
template <>
|
||||
struct is_integral<cutlass::fmha::Pow2> : true_type {};
|
||||
|
||||
} // end namespace cute
|
||||
@ -1,320 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "../device/fmha.hpp"
|
||||
#include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp"
|
||||
#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp"
|
||||
#include "../kernel/fmha_kernel_bwd_convert.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::fmha::device {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Element, class ElementAccumulator, class TileShape, class Mask>
|
||||
class Sm100FmhaBwd {
|
||||
public:
|
||||
/// Argument structure: User API
|
||||
struct Arguments {
|
||||
// Q K D HB
|
||||
cute::tuple<int, int, int, cute::tuple<int, int>> problem_size;
|
||||
|
||||
const Element* ptr_Q;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;
|
||||
const Element* ptr_K;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_K;
|
||||
const Element* ptr_V;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_V;
|
||||
|
||||
const Element* ptr_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
|
||||
const ElementAccumulator* ptr_LSE;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_LSE;
|
||||
|
||||
const Element* ptr_dO;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
|
||||
|
||||
Element* ptr_dQ;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dQ;
|
||||
Element* ptr_dK;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dK;
|
||||
Element* ptr_dV;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dV;
|
||||
|
||||
ElementAccumulator softmax_scale;
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
using OperationSumOdO = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<Element, ElementAccumulator>
|
||||
>;
|
||||
using OperationConvert = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::FmhaKernelBwdConvert<Element, ElementAccumulator>
|
||||
>;
|
||||
|
||||
using Operation = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<Element, ElementAccumulator, TileShape, Mask>
|
||||
>;
|
||||
using Kernel = typename Operation::Kernel;
|
||||
|
||||
struct Params {
|
||||
OperationSumOdO op_sum_OdO;
|
||||
Operation op;
|
||||
OperationConvert op_convert;
|
||||
ElementAccumulator* dQ_acc;
|
||||
size_t dQ_acc_size;
|
||||
};
|
||||
|
||||
private:
|
||||
Params params_;
|
||||
|
||||
static typename OperationSumOdO::Arguments to_sum_OdO_arguments(
|
||||
Arguments const& args,
|
||||
ElementAccumulator* sum_odo = nullptr,
|
||||
ElementAccumulator* scaled_lse = nullptr) {
|
||||
using namespace cute;
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));
|
||||
auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));
|
||||
auto log2_e = log2f(expf(1.0f));
|
||||
return typename OperationSumOdO::Arguments {
|
||||
args.problem_size,
|
||||
args.ptr_O, args.stride_O,
|
||||
args.ptr_dO, args.stride_dO,
|
||||
sum_odo, stride_sum_OdO,
|
||||
args.ptr_LSE, args.stride_LSE,
|
||||
scaled_lse, stride_scaled_lse,
|
||||
-1.0f, -log2_e
|
||||
};
|
||||
}
|
||||
|
||||
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
|
||||
using namespace cute;
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
|
||||
return typename OperationConvert::Arguments {
|
||||
args.problem_size,
|
||||
src, stride_src_dQ,
|
||||
nullptr, stride_src_dQ,
|
||||
nullptr, stride_src_dQ,
|
||||
args.ptr_dQ, args.stride_dQ,
|
||||
nullptr, args.stride_dK,
|
||||
nullptr, args.stride_dV,
|
||||
args.softmax_scale
|
||||
};
|
||||
}
|
||||
|
||||
static typename Operation::Arguments to_bwd_arguments(
|
||||
Arguments const& args,
|
||||
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_sum_OdO = {},
|
||||
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
|
||||
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
|
||||
return typename Operation::Arguments{
|
||||
args.problem_size,
|
||||
{ args.ptr_Q, args.stride_Q,
|
||||
args.ptr_K, args.stride_K,
|
||||
args.ptr_V, args.stride_V,
|
||||
args.ptr_dO, args.stride_dO,
|
||||
scaled_lse, stride_scaled_lse,
|
||||
sum_OdO, stride_sum_OdO,
|
||||
dQ_acc, stride_dQ,
|
||||
args.softmax_scale },
|
||||
{ args.ptr_dK, args.stride_dK,
|
||||
args.ptr_dV, args.stride_dV },
|
||||
args.hw_info
|
||||
};
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
Status status = Status::kSuccess;
|
||||
|
||||
status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args));
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = OperationConvert::can_implement(to_convert_arguments(args));
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = Operation::can_implement(to_bwd_arguments(args));
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
size_t workspace_bytes = 0;
|
||||
// OdO vector
|
||||
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
|
||||
// scaled LSE vector
|
||||
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
|
||||
// FP32 versions of outputs that are churned (start off with Q only)
|
||||
workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator);
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Initializes state from arguments.
|
||||
Status
|
||||
initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, void* workspace_scaled_lse, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
|
||||
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
|
||||
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse);
|
||||
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
|
||||
params_.dQ_acc = dQ_acc;
|
||||
params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator);
|
||||
auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse);
|
||||
auto args_convert = to_convert_arguments(args, dQ_acc);
|
||||
params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream);
|
||||
params_.op_convert.initialize(args_convert, nullptr, stream);
|
||||
auto args_bwd = to_bwd_arguments(
|
||||
args, sum_OdO, args_sum_OdO.stride_sum_OdO,
|
||||
scaled_lse, args_sum_OdO.stride_scaled_lse,
|
||||
dQ_acc, args_convert.stride_src_dQ
|
||||
);
|
||||
params_.op.initialize(args_bwd, nullptr, stream);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Initializes state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
char* workspace_chr = reinterpret_cast<char*>(workspace);
|
||||
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
|
||||
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
|
||||
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_chr);
|
||||
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
|
||||
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_chr);
|
||||
return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream);
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()");
|
||||
|
||||
Status result = Status::kSuccess;
|
||||
result = params.op_sum_OdO.run(stream);
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream);
|
||||
if (cuda_result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = params.op.run(stream);
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
result = params.op_convert.run(stream);
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,357 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
#include "kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
|
||||
#include "kernel/sm100_fmha_mla_reduction.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::fmha::device {
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class Kernel_
|
||||
>
|
||||
class MLA {
|
||||
public:
|
||||
|
||||
using Kernel = Kernel_;
|
||||
|
||||
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
|
||||
typename Kernel::ElementOut,
|
||||
typename Kernel::ElementAcc,
|
||||
typename Kernel::ElementAcc,
|
||||
Kernel::TileShapeH::value,
|
||||
Kernel::TileShapeL::value,
|
||||
256 /*Max split*/
|
||||
>;
|
||||
|
||||
/// Argument structure: User API
|
||||
using KernelArguments = typename Kernel::Arguments;
|
||||
using ReductionArguments = typename ReductionKernel::Arguments;
|
||||
|
||||
using Arguments = KernelArguments;
|
||||
|
||||
/// Argument structure: Kernel API
|
||||
using KernelParams = typename Kernel::Params;
|
||||
using ReductionParams = typename ReductionKernel::Params;
|
||||
struct Params {
|
||||
KernelParams fmha_params;
|
||||
ReductionParams reduction_params;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel API parameters object
|
||||
Params params_;
|
||||
|
||||
bool is_initialized(bool set = false) {
|
||||
static bool initialized = false;
|
||||
if (set) initialized = true;
|
||||
return initialized;
|
||||
}
|
||||
|
||||
static ReductionArguments to_reduction_args(Arguments const& args) {
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
return ReductionArguments{
|
||||
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
|
||||
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
|
||||
args.ptr_split_kv, Kernel::TileShapeS::value
|
||||
};
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Access the Params structure
|
||||
Params const& params() const {
|
||||
return params_;
|
||||
}
|
||||
|
||||
static void set_split_kv (KernelArguments& args) {
|
||||
if (args.split_kv >= 1) return;
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
int max_splits = ceil_div(K, 128);
|
||||
int sms_per_batch = max(1, sm_count / B);
|
||||
int split_heur = min(max_splits, sms_per_batch);
|
||||
int waves = ceil_div(B * split_heur, sm_count);
|
||||
int k_waves = ceil_div(max_splits, split_heur);
|
||||
int split_wave_aware = ceil_div(max_splits, k_waves);
|
||||
args.split_kv = split_wave_aware;
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
if (! Kernel::can_implement(args)) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
size_t workspace_bytes = 0;
|
||||
workspace_bytes += Kernel::get_workspace_size(args);
|
||||
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
|
||||
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
// first, account for dynamic smem capacity if needed
|
||||
cudaError_t result;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// query occupancy after setting smem size
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
device_kernel<Kernel>,
|
||||
Kernel::MaxThreadsPerBlock,
|
||||
smem_size);
|
||||
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Initialize the workspace
|
||||
Status status = Kernel::initialize_workspace(args, workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {kernel_params, reduction_params};
|
||||
|
||||
if (is_initialized()) return Status::kSuccess;
|
||||
|
||||
// account for dynamic smem capacity if needed
|
||||
// no dynamic smem is needed for reduction kernel
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
is_initialized(true);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
|
||||
Status
|
||||
update(Arguments const& args, void* workspace = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
if (workspace_bytes > 0 && nullptr == workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {fmha_params, reduction_params};
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::run()");
|
||||
dim3 const block = Kernel::get_block_shape();
|
||||
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
|
||||
|
||||
// configure smem size and carveout
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
Status launch_result;
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||
cute::size<1>(typename Kernel::ClusterShape{}),
|
||||
cute::size<2>(typename Kernel::ClusterShape{}));
|
||||
void const* kernel = (void const*) device_kernel<Kernel>;
|
||||
void* kernel_params[] = {¶ms.fmha_params};
|
||||
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
}
|
||||
else {
|
||||
launch_result = Status::kSuccess;
|
||||
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
|
||||
}
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess != result or Status::kSuccess != launch_result) {
|
||||
//return Status::kSuccess;
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
if (params.reduction_params.split_kv > 1) {
|
||||
// launch reduction kernel
|
||||
dim3 const block = ReductionKernel::get_block_shape();
|
||||
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
|
||||
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess == result) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
return run(args, workspace, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
operator()(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,146 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<class Element, class ElementAcc>
|
||||
struct FmhaKernelBwdConvert {
|
||||
|
||||
struct Arguments {
|
||||
tuple<int, int, int, tuple<int, int>> problem_size;
|
||||
|
||||
const ElementAcc* ptr_src_dQ;
|
||||
tuple<int, _1, tuple<int, int>> stride_src_dQ;
|
||||
const ElementAcc* ptr_src_dK;
|
||||
tuple<int, _1, tuple<int, int>> stride_src_dK;
|
||||
const ElementAcc* ptr_src_dV;
|
||||
tuple<int, _1, tuple<int, int>> stride_src_dV;
|
||||
|
||||
Element* ptr_dest_dQ;
|
||||
tuple<int, _1, tuple<int, int>> stride_dest_dQ;
|
||||
Element* ptr_dest_dK;
|
||||
tuple<int, _1, tuple<int, int>> stride_dest_dK;
|
||||
Element* ptr_dest_dV;
|
||||
tuple<int, _1, tuple<int, int>> stride_dest_dV;
|
||||
|
||||
ElementAcc scale = 1.0;
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
static constexpr int SharedStorageSize = 0;
|
||||
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
static const int MaxThreadsPerBlock = 128;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
static const int kBlockSeq = 8;
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static const int kNumThreadsD = 16;
|
||||
static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD;
|
||||
static const int kElementsPerLoad = 4;
|
||||
|
||||
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return get<2>(args.problem_size) % kElementsPerLoad == 0;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(size<3,0>(params.problem_size), size<3,1>(params.problem_size), ceil_div(std::max(size<0>(params.problem_size), size<1>(params.problem_size)), kBlockSeq));
|
||||
return grid;
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
dim3 block(kNumThreadsD, kNumThreadsSeq, 1);
|
||||
return block;
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template<class StrideSrc, class StrideDest>
|
||||
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) {
|
||||
auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
|
||||
auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
|
||||
|
||||
for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) {
|
||||
int idx_s = idx_s_t + kBlockSeq * blockIdx.z;
|
||||
if (idx_s >= count) continue;
|
||||
auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src);
|
||||
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest);
|
||||
|
||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||
ElementAcc value_src[kElementsPerLoad];
|
||||
Element value_dest[kElementsPerLoad];
|
||||
|
||||
using VecSrc = uint_bit_t<sizeof_bits_v<ElementAcc> * kElementsPerLoad>;
|
||||
using VecDest = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
|
||||
*reinterpret_cast<VecSrc*>(value_src) = *reinterpret_cast<const VecSrc*>(&ptr_src_bhs[idx_d]);
|
||||
|
||||
for (int v = 0; v < kElementsPerLoad; v++) {
|
||||
value_dest[v] = static_cast<Element>(params.scale * value_src[v]);
|
||||
}
|
||||
|
||||
*reinterpret_cast<VecDest*>(&ptr_dest_bhs[idx_d]) = *reinterpret_cast<const VecDest*>(value_dest);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
if (params.ptr_src_dQ != nullptr) {
|
||||
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_size));
|
||||
}
|
||||
if (params.ptr_src_dK != nullptr) {
|
||||
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_size));
|
||||
}
|
||||
if (params.ptr_src_dV != nullptr) {
|
||||
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_size));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
@ -1,151 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<class Element, class ElementAcc>
|
||||
struct FmhaKernelBwdSumOdO {
|
||||
|
||||
struct Arguments {
|
||||
cute::tuple<int, int, int, cute::tuple<int, int>> problem_size;
|
||||
|
||||
const Element* ptr_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
|
||||
const Element* ptr_dO;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
|
||||
|
||||
ElementAcc* ptr_sum_OdO;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_sum_OdO;
|
||||
|
||||
const ElementAcc* ptr_lse = nullptr;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_lse;
|
||||
|
||||
ElementAcc* ptr_scaled_lse = nullptr;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_scaled_lse;
|
||||
|
||||
ElementAcc sum_odo_scale = 1.0;
|
||||
ElementAcc lse_scale = 1.0;
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
static constexpr int SharedStorageSize = 0;
|
||||
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
static const int MaxThreadsPerBlock = 128;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static const int kBlockQ = 16;
|
||||
|
||||
static const int kNumThreadsD = 8;
|
||||
static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD;
|
||||
static const int kElementsPerLoad = 2;
|
||||
|
||||
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return get<2>(args.problem_size) % kElementsPerLoad == 0;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(ceil_div(size<0>(params.problem_size), kBlockQ), size<3,0>(params.problem_size), size<3,1>(params.problem_size));
|
||||
return grid;
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
dim3 block(kNumThreadsD, kNumThreadsQ, 1);
|
||||
return block;
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O);
|
||||
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO);
|
||||
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO);
|
||||
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
|
||||
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) {
|
||||
int idx_q = idx_q_t + kBlockQ * blockIdx.x;
|
||||
if (idx_q >= get<0>(params.problem_size)) continue;
|
||||
ElementAcc acc = 0;
|
||||
auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O);
|
||||
auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO);
|
||||
auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<0>(params.stride_sum_OdO);
|
||||
auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse);
|
||||
auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse);
|
||||
|
||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||
Element value_O[kElementsPerLoad];
|
||||
Element value_dO[kElementsPerLoad];
|
||||
|
||||
using Vec = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
|
||||
*reinterpret_cast<Vec*>(value_O) = *reinterpret_cast<const Vec*>(&ptr_O_bhq[idx_d]);
|
||||
*reinterpret_cast<Vec*>(value_dO) = *reinterpret_cast<const Vec*>(&ptr_dO_bhq[idx_d]);
|
||||
|
||||
for (int v = 0; v < kElementsPerLoad; v++) {
|
||||
acc += value_O[v] * value_dO[v];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 1; i < kNumThreadsD; i *= 2) {
|
||||
acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
*ptr_sum_OdO_bhq = params.sum_odo_scale * acc;
|
||||
if (params.ptr_scaled_lse) {
|
||||
*ptr_scaled_lse_bhq = params.lse_scale * *ptr_lse_bhq;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user