Compare commits
91 Commits
cutlass-3.
...
v3.6.0
| Author | SHA1 | Date | |
|---|---|---|---|
| bf9da7b76c | |||
| 3d261a5974 | |||
| e1cd8c7866 | |||
| 33c584364e | |||
| 2b6cfd34d1 | |||
| 4c42f73fda | |||
| 80243e0b8c | |||
| b0e09d7cd3 | |||
| 8aa95dbb88 | |||
| d656afbd2a | |||
| 32e3c38aef | |||
| 9004ed2d1b | |||
| 19f51596e8 | |||
| e8a8b69365 | |||
| 08a49953a0 | |||
| a424ca6cf9 | |||
| be692b48b0 | |||
| 12626bcfe4 | |||
| f02913c34e | |||
| 03e3bffaec | |||
| e5f3caf145 | |||
| 83ae20c740 | |||
| b0c09ed077 | |||
| ea69cc2849 | |||
| f3a3bfcbf2 | |||
| d65266a868 | |||
| 5b50a8faaf | |||
| 08101d9d0c | |||
| 755194a7bd | |||
| 53668799b2 | |||
| cc3c29a81a | |||
| 0837a2a00a | |||
| 477a677317 | |||
| b27c49e84a | |||
| e2b0789927 | |||
| 44dae8b90e | |||
| 2991ce18d3 | |||
| 1ebda1ccef | |||
| 9f68995de5 | |||
| 3a8c01a18b | |||
| dbdae514e0 | |||
| 21d0534167 | |||
| 323c8170bf | |||
| 82f5075946 | |||
| 06e337758d | |||
| 7369adcaca | |||
| 6c3044136b | |||
| e1976daacc | |||
| f7b19de32c | |||
| 4dbf5dbed2 | |||
| f93a69134e | |||
| 3f084f7f3c | |||
| b0296bf682 | |||
| 865be73a97 | |||
| 8d8cfdf375 | |||
| fb170439e8 | |||
| 4e5a8f6853 | |||
| 7192f4ab23 | |||
| 2049c6c5a2 | |||
| e22ba590cd | |||
| 19b4c5e065 | |||
| 06b21349bc | |||
| eee0cab26c | |||
| 36cbfcf483 | |||
| 1f2b590da6 | |||
| 8b2a0408bd | |||
| fbd116c0e5 | |||
| 5b283c872c | |||
| be60a0b272 | |||
| 56b46e2d13 | |||
| 52fb43f30f | |||
| 843adf0408 | |||
| e48c7618e4 | |||
| c5239d8312 | |||
| d6580c3dc0 | |||
| 81b06ee0e0 | |||
| dbfced05e7 | |||
| 2448bb56e6 | |||
| 637b159063 | |||
| 033d9efd2d | |||
| acc3ee18a1 | |||
| 5c447dd84f | |||
| 7d49e6c7e2 | |||
| a40e08e9d5 | |||
| 8e7d9f483d | |||
| 19f3cc33f1 | |||
| f9ece1b42c | |||
| 28cbacbf64 | |||
| 8f7d2789b8 | |||
| c4e3e122e2 | |||
| 629f4653c3 |
323
CHANGELOG.md
323
CHANGELOG.md
@ -1,28 +1,97 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
## [3.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.6.0) (2024-10-03)
|
||||
|
||||
- [Hopper structured sparse GEMM](./examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu).
|
||||
+ [FP16](./test/unit/gemm/device/sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu)
|
||||
+ [FP8](./test/unit/gemm/device/sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu)
|
||||
+ [INT8](./test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu)
|
||||
+ [TF32](./test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu)
|
||||
- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API.
|
||||
- Improve [mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md).
|
||||
+ Added a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
|
||||
+ Added [layout pre-shuffling](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L50-55) to optimize memory loading.
|
||||
+ Added [interleaved conversion](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu#L50-52) for `{INT4, UINT4, INT8}` x `{FP16, BF16}`.
|
||||
+ Other general optimizations.
|
||||
- The suffixes of the mixed input kernel schedules have been removed. Use `KernelTmaWarpSpecialized`, `KernelTmaWarpSpecializedPingpong` and `KernelTmaWarpSpecializedCooperative` instead.
|
||||
- [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu).
|
||||
- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md).
|
||||
- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
|
||||
- A new TMA-enabled [epilogue](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support.
|
||||
- A SIMT-enabled pointer-array [epilogue](./include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp).
|
||||
- A new [Ping-Pong kernel schedule for Grouped GEMM](./include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations.
|
||||
- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/profiler.md#instantiating-more-kernels-with-hopper).
|
||||
- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h)
|
||||
- Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu).
|
||||
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
- Optimal code generation with CUDA toolkit versions 12.6.
|
||||
|
||||
## [3.5.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.1) (2024-07-25)
|
||||
|
||||
- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu)
|
||||
- [Exposure of L2 `cache_hint`s in TMA copy atoms](./include/cute/arch/copy_sm90_tma.hpp#L48)
|
||||
- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and
|
||||
[example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
|
||||
- [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu).
|
||||
- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence:
|
||||
+ [FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu#L269-L393) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu#L269-L411).
|
||||
+ [int8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452).
|
||||
+ [int4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452).
|
||||
+ [FP32 TN](./test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu#L427-L642) and [NT](./test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu#L427-L456).
|
||||
- [CUDA host adapter](./include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs.
|
||||
- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](./python/cutlass_library/generator.py).
|
||||
- Support for residual add (beta != 0) in convolution kernels.
|
||||
- A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output.
|
||||
- A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt).
|
||||
- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md).
|
||||
- Better support for MSVC as a host compiler.
|
||||
- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2.
|
||||
- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1.
|
||||
|
||||
## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09)
|
||||
|
||||
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp)
|
||||
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md).
|
||||
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](./include/cutlass/conv/convnd_problem_shape.hpp).
|
||||
+ Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
|
||||
+ [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.
|
||||
+ NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design!
|
||||
- Support for [Ada (SM89) FP8 tensor cores via the 2.x API](./examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer.
|
||||
- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_conv/README.md) in CuTe and CUTLASS 3.x
|
||||
+ Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs.
|
||||
+ Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores.
|
||||
- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices.
|
||||
+ [Ampere FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm80.cu) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu#L227-L301), [Ampere INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu#L392-L1342), [Ampere INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu#L372-L934).
|
||||
+ [Turing FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm75.cu#L55-L394), [Turing INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu#L166-L537), [Turing INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu#L310-L564).
|
||||
- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial).
|
||||
- Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337).
|
||||
- Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17.
|
||||
- Fixes to greatly reduce build warnings.
|
||||
- Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [3.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.1) (2024-02-14)
|
||||
|
||||
- Statically available [CUTLASS Version macros](/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
|
||||
- Improvements for Hopper [Group-GEMMs](/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm).
|
||||
- Statically available [CUTLASS Version macros](./include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
|
||||
- Improvements for Hopper [Group-GEMMs](./examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMMs](./examples/56_hopper_ptr_array_batched_gemm).
|
||||
- Updates and bugfixes from the community (thanks!).
|
||||
|
||||
## [3.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12)
|
||||
* Expanded [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
|
||||
* Performance improvements to [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm)
|
||||
* Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
* Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
* [Ampere Sparse GEMM](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
|
||||
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) has been officially released.
|
||||
* Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved.
|
||||
* Expanded [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
|
||||
* Performance improvements to [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm)
|
||||
* Beta release of [Pointer-Array Batched GEMMs](./examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
* Beta release of [Group-GEMM](./examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
* [Ampere Sparse GEMM](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
|
||||
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](./include/cutlass/arch/barrier.h) has been officially released.
|
||||
* Improved [CuTe documentation](./media/docs/cute/) including improved clarity and depth of [Quickstart](./media/docs/cute/00_quickstart.md), [CuTe Layout](./media/docs/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved.
|
||||
|
||||
## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31)
|
||||
* [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
|
||||
* [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
|
||||
* [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8}, and upcast on operandA {s8, u8} x {fp16, bf16}.
|
||||
* [Copy Async based Hopper GEMMs](/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors.
|
||||
* Kernel schedules and Builder support for mixed precision and Copy Async GEMMs with < 16B aligned input tensors.
|
||||
* [Copy Async based Hopper GEMMs](./test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors.
|
||||
* Kernel schedules and Builder support for mixed precision and Copy Async GEMMs with < 16B aligned input tensors.
|
||||
* Profiler support for lower-aligned Hopper GEMMs.
|
||||
* Performance Improvements to [Scatter-Gather Hopper Example](/examples/52_hopper_gather_scatter_fusion).
|
||||
* Performance Improvements to [Scatter-Gather Hopper Example](./examples/52_hopper_gather_scatter_fusion).
|
||||
* Sub-Byte type fixes and improvements.
|
||||
* EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
|
||||
* EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](./include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
|
||||
* Fusion support for backprop fusions including drelu, dgelu, and dbias.
|
||||
* Support for void-C kernels and SM80 mixed-input GEMMs in the CUTLASS Python interface
|
||||
|
||||
@ -34,7 +103,7 @@
|
||||
* SM80 EVT support in C++ and Python.
|
||||
* Other SM90 epilogue improvements.
|
||||
* Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details.
|
||||
* Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](/python/README.md) for details.
|
||||
* Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](./python/README.md) for details.
|
||||
* SM90 TF32 kernel improvements for all layouts.
|
||||
* SM90 rasterization direction support in the CUTLASS profiler.
|
||||
* Improvement for CUTLASS profiler build times.
|
||||
@ -42,65 +111,65 @@
|
||||
|
||||
## [3.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.0) (2023-08-03)
|
||||
|
||||
* New warp-specialized persistent FP8 GEMM kernel [kernel schedules](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](/examples/54_hopper_fp8_warp_specialized_gemm). FP8 GEMMs come with a fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision.
|
||||
* New [Epilogue Visitor Tree (EVT)](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
|
||||
* [Stream-K](/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
|
||||
* Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
|
||||
* New warp-specialized persistent FP8 GEMM kernel [kernel schedules](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](./include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](./examples/54_hopper_fp8_warp_specialized_gemm). FP8 GEMMs come with a fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision.
|
||||
* New [Epilogue Visitor Tree (EVT)](./examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
|
||||
* [Stream-K](./include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
|
||||
* Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](./include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
|
||||
* Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
||||
* [Hopper GEMM+Permute](/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
|
||||
* New CUTLASS 2D Convolution Python interface. New [example](/examples/python/03_basic_conv2d.ipynb) here.
|
||||
* [Hopper GEMM+Permute](./examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
|
||||
* New CUTLASS 2D Convolution Python interface. New [example](./examples/python/03_basic_conv2d.ipynb) here.
|
||||
* Support for Windows (MSVC) builds. Tested with Visual Studio 2019 v16.11.27 on Windows 10.0.
|
||||
* Optimal performance using [**CUDA 12.2u1**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14)
|
||||
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](/python/README.md) and new [examples](/examples/python).
|
||||
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](./python/README.md) and new [examples](./examples/python).
|
||||
* New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
|
||||
* Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
|
||||
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
||||
* New [*warp-specialized persistent cooperative*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper.
|
||||
* An [example](examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper.
|
||||
* Epilogue builders. Similar to mainloop builders (see [example 49](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
|
||||
* New [*warp-specialized persistent cooperative*](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper.
|
||||
* An [example](./examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper.
|
||||
* Epilogue builders. Similar to mainloop builders (see [example 49](./examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
|
||||
* Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
|
||||
* Performance optimizations for the [*warp-specialized persistent ping-pong*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
|
||||
* Changes to the [GEMM API 3.x](media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
|
||||
* [FMHA Backward Pass](examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
|
||||
* [Streamk GEMM with Broadcast](examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
|
||||
* [Batched B2B GEMM](examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
|
||||
* Performance optimizations for the [*warp-specialized persistent ping-pong*](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
|
||||
* Changes to the [GEMM API 3.x](./media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
|
||||
* [FMHA Backward Pass](./examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
|
||||
* [Streamk GEMM with Broadcast](./examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
|
||||
* [Batched B2B GEMM](./examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
|
||||
* [Batched Strided GEMV](test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
|
||||
* [Permute + GEMM fusion](examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
|
||||
* [Row Broadcast](include/cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h) can be fused in the epilogue.
|
||||
* [Permute + GEMM fusion](./examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
|
||||
* [Row Broadcast](./include/cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h) can be fused in the epilogue.
|
||||
* The GitHub branch is renamed from `master` to `main` in this release.
|
||||
* Optimal performance using [**CUDA 12.1**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23)
|
||||
* [CuTe](/media/docs/cute/00_quickstart.md), a [new core library and backend](/include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
|
||||
* [A new conceptual operation hierarchy](media/docs/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](/media/docs/gemm_api_3x.md).
|
||||
* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](media/docs/cutlass_3x_backwards_compatibility.md).
|
||||
* Updates to [Functionality](media/docs/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3.
|
||||
* Updates to [Compatibility](/README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](/README.md#Target-Architecture).
|
||||
* New warp-specialized GEMM [kernel schedules](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters.
|
||||
* [CuTe](./media/docs/cute/00_quickstart.md), a [new core library and backend](./include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
|
||||
* [A new conceptual operation hierarchy](./media/docs/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](./media/docs/gemm_api_3x.md).
|
||||
* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](./include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](./include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](./media/docs/cutlass_3x_backwards_compatibility.md).
|
||||
* Updates to [Functionality](./media/docs/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3.
|
||||
* Updates to [Compatibility](./README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](./README.md#Target-Architecture).
|
||||
* New warp-specialized GEMM [kernel schedules](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](./include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters.
|
||||
* Extensions to CUTLASS profiler to support threadblock cluster shapes in library and profiler tile configurations.
|
||||
* [CUTLASS library integration](/tools/library/src/gemm_operation_3x.hpp) for 3.x API kernels built through the new `CollectiveBuilder` API, enabling CUTLASS profiler.
|
||||
* Support for [Hopper GEMMs](examples/48_hopper_warp_specialized_gemm) through the new 3.0 API with CuTe-based exposure of the Hopper [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) and [WGMMA Tensor Core](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) features.
|
||||
* Set of examples that demonstrate the usage of the new 3.0 API to easily build GEMM kernels targeting Hopper: examples [48](examples/48_hopper_warp_specialized_gemm), [49](examples/49_hopper_gemm_schedules_with_collective_builder), and [50](examples/50_hopper_gemm_with_epilogue_swizzle).
|
||||
* [CUTLASS library integration](./tools/library/src/gemm_operation_3x.hpp) for 3.x API kernels built through the new `CollectiveBuilder` API, enabling CUTLASS profiler.
|
||||
* Support for [Hopper GEMMs](./examples/48_hopper_warp_specialized_gemm) through the new 3.0 API with CuTe-based exposure of the Hopper [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) and [WGMMA Tensor Core](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) features.
|
||||
* Set of examples that demonstrate the usage of the new 3.0 API to easily build GEMM kernels targeting Hopper: examples [48](./examples/48_hopper_warp_specialized_gemm), [49](./examples/49_hopper_gemm_schedules_with_collective_builder), and [50](./examples/50_hopper_gemm_with_epilogue_swizzle).
|
||||
|
||||
## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19)
|
||||
* [Stream-K](/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||
* [Fused multi-head attention Kernel](/examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
|
||||
* [Dual GEMM](/examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency.
|
||||
* Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
|
||||
* [BLAS3](/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
|
||||
* [ELL Block Sparse GEMM](/examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary.
|
||||
* Optimized [Group Conv](/examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N.
|
||||
* [Optimized DepthWise Conv](/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added
|
||||
* [kOptimized](/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM.
|
||||
* [Stream-K](./examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||
* [Fused multi-head attention Kernel](./examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
|
||||
* [Dual GEMM](./examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency.
|
||||
* Hopper improves [double precision matrix multiplication](./test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
|
||||
* [BLAS3](./test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
|
||||
* [ELL Block Sparse GEMM](./examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary.
|
||||
* Optimized [Group Conv](./examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N.
|
||||
* [Optimized DepthWise Conv](./examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added
|
||||
* [kOptimized](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM.
|
||||
* The restrictions are: 1) input ,output channel and group number should be multiple of (128 / sizeof(input element)). 2) The input filter size should be the same as the template parameter configuration.
|
||||
* [kFixedStrideDilation](/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded.
|
||||
* The restrictions are: 1) input, output channel and group number should be multiple of (128 / sizeof(input element)). 2) input filter size, stride, dilation should same as the template parameter configuration.
|
||||
* [Scripts](/examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/).
|
||||
* [FP8 data type definition](/include/cutlass/float8.h) and [conversion routines](/include/cutlass/numeric_conversion.h#L1274-2115).
|
||||
* [kFixedStrideDilation](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded.
|
||||
* The restrictions are: 1) input, output channel and group number should be multiple of (128 / sizeof(input element)). 2) input filter size, stride, dilation should same as the template parameter configuration.
|
||||
* [Scripts](./examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/).
|
||||
* [FP8 data type definition](./include/cutlass/float8.h) and [conversion routines](./include/cutlass/numeric_conversion.h#L1274-2115).
|
||||
* Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers).
|
||||
|
||||
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
@ -109,54 +178,54 @@
|
||||
* CUDA 10.2
|
||||
|
||||
## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23)
|
||||
* [CUTLASS Python](/examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
|
||||
* Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too.
|
||||
* Optimizations for [GEMM+Softmax](examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance.
|
||||
* [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing.
|
||||
* [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues.
|
||||
* [GEMM Epilogue Permutation Fusion](examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue.
|
||||
* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes:
|
||||
* [CUTLASS Python](./examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
|
||||
* Optimizations for CUTLASS's [Grouped GEMM](./examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](./examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too.
|
||||
* Optimizations for [GEMM+Softmax](./examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance.
|
||||
* [Grouped GEMM for Multihead Attention](./examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing.
|
||||
* [GEMM + Layer norm fusion for Ampere](./examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues.
|
||||
* [GEMM Epilogue Permutation Fusion](./examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue.
|
||||
* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes:
|
||||
* kSingleGroup: output channel per group is multiple of Threadblock tile N.
|
||||
* kMultipleGroup: Threadblock tile N is multiple of output channel per group.
|
||||
* [Depthwise separable convolution](test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number.
|
||||
* Standalone [Layernorm](/tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](/tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
|
||||
* [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) relaxes the requirement that the first GEMM K dimension needs to be the multiple of Threadblock Tile K dimension.
|
||||
* Standalone [Layernorm](./tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](./tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
|
||||
* [Back-to-back GEMM/CONV](./examples/13_two_tensor_op_fusion) relaxes the requirement that the first GEMM K dimension needs to be the multiple of Threadblock Tile K dimension.
|
||||
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [2.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.9.0) (2022-04-21)
|
||||
|
||||
* [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
|
||||
* [Few channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
|
||||
* [Fixed channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
|
||||
* [Unit tests](/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
|
||||
* [Python-based instance emitter](/python/cutlass_library/generator.py) in the CUTLASS Library and support in the Profiler
|
||||
* [First layer Convolution kernels](./test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
|
||||
* [Few channels](./include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
|
||||
* [Fixed channels](./include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
|
||||
* [Unit tests](./test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
|
||||
* [Python-based instance emitter](./python/cutlass_library/generator.py) in the CUTLASS Library and support in the Profiler
|
||||
* [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
|
||||
* Supported types: f32, cf32, f64, cf64, tf32x3, complex tf32x3
|
||||
* [HERK](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](/python/cutlass_library/rank_k_operation.py)
|
||||
* [SYRK](/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](/python/cutlass_library/rank_k_operation.py)
|
||||
* [SYMM](/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/python/cutlass_library/symm_operation.py)
|
||||
* [TRMM](/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/python/cutlass_library/trmm_operation.py)
|
||||
* [Unit tests](/test/unit/gemm/device/testbed_rank_k_universal.h)
|
||||
* [CUTLASS Python](/examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
|
||||
* [Python-based runtime](/tools/library/scripts/rt.py) interoperable with existing emitters
|
||||
* [GEMM + Softmax example](/examples/35_gemm_softmax)
|
||||
* [Gather and Scatter Fusion with GEMM](/examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel.
|
||||
* [HERK](./test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](./python/cutlass_library/rank_k_operation.py)
|
||||
* [SYRK](./test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](./python/cutlass_library/rank_k_operation.py)
|
||||
* [SYMM](./test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](./python/cutlass_library/symm_operation.py)
|
||||
* [TRMM](./test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](./python/cutlass_library/trmm_operation.py)
|
||||
* [Unit tests](./test/unit/gemm/device/testbed_rank_k_universal.h)
|
||||
* [CUTLASS Python](./examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
|
||||
* [Python-based runtime](./tools/library/scripts/rt.py) interoperable with existing emitters
|
||||
* [GEMM + Softmax example](./examples/35_gemm_softmax)
|
||||
* [Gather and Scatter Fusion with GEMM](./examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel.
|
||||
* It can select random rows in a row major matrix.
|
||||
* It can select random columns in a column major matrix.
|
||||
* [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. It can eliminate register spill when the tile size is big. Additionally, bias vector add is supported in the first GEMM/CONV.
|
||||
* [Back-to-back GEMM/CONV](./examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. It can eliminate register spill when the tile size is big. Additionally, bias vector add is supported in the first GEMM/CONV.
|
||||
* Supported kernels: GEMM and CONV.
|
||||
* Supported types: fp16 and int8.
|
||||
* Supported architectures: Turing and Ampere.
|
||||
* [Transposed Convolution](/examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation.
|
||||
* [Utility functions](/tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC.
|
||||
* [Transposed Convolution](./examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation.
|
||||
* [Utility functions](./tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC.
|
||||
* [Small alignment implicit gemm](https://github.com/NVIDIA/cutlass/issues/242) support for Fprop/Dgrad/Wgrad so that padding is no longer mandated to use tensor cores in these kernels.
|
||||
* Epilogue enhancement:
|
||||
* Eliminate bank conflicts in int8 tensor core kernels.
|
||||
* Half2 usage if epilogue compute type is fp16.
|
||||
* More activation functions: Silu, Hardswish, Leaky Relu.
|
||||
* New elementwise fusion pattern for [residual block](/include/cutlass/epilogue/thread/linear_combination_residual_block.h).
|
||||
* [Group GEMM](/examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs.
|
||||
* New elementwise fusion pattern for [residual block](./include/cutlass/epilogue/thread/linear_combination_residual_block.h).
|
||||
* [Group GEMM](./examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs.
|
||||
* [Parallel GEMM splitk](https://github.com/NVIDIA/cutlass/pull/277) support in the CUTLASS profiler.
|
||||
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
@ -166,17 +235,17 @@
|
||||
|
||||
* **TF32x3:** emulated single-precision using Tensor Cores
|
||||
* 45+ TFLOPs on NVIDIA A100
|
||||
* [GEMM SDK example](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu) (real)
|
||||
* [COMPLEX GEMM SDK example](/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu) (complex)
|
||||
* [Implicit GEMM Convolution SDK example](/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu)
|
||||
* [GEMM SDK example](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu) (real)
|
||||
* [COMPLEX GEMM SDK example](./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu) (complex)
|
||||
* [Implicit GEMM Convolution SDK example](./examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu)
|
||||
* **Mainloop fusion for Convolution:** convolution with fused per-channel scale-bias-relu
|
||||
* [Conv Fprop SDK example](/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu)
|
||||
* [Conv WGrad SDK example](/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu)
|
||||
* [cutlass::conv::device::ImplicitGemmConvolutionFusion](/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h)
|
||||
* [Conv Fprop SDK example](./examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu)
|
||||
* [Conv WGrad SDK example](./examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu)
|
||||
* [cutlass::conv::device::ImplicitGemmConvolutionFusion](./include/cutlass/conv/device/implicit_gemm_convolution_fusion.h)
|
||||
* **Grouped GEMM:** similar to batched GEMM with distinct problem size per group
|
||||
* [SDK example](/examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM
|
||||
* [cutlass::gemm::device::GemmGrouped](/include/cutlass/gemm/device/gemm_grouped.h)
|
||||
* [Implicit GEMM Convolution fusion](/examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing. This allows more flexible warp tile sizes and less regsiter pressue.
|
||||
* [SDK example](./examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM
|
||||
* [cutlass::gemm::device::GemmGrouped](./include/cutlass/gemm/device/gemm_grouped.h)
|
||||
* [Implicit GEMM Convolution fusion](./examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing. This allows more flexible warp tile sizes and less regsiter pressue.
|
||||
* Optimal performance using [**CUDA 11.5**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates from the community (thanks!)
|
||||
|
||||
@ -186,11 +255,11 @@
|
||||
* CUDA 10.2
|
||||
|
||||
## [2.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.7.0) (2021-09-24)
|
||||
* Mainloop fusion for GEMM: [summation over A or B](/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
|
||||
* [Strided DGRAD (optimized iterators)](/include/cutlass/conv/kernel/default_conv2d_dgrad.h)
|
||||
* [Half-precision GELU_taylor activation functions](/include/cutlass/epilogue/thread/activation.h#L196)
|
||||
* Mainloop fusion for GEMM: [summation over A or B](./examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
|
||||
* [Strided DGRAD (optimized iterators)](./include/cutlass/conv/kernel/default_conv2d_dgrad.h)
|
||||
* [Half-precision GELU_taylor activation functions](./include/cutlass/epilogue/thread/activation.h#L196)
|
||||
* Use these when accumulation and epilogue compute types are all `cutlass::half_t`
|
||||
* Tuning and bug fixes to [fused GEMM + GEMM example](/examples/13_two_tensor_op_fusion/)
|
||||
* Tuning and bug fixes to [fused GEMM + GEMM example](./examples/13_two_tensor_op_fusion/)
|
||||
* Support for smaller than 128b aligned Convolutions: [see examples](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu#L272)
|
||||
* Caching of results to accelerate Convolution [unit tests](test/unit/conv/device/cache_testbed_output.h)
|
||||
* Can be enabled or disabled by running `cmake .. -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=OFF`
|
||||
@ -205,27 +274,27 @@
|
||||
|
||||
## [2.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.0) (2021-07-22)
|
||||
* Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
* Adopt the new L2 prefetch feature in [cp.async](/include/cutlass/arch/memory.h) and [global load](/include/cutlass/arch/memory_sm80.h)
|
||||
* Adopt the new L2 prefetch feature in [cp.async](./include/cutlass/arch/memory.h) and [global load](./include/cutlass/arch/memory_sm80.h)
|
||||
* Fused operators with GEMM and Convolution
|
||||
* [Fused broadcast in epilogue](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu)
|
||||
* [Fused partial reduction in epilogue](/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu)
|
||||
* [Fused partial reduction in epilogue](./test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu)
|
||||
* 64b tensor strides and leading dimensions support for GEMMs
|
||||
* Affine rank=2 matrix layouts
|
||||
* Row stride and column stride for matrices using [cutlass::layout::AffineRank2](/include/cutlass/layout/matrix.h)
|
||||
* Support [FP64 tensor core](/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM.
|
||||
* [Batched GEMV](/test/unit/gemm/device/gemv.cu) preview implementation
|
||||
* Affine rank=2 matrix layouts
|
||||
* Row stride and column stride for matrices using [cutlass::layout::AffineRank2](./include/cutlass/layout/matrix.h)
|
||||
* Support [FP64 tensor core](./examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM.
|
||||
* [Batched GEMV](./test/unit/gemm/device/gemv.cu) preview implementation
|
||||
* [New strided Dgrad](test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation
|
||||
* Accelerates over previous implementation by cutting down redundant math by 4x
|
||||
* Support using new `Dy` and `w` analytic iterators and existing `cutlass::conv::device::ImplicitGemmConvolution` interface
|
||||
* Quaternion-valued GEMM and Convolution in single- and double-precision (targeting CUDA Cores)
|
||||
* Updates to [quaternion.h](/include/cutlass/quaternion.h) and [functional.h](/include/cutlass/functional.h)
|
||||
* SDK Example for [GEMM](/examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](/examples/22_quaternion_conv/quaternion_conv.cu)
|
||||
* [Unit tests for GEMM](/test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu)
|
||||
* Updates to [quaternion.h](./include/cutlass/quaternion.h) and [functional.h](./include/cutlass/functional.h)
|
||||
* SDK Example for [GEMM](./examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](./examples/22_quaternion_conv/quaternion_conv.cu)
|
||||
* [Unit tests for GEMM](./test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](./test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu)
|
||||
* Many improvements to the epilogue.
|
||||
* Provide an [option](/include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
|
||||
* Provide an [option](./include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
|
||||
* Performance improvement for FP16 tensor core kernels
|
||||
* Bug fixes
|
||||
* Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere.
|
||||
* Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere.
|
||||
* Updated minimum CUDA Toolkit requirement to 10.2
|
||||
* [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) recommended
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
@ -234,17 +303,17 @@
|
||||
## [2.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.5.0) (2021-02-26)
|
||||
* Tensor reductions
|
||||
* _m_-to-_n_ reductions of tensors with affine layout
|
||||
* [Specializations](/test/unit/reduction/device/tensor_reduce_contiguous.cu) for reductions including contiguous dimension
|
||||
* [Specializations](/test/unit/reduction/device/tensor_reduce_strided.cu) for reductions excluding contiguous dimension
|
||||
* [Specializations](./test/unit/reduction/device/tensor_reduce_contiguous.cu) for reductions including contiguous dimension
|
||||
* [Specializations](./test/unit/reduction/device/tensor_reduce_strided.cu) for reductions excluding contiguous dimension
|
||||
* Custom reduction functors such as `cutlass::logical_and`
|
||||
* Large tensor support, up to 2^63 elements (however, each dimension is limited to an extent of 2^31)
|
||||
* Optimizations for 3-D convolution
|
||||
* [Optimized tile iterators](include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h) using precomputed delta table for 3-D convolution
|
||||
* [Optimized tile iterators](./include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h) using precomputed delta table for 3-D convolution
|
||||
* Full coverage of [forward](test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) and [backwards](test/unit/conv/device/conv3d_dgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) passes for 3D convolution
|
||||
* [Fused Convolution+Convolution example](/examples/13_two_tensor_op_fusion/README.md)
|
||||
* [Fused Convolution+Convolution example](./examples/13_two_tensor_op_fusion/README.md)
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
* Thank you for filing these issues!
|
||||
|
||||
|
||||
|
||||
## [2.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.4.0) (2020-11-19)
|
||||
* Implicit GEMM convolution kernels supporting CUDA and Tensor Cores on NVIDIA GPUs
|
||||
@ -252,11 +321,11 @@
|
||||
* Data type: FP32, complex<FP32>, Tensor Float 32 (TF32), BFloat16 (BF16), Float16, Int4, Int8, Int32
|
||||
* Spatial dimensions: 1-D, 2-D, and 3-D
|
||||
* Layout: NHWC, NCxHWx
|
||||
* Implicit GEMM convolution components:
|
||||
* Implicit GEMM convolution components:
|
||||
* Global memory iterators supporting Fprop, Dgrad, and Wgrad
|
||||
* `MmaMultistage` for implicit GEMM convolution for NVIDIA Ampere architecture
|
||||
* `MmaPipeline` for implicit GEMM convolution for NVIDIA Volta and Turing architectures
|
||||
* [Documentation](/media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
|
||||
* [Documentation](./media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
|
||||
|
||||
## [2.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.3.0) (2020-09-23)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
@ -264,21 +333,21 @@
|
||||
* Direct access to Sparse Tensor Cores and maximum performance via [`mma.sp.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
|
||||
* Fast SGEMM targeting GeForce RTX 30-series CUDA Cores
|
||||
* Minor Features:
|
||||
* [Activation functions](/include/cutlass/epilogue/thread/activation.h) such as [GeLU](/include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](/include/cutlass/epilogue/thread/linear_combination_sigmoid.h)
|
||||
* Small [matrix](/include/cutlass/matrix.h) and [quaternion](/include/cutlass/quaternion.h) template classes in device code
|
||||
* [Floating-point constants](/include/cutlass/constants.h)
|
||||
* [Activation functions](./include/cutlass/epilogue/thread/activation.h) such as [GeLU](./include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](./include/cutlass/epilogue/thread/linear_combination_sigmoid.h)
|
||||
* Small [matrix](./include/cutlass/matrix.h) and [quaternion](./include/cutlass/quaternion.h) template classes in device code
|
||||
* [Floating-point constants](./include/cutlass/constants.h)
|
||||
* NVIDIA Ampere GPU Architecture examples and documentation:
|
||||
* [Tensor Float 32](/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and
|
||||
* [Sparse Tensor Cores](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu)
|
||||
* Documentation added on CUTLASS [efficient row-major epilogue](/media/docs/gemm_api.md#efficient-epilogue)
|
||||
* [Tensor Float 32](./examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and
|
||||
* [Sparse Tensor Cores](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu)
|
||||
* Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/gemm_api.md#efficient-epilogue)
|
||||
|
||||
## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
* Fast Tensor Core operations:
|
||||
* Fast Tensor Core operations:
|
||||
* Maximum performance via [`mma.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
|
||||
* Tensor Float 32, BFloat16, and double-precision data types
|
||||
* Mixed integer data types (int8, int4, bin1)
|
||||
* Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution)
|
||||
* Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution)
|
||||
* Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745) (free registration required)
|
||||
* Features:
|
||||
* SDK examples showing GEMM fused with bias+relu and fused GEMM+GEMM
|
||||
@ -290,11 +359,11 @@
|
||||
* Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON`
|
||||
|
||||
## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06)
|
||||
* BLAS-style host-side API added to [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
|
||||
* BLAS-style host-side API added to [CUTLASS Library](./media/docs/quickstart.md#cutlass-library)
|
||||
* API to launch compiled kernel instances for GEMM and planar complex GEMM
|
||||
* Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores
|
||||
* Computes complex matrix products on matrices stored as disjoint real and imaginary parts
|
||||
* [SDK Examples of Planar Complex GEMMs](/examples/10_planar_complex/planar_complex.cu)
|
||||
* [SDK Examples of Planar Complex GEMMs](./examples/10_planar_complex/planar_complex.cu)
|
||||
* Minor enhancements and bug fixes
|
||||
|
||||
## [2.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.0.0) (2019-11-19)
|
||||
@ -304,10 +373,10 @@
|
||||
* Encapsulated functionality embodying modern C++11 programming techniques
|
||||
* Optimized containers and data types for efficient, generic, portable device code
|
||||
* Updates to:
|
||||
* [Quick start guide](/media/docs/quickstart.md)
|
||||
* [Documentation](/README.md#documentation)
|
||||
* [Utilities](/media/docs/utilities.md)
|
||||
* [CUTLASS Profiler](/media/docs/profiler.md)
|
||||
* [Quick start guide](./media/docs/quickstart.md)
|
||||
* [Documentation](./README.md#documentation)
|
||||
* [Utilities](./media/docs/utilities.md)
|
||||
* [CUTLASS Profiler](./media/docs/profiler.md)
|
||||
* Native Turing Tensor Cores
|
||||
* Efficient GEMM kernels targeting Turing Tensor Cores
|
||||
* Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands
|
||||
|
||||
375
CMakeLists.txt
375
CMakeLists.txt
@ -38,7 +38,7 @@ else()
|
||||
endif()
|
||||
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set")
|
||||
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++17 if set")
|
||||
|
||||
# To reduce duplicate version locations, parse the version out of the
|
||||
# main versions.h file and reuse it here.
|
||||
@ -59,6 +59,18 @@ project(CUTLASS VERSION ${_CUTLASS_VERSION_MAJOR}.${_CUTLASS_VERSION_MINOR}.${_C
|
||||
|
||||
################################################################################
|
||||
|
||||
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU")
|
||||
set(CUTLASS_GNU_HOST_COMPILE ON CACHE BOOL "Using GNU tools for host code compilation")
|
||||
endif()
|
||||
if (CMAKE_CXX_COMPILER_ID MATCHES "[Cc]lang")
|
||||
set(CUTLASS_CLANG_HOST_COMPILE ON CACHE BOOL "Using Clang tools for host code compilation")
|
||||
endif()
|
||||
if (CMAKE_CXX_COMPILER_ID MATCHES "MSVC")
|
||||
set(CUTLASS_MSVC_HOST_COMPILE ON CACHE BOOL "Using MSVC tools for host code compilation")
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 11.3)
|
||||
@ -67,14 +79,13 @@ elseif (CUDA_VERSION VERSION_LESS 11.4)
|
||||
message(WARNING "CUTLASS ${CUTLASS_VERSION} support for CUDA ${CUDA_VERSION} is deprecated, please use CUDA 11.8 or higher.")
|
||||
endif()
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.5)
|
||||
message(FATAL_ERROR "GCC version must be at least 7.5!")
|
||||
if(CUTLASS_GNU_HOST_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3)
|
||||
message(FATAL_ERROR "GCC version must be at least 7.3!")
|
||||
endif()
|
||||
|
||||
if (CUDA_COMPILER MATCHES "[Cc]lang" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0)
|
||||
if (CUTLASS_CLANG_DEVICE_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0)
|
||||
message(FATAL_ERROR "Clang 7.0+ required for GPU compilation")
|
||||
endif()
|
||||
|
||||
find_package(Doxygen QUIET)
|
||||
|
||||
################################################################################
|
||||
@ -86,14 +97,11 @@ set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
||||
if(CUTLASS_NATIVE_CUDA)
|
||||
set(CMAKE_CUDA_STANDARD 17)
|
||||
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --expt-relaxed-constexpr)
|
||||
else()
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++17)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CUDA_STANDARD 17)
|
||||
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
|
||||
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --expt-relaxed-constexpr)
|
||||
|
||||
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
|
||||
set(CMAKE_INSTALL_PREFIX install CACHE PATH "Default installation location." FORCE)
|
||||
endif()
|
||||
@ -134,16 +142,26 @@ set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUT
|
||||
|
||||
set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests")
|
||||
set(CUTLASS_ENABLE_GTEST_UNIT_TESTS ${CUTLASS_ENABLE_TESTS} CACHE BOOL "Enable CUTLASS GTest-based Unit Tests")
|
||||
set(CUTLASS_USE_SYSTEM_GOOGLETEST OFF CACHE BOOL "Use system/external installation of GTest")
|
||||
set(CUTLASS_USE_PACKED_TUPLE ON CACHE BOOL "If ON, make cute::tuple be new standard-layout tuple type; if OFF, use the original cute::tuple implementation that is _not_ standard-layout.")
|
||||
if (CUTLASS_USE_PACKED_TUPLE)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_USE_PACKED_TUPLE=1)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCUTLASS_USE_PACKED_TUPLE=1")
|
||||
message(STATUS "Make cute::tuple be the new standard-layout tuple type")
|
||||
elseif()
|
||||
message(STATUS "Use the original cute::tuple implementation that is _not_ standard-layout")
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
|
||||
set(CUTLASS_NVCC_ARCHS_SUPPORTED "")
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.4 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.4)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 70 72 75 80 86 87)
|
||||
endif()
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 89 90)
|
||||
endif()
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 90a)
|
||||
endif()
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||
@ -168,6 +186,7 @@ endif()
|
||||
include(GNUInstallDirs)
|
||||
|
||||
link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
|
||||
link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
@ -215,7 +234,7 @@ if (${CUTLASS_NVCC_VERBOSE})
|
||||
endif()
|
||||
|
||||
#
|
||||
# CUTLASS NAMESPACE
|
||||
# CUTLASS NAMESPACE
|
||||
#
|
||||
set(CUTLASS_NAMESPACE "cutlass" CACHE STRING "Top level namespace of CUTLASS")
|
||||
|
||||
@ -233,15 +252,15 @@ set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.
|
||||
|
||||
set(KERNEL_FILTER_FILE "" CACHE STRING "KERNEL FILTER FILE FULL PATH")
|
||||
|
||||
if (KERNEL_FILTER_FILE AND NOT CUTLASS_LIBRARY_KERNELS)
|
||||
if (KERNEL_FILTER_FILE AND NOT CUTLASS_LIBRARY_KERNELS)
|
||||
# If a kernel filter file is specified, we want to generate and then
|
||||
# filter on the entire kernel set, not the default kernel
|
||||
# (sub)set. The user may overried CUTLASS_LIBRRARY_KERNELS, in which
|
||||
# (sub)set. The user may have overridden CUTLASS_LIBRARY_KERNELS, in which
|
||||
# case the resulting kernel set will be the intersection of the two
|
||||
# options differenced against CUTLASS_LIBRARY_IGNORE_KERNELS.
|
||||
set(CUTLASS_LIBRARY_KERNELS_INIT "*")
|
||||
else()
|
||||
set(CUTLASS_LIBRARY_KERNELS_INIT "")
|
||||
else()
|
||||
set(CUTLASS_LIBRARY_KERNELS_INIT "")
|
||||
endif()
|
||||
|
||||
if (KERNEL_FILTER_FILE)
|
||||
@ -255,9 +274,11 @@ if(KERNEL_FILTER_FILE)
|
||||
message(STATUS "Full path of filter file: ${KERNEL_FILTER_FILE}")
|
||||
endif()
|
||||
|
||||
set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.")
|
||||
set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.")
|
||||
set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma delimited list of kernel names to exclude from build.")
|
||||
set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma-delimited list of operation name filters. Default '' means all operations are enabled.")
|
||||
set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma-delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If the string 'all' is specified, all kernels are enabled.")
|
||||
set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option ONLY takes effect if CUTLASS_LIBRARY_KERNELS is set.")
|
||||
set(CUTLASS_LIBRARY_EXCLUDE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option always takes effect, whether or not CUTLASS_LIBRARY_KERNELS is set. It also can exclude kernels from the filter file (see KERNEL_FILTER_FILE).")
|
||||
set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.")
|
||||
|
||||
################################################################################
|
||||
|
||||
@ -298,6 +319,15 @@ list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_DEBUG_TRACE_LEVEL=${CUTLASS_DEBUG_
|
||||
set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
|
||||
"Enable PTX mma instruction for collective matrix multiply operations.")
|
||||
|
||||
set(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES OFF CACHE BOOL
|
||||
"Enable an extended set of SM90 WGMMA instruction shapes (may lead to increased compilation times)")
|
||||
if(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES)
|
||||
message(STATUS "Enabled extended SM90 WGMMA instruction shapes")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_SKIP_REDUCTION_INIT OFF CACHE BOOL "Disable init reduction workspace")
|
||||
|
||||
#
|
||||
# NOTE: running with asan and CUDA requires the following environment variable:
|
||||
#
|
||||
@ -325,13 +355,53 @@ if(CUTLASS_NVCC_EMBED_PTX)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-include-ptx=all)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_SKIP_REDUCTION_INIT)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SKIP_REDUCTION_INIT=1)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_PROFILER_DISABLE_REFERENCE OFF CACHE BOOL "Disable compilation of reference kernels in the CUTLASS profiler.")
|
||||
if (CUTLASS_PROFILER_DISABLE_REFERENCE)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_PROFILER_DISABLE_REFERENCE=1)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_ENABLE_GDC_FOR_SM90)
|
||||
message(STATUS "Grid Dependency Control (GDC) is enabled for SM90 kernels (required for programmatic dependent launches).")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_ENABLE_GDC_FOR_SM90=1)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_SYNCLOG OFF CACHE BOOL "Enable synchronization event logging for race condition debugging. WARNING: This redefines __syncthreads() and __syncwarp() in all downstream code!")
|
||||
|
||||
if (CUTLASS_ENABLE_SYNCLOG)
|
||||
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
|
||||
string(APPEND CMAKE_CXX_FLAGS " -DCUTLASS_ENABLE_SYNCLOG=1")
|
||||
string(APPEND CMAKE_CUDA_FLAGS " -DCUTLASS_ENABLE_SYNCLOG=1")
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
# Warnings-as-error exceptions and warning suppressions for Clang builds
|
||||
if (CUTLASS_CLANG_HOST_COMPILE)
|
||||
|
||||
set(FLAGS_TO_ADD
|
||||
"-Wno-error=implicit-int-conversion"
|
||||
"-Wno-error=pass-failed"
|
||||
"-Wno-error=inconsistent-missing-override"
|
||||
"-Wno-sign-conversion"
|
||||
"-Wno-unused-parameter"
|
||||
)
|
||||
|
||||
foreach(FLAG ${FLAGS_TO_ADD})
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLAG}")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS "${FLAG}")
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS "${FLAG}")
|
||||
endforeach()
|
||||
|
||||
endif()
|
||||
|
||||
if (NOT MSVC AND CUTLASS_NVCC_KEEP)
|
||||
# MSVC flow handles caching already, but for other generators we handle it here.
|
||||
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
|
||||
@ -342,9 +412,9 @@ endif()
|
||||
|
||||
if (CUTLASS_ENABLE_F16C AND NOT CMAKE_CROSSCOMPILING)
|
||||
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_F16C=1)
|
||||
if ((CMAKE_CXX_COMPILER_ID MATCHES "GNU") OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang"))
|
||||
if (CUTLASS_GNU_HOST_COMPILE OR CUTLASS_CLANG_HOST_COMPILE)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-mf16c)
|
||||
elseif((CMAKE_CXX_COMPILER_ID MATCHES "MSVC"))
|
||||
elseif(CUTLASS_MSVC_HOST_COMPILE)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/arch:AVX2)
|
||||
endif()
|
||||
endif()
|
||||
@ -357,6 +427,7 @@ if (CUTLASS_ENABLE_OPENMP_TESTS)
|
||||
message(WARNING "CUTLASS_ENABLE_OPENMP_TESTS set but OpenMP not found.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(UNIX)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-Wconversion)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-fno-strict-aliasing)
|
||||
@ -368,24 +439,13 @@ if (NOT CMAKE_BUILD_TYPE MATCHES "Release")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -lineinfo)
|
||||
endif()
|
||||
|
||||
#Report CUDA build flags
|
||||
if (CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
if(CUTLASS_CUDA_CLANG_FLAGS)
|
||||
message(STATUS "Using CLANG flags: ${CUTLASS_CUDA_CLANG_FLAGS}")
|
||||
endif()
|
||||
else()
|
||||
if(CUTLASS_CUDA_NVCC_FLAGS)
|
||||
message(STATUS "Using NVCC flags: ${CUTLASS_CUDA_NVCC_FLAGS}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
if( NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang" )
|
||||
if (CUTLASS_CLANG_DEVICE_COMPILE)
|
||||
if (NOT CUTLASS_CLANG_HOST_COMPILE)
|
||||
message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" )
|
||||
endif()
|
||||
|
||||
# There are numerous Clang versions that can work with each CUDA toolkit and the
|
||||
# the checks are not very useful so we are turning them off and using testing to
|
||||
# There are numerous Clang versions that can work with each CUDA toolkit and the
|
||||
# the checks are not very useful so we are turning them off and using testing to
|
||||
# ensure the various combinations work properly.
|
||||
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-path=${CUDA_TOOLKIT_ROOT_DIR})
|
||||
@ -396,12 +456,8 @@ if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -unroll-threshold=5000)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument)
|
||||
|
||||
string(REPLACE "." ";" CUDA_VERSION_PARTS ${CMAKE_CUDA_COMPILER_VERSION})
|
||||
list(GET CUDA_VERSION_PARTS 0 CUDA_VERSION_MAJOR)
|
||||
list(GET CUDA_VERSION_PARTS 1 CUDA_VERSION_MINOR)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -D__CUDACC_VER_MAJOR__=${CUDA_VERSION_MAJOR} -D__CUDACC_VER_MINOR__=${CUDA_VERSION_MINOR})
|
||||
|
||||
|
||||
# needed for libcublasLt.so in case it's installed in the same location as libcudart.so
|
||||
# dynamic linker can find it if linker sets RPATH (forced by --disable-new-tags)
|
||||
# Otherwise linker uses RUNPATH and that does not propagate to loaded libs.
|
||||
@ -409,34 +465,54 @@ if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
|
||||
link_libraries(nvidia::cudart)
|
||||
link_libraries(nvidia::cuda_driver)
|
||||
|
||||
endif()
|
||||
|
||||
# Support for 128-bit integers if using NVIDIA C++ compiler
|
||||
#Report CUDA build flags
|
||||
if (CUTLASS_CLANG_DEVICE_COMPILE AND CUTLASS_CUDA_CLANG_FLAGS)
|
||||
set(__FLAG_GROUP Clang)
|
||||
set(__FLAG_LIST CUTLASS_CUDA_CLANG_FLAGS)
|
||||
else(CUTLASS_NVCC_DEVICE_COMPILE AND CUTLASS_CUDA_NVCC_FLAGS)
|
||||
set(__FLAG_GROUP NVCC)
|
||||
set(__FLAG_LIST CUTLASS_CUDA_NVCC_FLAGS)
|
||||
endif()
|
||||
|
||||
set(__FLAG_DISPLAY_STRING "")
|
||||
set(__FLAG_DISPLAY_SEPARATOR)
|
||||
list(JOIN ${__FLAG_LIST} "\n " __FLAG_DISPLAY_STRING)
|
||||
message(STATUS "Using the following ${__FLAG_GROUP} flags: \n ${__FLAG_DISPLAY_STRING}")
|
||||
|
||||
# Known gcc 8.1-8.3 SFINAE issue (fixed in gcc 8.4), check https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87748
|
||||
# Also see https://github.com/NVIDIA/nccl/issues/835 for nvtx3.hpp
|
||||
if (CUTLASS_GNU_HOST_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8.1 AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS_EQUAL 8.3)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNVTX3_USE_CHECKED_OVERLOADS_FOR_GET=0")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DNVTX3_USE_CHECKED_OVERLOADS_FOR_GET=0")
|
||||
endif()
|
||||
|
||||
# Support for 128-bit integers if using NVIDIA C++ compiler
|
||||
if (${CMAKE_CXX_COMPILER_ID} MATCHES "PGI" OR ${CMAKE_CXX_COMPILER_ID} MATCHES "NVHPC")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Mint128 ")
|
||||
endif()
|
||||
|
||||
if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.18)
|
||||
# CMake 3.18 added support for CUDA_ARCHITECTURES target property. We will use this
|
||||
# property for CMake 3.18+, so we request the NEW behavior for correct compatibility.
|
||||
# https://cmake.org/cmake/help/v3.18/policy/CMP0104.html#policy:CMP0104
|
||||
cmake_policy(SET CMP0104 NEW)
|
||||
endif()
|
||||
# CMake 3.18 added support for CUDA_ARCHITECTURES target property. We will use this
|
||||
# property for CMake 3.18+, so we request the NEW behavior for correct compatibility.
|
||||
# https://cmake.org/cmake/help/v3.18/policy/CMP0104.html#policy:CMP0104
|
||||
cmake_policy(SET CMP0104 NEW)
|
||||
|
||||
if (MSVC)
|
||||
|
||||
|
||||
# MSVC by default does not apply the correct __cplusplus version as specified by the C++ standard
|
||||
# because MSVC is not a completely compliant implementation. This option forces MSVC to use the
|
||||
# because MSVC is not a completely compliant implementation. This option forces MSVC to use the
|
||||
# appropriate value given the requested --std option. This fixes a compilation issue mismatch
|
||||
# between GCC/Clang and MSVC.
|
||||
#
|
||||
# error : a constexpr function cannot have a nonliteral return type "dim3"
|
||||
#
|
||||
#
|
||||
# See https://developercommunity.visualstudio.com/t/msvc-incorrectly-defines-cplusplus/139261
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus")
|
||||
|
||||
|
||||
endif()
|
||||
|
||||
# Some tests require this build option in order to link.
|
||||
@ -457,59 +533,25 @@ function(cutlass_apply_cuda_gencode_flags TARGET)
|
||||
set(ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS_ENABLED})
|
||||
endif()
|
||||
|
||||
set(NVCC_FLAGS)
|
||||
set(CLANG_FLAGS)
|
||||
set(__CMAKE_CUDA_ARCHS)
|
||||
foreach(ARCH ${ARCHS_ENABLED})
|
||||
list(APPEND CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH})
|
||||
set(CODES)
|
||||
if(CUTLASS_NVCC_EMBED_CUBIN)
|
||||
list(APPEND CODES sm_${ARCH})
|
||||
list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-real)
|
||||
endif()
|
||||
if(CUTLASS_NVCC_EMBED_PTX)
|
||||
list(APPEND CODES compute_${ARCH})
|
||||
if(CUTLASS_NVCC_EMBED_PTX AND NOT CUTLASS_CLANG_DEVICE_COMPILE)
|
||||
# If we're using clang for device compilation, the ptx is inserted
|
||||
# via another command line option and the `-virtual` flags will cause an error.
|
||||
list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-virtual)
|
||||
endif()
|
||||
list(JOIN CODES "," CODES_STR)
|
||||
list(APPEND NVCC_FLAGS -gencode=arch=compute_${ARCH},code=[${CODES_STR}])
|
||||
endforeach()
|
||||
|
||||
if (NOT __SM_ARCHS)
|
||||
if (CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
target_compile_options(
|
||||
${TARGET}
|
||||
PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CXX>:${CLANG_FLAGS}>
|
||||
)
|
||||
elseif(CMAKE_VERSION GREATER_EQUAL 3.18)
|
||||
set_property(TARGET ${TARGET} PROPERTY CUDA_ARCHITECTURES ${__CMAKE_CUDA_ARCHS})
|
||||
else()
|
||||
target_compile_options(
|
||||
${TARGET}
|
||||
PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:${NVCC_FLAGS}>
|
||||
)
|
||||
endif()
|
||||
else()
|
||||
list(JOIN CLANG_FLAGS " " CLANG_FLAGS_STR)
|
||||
list(JOIN NVCC_FLAGS " " STR_NVCC_FLAGS)
|
||||
if (CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
if(${TARGET} MATCHES ".*\.cpp")
|
||||
set_source_files_properties(${TARGET} PROPERTIES COMPILE_FLAGS ${CLANG_FLAGS_STR})
|
||||
endif()
|
||||
elseif(CMAKE_VERSION GREATER_EQUAL 3.18)
|
||||
set_source_files_properties(${TARGET} PROPERTIES CUDA_ARCHITECTURES ${STR_NVCC_FLAGS})
|
||||
else()
|
||||
if(${TARGET} MATCHES ".*\.cu")
|
||||
set_source_files_properties(${TARGET} PROPERTIES COMPILE_FLAGS ${STR_NVCC_FLAGS})
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set_property(TARGET ${TARGET} PROPERTY CUDA_ARCHITECTURES ${__CMAKE_CUDA_ARCHS})
|
||||
|
||||
endfunction()
|
||||
|
||||
# Cache the flags so they are available when the function below is called anywhere globally.
|
||||
# Cache the flags so they are available when the function below is called anywhere globally.
|
||||
|
||||
set(__CUTLASS_CUDA_FLAGS ${CUTLASS_CUDA_FLAGS} CACHE INTERNAL "")
|
||||
set(__CUTLASS_CUDA_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} CACHE INTERNAL "")
|
||||
@ -526,8 +568,8 @@ set(__CUTLASS_CUDA_NVCC_FLAGS_DEBUG ${CUTLASS_CUDA_NVCC_FLAGS_DEBUG} CACHE INTER
|
||||
|
||||
function(cutlass_apply_standard_compile_options TARGET)
|
||||
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
set(CUDA_COMPILE_LANGUAGE CXX)
|
||||
if(CUTLASS_CLANG_DEVICE_COMPILE)
|
||||
set(CUDA_COMPILE_LANGUAGE CUDA)
|
||||
set(_FLAGS ${__CUTLASS_CUDA_FLAGS} ${__CUTLASS_CUDA_CLANG_FLAGS})
|
||||
set(_FLAGS_RELEASE ${__CUTLASS_CUDA_FLAGS_RELEASE} ${__CUTLASS_CUDA_CLANG_FLAGS_RELEASE})
|
||||
set(_FLAGS_RELWITHDEBINFO ${__CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${__CUTLASS_CUDA_CLANG_FLAGS_RELWITHDEBINFO})
|
||||
@ -620,8 +662,6 @@ target_include_directories(
|
||||
$<INSTALL_INTERFACE:include>
|
||||
$<BUILD_INTERFACE:${CUTLASS_INCLUDE_DIR}>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/include>
|
||||
$<BUILD_INTERFACE:${cute_SOURCE_DIR}/include>
|
||||
$<BUILD_INTERFACE:${cute_SOURCE_DIR}/examples>
|
||||
)
|
||||
|
||||
# Mark CTK headers as system to supress warnings from them
|
||||
@ -680,6 +720,7 @@ if(NOT WIN32)
|
||||
"-Wl,-rpath,'$ORIGIN/../lib'"
|
||||
"-Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/lib64'"
|
||||
"-Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/lib'"
|
||||
${CMAKE_DL_LIBS}
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -689,7 +730,11 @@ include(CTest)
|
||||
enable_testing()
|
||||
|
||||
if (CUTLASS_ENABLE_GTEST_UNIT_TESTS)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake)
|
||||
if (CUTLASS_USE_SYSTEM_GOOGLETEST)
|
||||
find_package(GTest REQUIRED)
|
||||
else()
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (NOT TARGET test_all)
|
||||
@ -734,30 +779,31 @@ set(CUTLASS_DEFAULT_ACTIVE_TEST_SETS "default" CACHE STRING "Default
|
||||
with CUTLASS_TEST_SETS environment variable when running the ctest
|
||||
executable.")
|
||||
|
||||
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/${CMAKE_INSTALL_BINDIR}")
|
||||
set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.configure.cmake)
|
||||
set(CUTLASS_CTEST_GENERATED_FILES "" CACHE INTERNAL "")
|
||||
|
||||
function(cutlass_add_executable_tests NAME TARGET)
|
||||
#
|
||||
# Generates test rules for `make test`, `make test_all`, and `ctest` invoked from either the
|
||||
#
|
||||
# Generates test rules for `make test`, `make test_all`, and `ctest` invoked from either the
|
||||
# <CMAKE_BINARY_DIR> or the <CMAKE_INSTALL_PREFIX>/<CUTLASS_TEST_INSTALL_PREFIX> after installation.
|
||||
#
|
||||
#
|
||||
# NAME: The base name for the test. Can be run with `make <NAME>` or `ctest -R 'c<NAME>'`.
|
||||
# TARGET: The target corresponding to the executable under test.
|
||||
# DISABLE_EXECUTABLE_INSTALL_RULE: An option, if given, that disables creating an install rule for TARGET.
|
||||
# DEPENDS: A list of targets or files on which this test is dependent.
|
||||
# DEPENDEES: A list of targets which should depend on this test.
|
||||
# TEST_COMMAND_OPTIONS: A list of variables (i.e. by reference params) which contain command line arguments
|
||||
# to pass to the test executable. A unique test is generated for each set of
|
||||
# to pass to the test executable. A unique test is generated for each set of
|
||||
# options given. If this option is not used, a single test with no arguments is generated.
|
||||
# TEST_COMMAND_OPTIONS_PREFIX: If provided, is added as a prefix to each TEST_COMMAND_OPTIONS value for
|
||||
# TEST_COMMAND_OPTIONS_PREFIX: If provided, is added as a prefix to each TEST_COMMAND_OPTIONS value for
|
||||
# generating the full variable name to be referenced.
|
||||
# RESULT_CACHE_FILE: A file to be installed alongside the test executable with pre-computed
|
||||
# test results to speed up test runtime.
|
||||
# TEST_SETS_SUPPORTED: A list of test set names these tests support.
|
||||
#
|
||||
# TEST_SETS_SUPPORTED: A list of test set names these tests support.
|
||||
#
|
||||
|
||||
set(options DISABLE_EXECUTABLE_INSTALL_RULE)
|
||||
set(options DISABLE_EXECUTABLE_INSTALL_RULE DO_NOT_LOWERCASE_TEST_NAME)
|
||||
set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE TEST_COMMAND_OPTIONS_PREFIX)
|
||||
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS TEST_SETS_SUPPORTED)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
@ -787,9 +833,9 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
endif()
|
||||
|
||||
if (NOT __DISABLE_EXECUTABLE_INSTALL_RULE AND CUTLASS_INSTALL_TESTS)
|
||||
|
||||
|
||||
# file(RELATIVE_PATH CMAKE_CURRENT_BINARY_RELATIVE_DIR ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
|
||||
install(
|
||||
TARGETS ${TARGET}
|
||||
RUNTIME DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR}
|
||||
@ -803,7 +849,7 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
|
||||
endif()
|
||||
|
||||
if (NOT __TEST_COMMAND_OPTIONS)
|
||||
@ -829,48 +875,8 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
|
||||
set(TEST_GROUP_NAME ${NAME})
|
||||
|
||||
foreach(CMD_OPTIONS_VAR IN LISTS __TEST_COMMAND_OPTIONS)
|
||||
|
||||
if (CMD_COUNT GREATER 1)
|
||||
string(TOLOWER "${NAME}_${CMD_OPTIONS_VAR}" TEST_NAME)
|
||||
else()
|
||||
string(TOLOWER "${NAME}" TEST_NAME)
|
||||
endif()
|
||||
|
||||
# The following rigmarole is needed to deal with spaces and possible quotes in
|
||||
# command line arguments. The options are passed "by reference" as the actual
|
||||
# variable names holding the real options. We then expand these in a way that
|
||||
# preserves any quotes. Note, they have to be in this order for it to work for
|
||||
# all the use cases below.
|
||||
|
||||
set(TEST_COMMAND_OPTIONS ${${__TEST_COMMAND_OPTIONS_PREFIX}${CMD_OPTIONS_VAR}})
|
||||
list(JOIN TEST_COMMAND_OPTIONS " " TEST_COMMAND_OPTIONS)
|
||||
separate_arguments(TEST_COMMAND_OPTIONS)
|
||||
|
||||
add_custom_target(
|
||||
${TEST_NAME}
|
||||
COMMAND
|
||||
${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${TARGET}> ${TEST_COMMAND_OPTIONS}
|
||||
DEPENDS
|
||||
${TARGET}
|
||||
)
|
||||
|
||||
if (CMD_COUNT GREATER 1)
|
||||
add_dependencies(${NAME} ${TEST_NAME})
|
||||
endif()
|
||||
|
||||
foreach(DEPENDEE ${__DEPENDEES})
|
||||
add_dependencies(${DEPENDEE} ${TEST_NAME})
|
||||
endforeach()
|
||||
|
||||
set(TEST_NAME c${TEST_NAME})
|
||||
string(CONFIGURE "${_INLINE_PER_TEST_CODE_TEMPLATE}" _TEST_CODE @ONLY)
|
||||
string(APPEND _INLINE_PER_TEST_CODE "${_TEST_CODE}")
|
||||
|
||||
endforeach()
|
||||
|
||||
# To run the tests from an install package with tests enabled, we need to generate test files
|
||||
# that don't rely on the current directory structure in build.
|
||||
# that don't rely on the current directory structure in build.
|
||||
|
||||
set(TEST_NAME c${NAME})
|
||||
set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME})
|
||||
@ -884,17 +890,62 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
set(TEST_USE_EXTENDED_FORMAT OFF) # ctest does not support extended add_test format.
|
||||
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" @ONLY)
|
||||
|
||||
foreach(CMD_OPTIONS_VAR IN LISTS __TEST_COMMAND_OPTIONS)
|
||||
|
||||
if (CMD_COUNT GREATER 1)
|
||||
set(TESTCASE_NAME "${NAME}_${CMD_OPTIONS_VAR}")
|
||||
else()
|
||||
set(TESTCASE_NAME "${NAME}")
|
||||
endif()
|
||||
|
||||
if (NOT __DO_NOT_LOWERCASE_TEST_NAME)
|
||||
string(TOLOWER "${TESTCASE_NAME}" TESTCASE_NAME)
|
||||
endif()
|
||||
|
||||
# The following rigmarole is needed to deal with spaces and possible quotes in
|
||||
# command line arguments. The options are passed "by reference" as the actual
|
||||
# variable names holding the real options. We then expand these in a way that
|
||||
# preserves any quotes. Note, they have to be in this order for it to work for
|
||||
# all the use cases below.
|
||||
|
||||
set(TEST_COMMAND_OPTIONS ${${__TEST_COMMAND_OPTIONS_PREFIX}${CMD_OPTIONS_VAR}})
|
||||
list(JOIN TEST_COMMAND_OPTIONS " " TEST_COMMAND_OPTIONS)
|
||||
separate_arguments(TEST_COMMAND_OPTIONS)
|
||||
|
||||
add_custom_target(
|
||||
${TESTCASE_NAME}
|
||||
COMMAND
|
||||
${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${TARGET}> ${TEST_COMMAND_OPTIONS}
|
||||
DEPENDS
|
||||
${TARGET}
|
||||
)
|
||||
|
||||
if (CMD_COUNT GREATER 1)
|
||||
add_dependencies(${NAME} ${TESTCASE_NAME})
|
||||
endif()
|
||||
|
||||
foreach(DEPENDEE ${__DEPENDEES})
|
||||
add_dependencies(${DEPENDEE} ${TESTCASE_NAME})
|
||||
endforeach()
|
||||
|
||||
set(TESTCASE_NAME c${TESTCASE_NAME})
|
||||
string(CONFIGURE "${_INLINE_PER_TEST_CODE_TEMPLATE}" _TEST_CODE @ONLY)
|
||||
file(APPEND "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" "${_TEST_CODE}")
|
||||
file(APPEND "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" "${_TEST_CODE}")
|
||||
|
||||
endforeach()
|
||||
|
||||
# The following line imports the tests for immediate run via `make test`.
|
||||
|
||||
include(${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake)
|
||||
|
||||
|
||||
set(CUTLASS_CTEST_GENERATED_FILES ${CUTLASS_CTEST_GENERATED_FILES};ctest/${TEST_NAME}/CTestTestfile.${TEST_NAME}.cmake CACHE INTERNAL "")
|
||||
|
||||
if (CUTLASS_INSTALL_TESTS)
|
||||
|
||||
file(GENERATE
|
||||
OUTPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake"
|
||||
INPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in"
|
||||
file(GENERATE
|
||||
OUTPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake"
|
||||
INPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in"
|
||||
)
|
||||
|
||||
install(
|
||||
@ -952,19 +1003,19 @@ endif()
|
||||
include(CMakePackageConfigHelpers)
|
||||
|
||||
write_basic_package_version_file(
|
||||
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake
|
||||
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake
|
||||
COMPATIBILITY AnyNewerVersion)
|
||||
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassConfig.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake
|
||||
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake
|
||||
@ONLY
|
||||
)
|
||||
|
||||
install(
|
||||
FILES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake
|
||||
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake
|
||||
FILES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake
|
||||
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/NvidiaCutlass/
|
||||
)
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||

|
||||

|
||||
|
||||
[README](/README.md#documentation) > **Contributors**
|
||||
[README](./README.md#documentation) > **Contributors**
|
||||
|
||||
# CUTLASS Developers and Contributors
|
||||
|
||||
|
||||
114
CUDA.cmake
114
CUDA.cmake
@ -26,49 +26,46 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
set(CUTLASS_NATIVE_CUDA_INIT ON)
|
||||
elseif(CMAKE_VERSION VERSION_LESS 3.12.4)
|
||||
set(CUTLASS_NATIVE_CUDA_INIT OFF)
|
||||
else()
|
||||
set(CUTLASS_NATIVE_CUDA_INIT ON)
|
||||
if (CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
message(WARNING "CUDA_COMPILER flag is deprecated, set CMAKE_CUDA_COMPILER to desired compiler executable.")
|
||||
set(__CLANG_DEVICE_COMPILATION_REQUESTED ON)
|
||||
elseif(CUDA_COMPILER)
|
||||
message(WARNING "Deprecated flag CUDA_COMPILER used with unknown argument ${CUDA_COMPILER}, ignoring.")
|
||||
endif()
|
||||
|
||||
set(CUTLASS_NATIVE_CUDA ${CUTLASS_NATIVE_CUDA_INIT} CACHE BOOL "Utilize the CMake native CUDA flow")
|
||||
|
||||
if(NOT DEFINED ENV{CUDACXX} AND NOT DEFINED ENV{CUDA_BIN_PATH} AND DEFINED ENV{CUDA_PATH})
|
||||
# For backward compatibility, allow use of CUDA_PATH.
|
||||
set(ENV{CUDACXX} $ENV{CUDA_PATH}/bin/nvcc)
|
||||
if (__CLANG_DEVICE_COMPILATION_REQUESTED AND NOT DEFINED CMAKE_CUDA_COMPILER)
|
||||
set(CMAKE_CUDA_COMPILER clang++) # We will let the system find Clang or error out
|
||||
endif()
|
||||
|
||||
if(CUTLASS_NATIVE_CUDA)
|
||||
enable_language(CUDA)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
|
||||
enable_language(CUDA)
|
||||
|
||||
if(NOT CUDA_VERSION)
|
||||
set(CUDA_VERSION ${CMAKE_CUDA_COMPILER_VERSION})
|
||||
endif()
|
||||
if(NOT CUDA_TOOLKIT_ROOT_DIR)
|
||||
get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CMAKE_CUDA_COMPILER}/../.." ABSOLUTE)
|
||||
endif()
|
||||
if(NOT CUDA_VERSION)
|
||||
# For backward compatibility with older CMake code.
|
||||
set(CUDA_VERSION ${CUDAToolkit_VERSION})
|
||||
set(CUDA_VERSION_MAJOR ${CUDAToolkit_VERSION_MAJOR})
|
||||
set(CUDA_VERSION_MINOR ${CUDAToolkit_VERSION_MINOR})
|
||||
endif()
|
||||
if(NOT CUDA_TOOLKIT_ROOT_DIR)
|
||||
# In some scenarios, such as clang device compilation, the toolkit root may not be set, so we
|
||||
# force it here to the nvcc we found via the CUDAToolkit package.
|
||||
get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CUDAToolkit_NVCC_EXECUTABLE}/../.." ABSOLUTE)
|
||||
endif()
|
||||
|
||||
if (CMAKE_CUDA_COMPILER_ID MATCHES "(nvcc|[Nn][Vv][Ii][Dd][Ii][Aa])")
|
||||
set(CUTLASS_NVCC_DEVICE_COMPILE ON CACHE BOOL "Using nvcc tools for device compilation")
|
||||
elseif (CMAKE_CUDA_COMPILER_ID MATCHES "[Cc]lang")
|
||||
set(CUTLASS_CLANG_DEVICE_COMPILE ON CACHE BOOL "Using Clang tools for device compilation")
|
||||
else()
|
||||
message(FATAL_ERROR "Uknown device-side compiler ${CMAKE_CUDA_COMPILER_ID} found. Set CMAKE_CUDA_COMPILER to either nvcc or clang++.")
|
||||
endif()
|
||||
|
||||
find_package(CUDA REQUIRED)
|
||||
# We workaround missing variables with the native flow by also finding the CUDA toolkit the old way.
|
||||
|
||||
if(NOT CMAKE_CUDA_COMPILER_VERSION)
|
||||
set(CMAKE_CUDA_COMPILER_VERSION ${CUDA_VERSION})
|
||||
endif()
|
||||
|
||||
if (CUTLASS_CLANG_DEVICE_COMPILE AND CMAKE_VERSION VERSION_LESS_EQUAL "3.30")
|
||||
message(FATAL_ERROR "Clang device compilation for CUTLASS requires CMake 3.30 or higher.")
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 9.2)
|
||||
message(FATAL_ERROR "CUDA 9.2+ Required, Found ${CUDA_VERSION}.")
|
||||
endif()
|
||||
if(NOT CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc)
|
||||
message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}")
|
||||
message(FATAL_ERROR "CUDA 9.2+ required, found ${CUDA_VERSION}.")
|
||||
endif()
|
||||
|
||||
find_library(
|
||||
@ -211,16 +208,6 @@ include_directories(SYSTEM ${CUDA_INCLUDE_DIRS})
|
||||
# Some platforms (e.g. Visual Studio) don't add the CUDA include directories to the system include
|
||||
# paths by default, so we add it explicitly here.
|
||||
|
||||
function(cutlass_correct_source_file_language_property)
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
foreach(File ${ARGN})
|
||||
if(File MATCHES ".*\.cu$")
|
||||
set_source_files_properties(${File} PROPERTIES LANGUAGE CXX)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
if (MSVC OR CUTLASS_LIBRARY_KERNELS MATCHES "all")
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED_INIT ON)
|
||||
else()
|
||||
@ -306,18 +293,13 @@ function(cutlass_add_library NAME)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
|
||||
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang")
|
||||
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
|
||||
add_library(${NAME} ${TARGET_SOURCE_ARGS} "")
|
||||
else()
|
||||
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
|
||||
cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS} "")
|
||||
endif()
|
||||
|
||||
add_library(${NAME} ${TARGET_SOURCE_ARGS} "")
|
||||
|
||||
cutlass_apply_standard_compile_options(${NAME})
|
||||
|
||||
if (NOT __SKIP_GENCODE_FLAGS)
|
||||
cutlass_apply_cuda_gencode_flags(${NAME})
|
||||
cutlass_apply_cuda_gencode_flags(${NAME})
|
||||
endif()
|
||||
|
||||
target_compile_features(
|
||||
@ -326,6 +308,14 @@ function(cutlass_add_library NAME)
|
||||
cxx_std_11
|
||||
)
|
||||
|
||||
get_target_property(TARGET_TYPE ${NAME} TYPE)
|
||||
|
||||
if (TARGET_TYPE MATCHES "SHARED")
|
||||
set_target_properties(${NAME} PROPERTIES CUDA_RUNTIME_LIBRARY Shared)
|
||||
elseif(TARGET_TYPE MATCHES "STATIC")
|
||||
set_target_properties(${NAME} PROPERTIES CUDA_RUNTIME_LIBRARY Static)
|
||||
endif()
|
||||
|
||||
if(__EXPORT_NAME)
|
||||
add_library(nvidia::cutlass::${__EXPORT_NAME} ALIAS ${NAME})
|
||||
set_target_properties(${NAME} PROPERTIES EXPORT_NAME ${__EXPORT_NAME})
|
||||
@ -336,19 +326,22 @@ endfunction()
|
||||
function(cutlass_add_executable NAME)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs)
|
||||
set(oneValueArgs CUDA_RUNTIME_LIBRARY)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if (NOT DEFINED __CUDA_RUNTIME_LIBRARY)
|
||||
set(__CUDA_RUNTIME_LIBRARY Shared)
|
||||
endif()
|
||||
|
||||
set(__CUDA_RUNTIME_LIBRARY_ALLOWED None Shared Static)
|
||||
if (NOT __CUDA_RUNTIME_LIBRARY IN_LIST __CUDA_RUNTIME_LIBRARY_ALLOWED)
|
||||
message(FATAL_ERROR "CUDA_RUNTIME_LIBRARY value '${__CUDA_RUNTIME_LIBRARY}' is not in allowed list of '${__CUDA_RUNTIME_LIBRARY_ALLOWED}'")
|
||||
endif()
|
||||
|
||||
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
|
||||
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang")
|
||||
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
|
||||
add_executable(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
else()
|
||||
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
|
||||
cuda_add_executable(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
endif()
|
||||
add_executable(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
|
||||
cutlass_apply_standard_compile_options(${NAME})
|
||||
cutlass_apply_cuda_gencode_flags(${NAME})
|
||||
@ -359,6 +352,8 @@ function(cutlass_add_executable NAME)
|
||||
cxx_std_11
|
||||
)
|
||||
|
||||
set_target_properties(${NAME} PROPERTIES CUDA_RUNTIME_LIBRARY ${__CUDA_RUNTIME_LIBRARY})
|
||||
|
||||
endfunction()
|
||||
|
||||
function(cutlass_target_sources NAME)
|
||||
@ -369,7 +364,6 @@ function(cutlass_target_sources NAME)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
|
||||
target_sources(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
|
||||
endfunction()
|
||||
|
||||
@ -1,9 +1,20 @@
|
||||
# Publications Using Cutlass
|
||||
|
||||
## 2024
|
||||
|
||||
- ["ShadowKV: KV Cache in Shadows for High-Throughput Long-Context LLM Inference"](https://arxiv.org/abs/2410.21465). Hanshi Sun, Li-Wen Chang, Wenlei Bao, Size Zheng, Ningxin Zheng, Xin Liu, Harry Dong, Yuejie Chi, Beidi Chen. _arXiv_, October 2024.
|
||||
|
||||
- ["FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion"](https://arxiv.org/abs/2406.06858). Li-Wen Chang, Wenlei Bao, Qi Hou, Chengquan Jiang, Ningxin Zheng, Yinmin Zhong, Xuanrun Zhang, Zuquan Song, Chengji Yao, Ziheng Jiang, Haibin Lin, Xin Jin, Xin Liu. _arXiv_, June 2024.
|
||||
|
||||
- ["EVT: Accelerating Deep Learning Training with Epilogue Visitor Tree"](https://dl.acm.org/doi/10.1145/3620666.3651369). Zhaodong Chen, Andrew Kerr, Richard Cai, Jack Kosaian, Haicheng Wu, Yufei Ding, and Yuan Xie. _Proceedings of the 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, April 2024.
|
||||
|
||||
- ["Faster Neighborhood Attention: Reducing the O(n^2) Cost of Self Attention at the Threadblock Level"](https://arxiv.org/abs/2403.04690). Ali Hassani, Wen-Mei Hwu, Humphrey Shi. _arXiv_, March 2024.
|
||||
|
||||
## 2023
|
||||
|
||||
- ["A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on NVIDIA Hopper Architecture using the CUTLASS Library"](https://arxiv.org/abs/2312.11918). Ganesh Bikshandi, Jay Shah. _arXiv_, December 2023.
|
||||
|
||||
- ["Benchmarking GPU Tensor Cores on General Matrix Multiplication Kernels through CUTLASS"](https://www.mdpi.com/2076-3417/13/24/13022). Xuanteng Huang, Xianwei Zhang, Panfei Yang, Nong Xiao. _Journal of Applied Sciences_, December 2023.
|
||||
|
||||
- ["A Speed Odyssey for Deployable Quantization of LLMs"](https://arxiv.org/abs/2311.09550). Qingyuan Li, Ran Meng, Yiduo Li, Bo Zhang, Liang Li, Yifan Lu, Xiangxiang Chu, Yerui Sun, Yuchen Xie. _arXiv_, November 2023.
|
||||
|
||||
@ -19,6 +30,8 @@
|
||||
|
||||
- ["Mixed Precision Post Training Quantization of Neural Networks with Sensitivity Guided Search"](https://arxiv.org/abs/2302.01382). Clemens JS Schaefer, Elfie Guo, Caitlin Stanton, Xiaofan Zhang, Tom Jablin, Navid Lambert-Shirzad, Jian Li, Chiachen Chou, Siddharth Joshi, Yu Emma Wang. _arXiv_, Feburary 2023.
|
||||
|
||||
- ["Dynamic N:M Fine-Grained Structured Sparse Attention Mechanism"](https://dl.acm.org/doi/abs/10.1145/3572848.3577500). Zhaodong Chen, Zheng Qu, Yuying Quan, Liu Liu, Yufei Ding, Yuan Xie. _Proceedings of the 28th ACM SIGPLAN Annual Symposium on Principles and Practice of Parallel Programming_, Feburary 2023.
|
||||
|
||||
- ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023.
|
||||
|
||||
## 2022
|
||||
|
||||
109
README.md
109
README.md
@ -1,8 +1,8 @@
|
||||

|
||||

|
||||
|
||||
# CUTLASS 3.4
|
||||
# CUTLASS 3.6.0
|
||||
|
||||
_CUTLASS 3.4 - February 2024_
|
||||
_CUTLASS 3.6.0 - October 2024_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
|
||||
@ -19,16 +19,16 @@ mixed-precision computations, providing specialized data-movement and
|
||||
multiply-accumulate abstractions for half-precision floating
|
||||
point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
|
||||
single-precision floating point (FP32),
|
||||
[FP32 emulation via tensor core instruction](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
|
||||
[FP32 emulation via tensor core instruction](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
|
||||
double-precision floating
|
||||
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
|
||||
CUTLASS demonstrates warp-synchronous matrix multiply operations
|
||||
targeting the programmable, high-throughput _Tensor Cores_ implemented by
|
||||
NVIDIA's Volta, Turing, Ampere, and Hopper architectures.
|
||||
|
||||
See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
See the [Quick Start Guide](./media/docs/quickstart.md) to get started quickly.
|
||||
|
||||
See the [functionality listing](/media/docs/functionality.md) for the list of operations
|
||||
See the [functionality listing](./media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
|
||||
CUTLASS 3.0 introduced a new core library, CuTe, to describe and manipulate tensors of threads and data.
|
||||
@ -37,25 +37,31 @@ CuTe is a collection of C++ CUDA template abstractions for defining and operatin
|
||||
The core abstractions of CuTe are hierarchically multidimensional layouts which can be composed with data arrays to represent tensors. The representation of layouts is powerful enough to represent nearly everything we need to implement efficient dense linear algebra. Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning.
|
||||
|
||||
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design
|
||||
and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](/media/docs/cute/00_quickstart.md).
|
||||
and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](./media/docs/cute/00_quickstart.md).
|
||||
|
||||
In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
|
||||
|
||||
# What's New in CUTLASS 3.4
|
||||
|
||||
CUTLASS 3.4.1 is an update to CUTLASS adding:
|
||||
- Statically available [CUTLASS Version macros](/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
|
||||
- Improvements for Hopper [Group-GEMM](/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMM](/examples/56_hopper_ptr_array_batched_gemm).
|
||||
- Updates and bugfixes from the community (thanks!).
|
||||
# What's New in CUTLASS 3.6
|
||||
|
||||
CUTLASS 3.4.0 is an update to CUTLASS adding:
|
||||
CUTLASS 3.6.0 is an update to CUTLASS adding:
|
||||
|
||||
- Improved [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) supporting {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors tuned for optimal performance on Hopper H100.
|
||||
- Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) utilizing TMA and Hopper H100 tensor cores now available. (Requires CUDA 12.3 or above)
|
||||
- Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) - commonly used in optimization of Mixture-Of-Expert models, is now available on Hopper GPUs taking advantage of TMA and Hopper H100 tensor cores. (Requires CUDA 12.3 or above)
|
||||
- [Ampere Sparse GEMM](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
|
||||
- Improvements to NamedBarriers including details of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) used within the CUTLASS library.
|
||||
- Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved.
|
||||
- [Hopper structured sparse GEMM](./examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu).
|
||||
+ [FP16](./test/unit/gemm/device/sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu)
|
||||
+ [FP8](./test/unit/gemm/device/sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu)
|
||||
+ [INT8](./test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu)
|
||||
+ [TF32](./test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu)
|
||||
- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API.
|
||||
- [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
|
||||
- [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu).
|
||||
- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md).
|
||||
- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
|
||||
- A new TMA-enabled [epilogue](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support.
|
||||
- A SIMT-enabled pointer-array [epilogue](./include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp).
|
||||
- A new [Ping-Pong kernel schedule for Grouped GEMM](./include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations.
|
||||
- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/profiler.md#instantiating-more-kernels-with-hopper).
|
||||
- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h)
|
||||
- Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu).
|
||||
|
||||
Minimum requirements:
|
||||
|
||||
@ -74,16 +80,15 @@ Starting from CUTLASS 3.0, CUTLASS removed support for the following:
|
||||
|
||||
# Performance
|
||||
|
||||
<p align="center"><img src=media/images/cutlass-3.1-gemm-peak-performance.png></p>
|
||||
<p align="center"><img src=media/images/cutlass-3.5.1-gemm-peak-performance.png></p>
|
||||
<p align="center"><img src=media/images/cutlass-3.5.1-gemm-peak-performance-fp8.png></p>
|
||||
|
||||
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
|
||||
they exhibit peak performance comparable to cuBLAS for scalar GEMM
|
||||
computations. The above figure shows CUTLASS performance relative to cuBLAS
|
||||
for large matrix dimensions on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture),
|
||||
an [NVIDIA L40](https://www.nvidia.com/en-us/data-center/l40/) (NVIDIA Ada architecture),
|
||||
an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) (NVIDIA Ampere architecture),
|
||||
and an [NVIDIA A40](https://www.nvidia.com/en-us/data-center/a40/) (NVIDIA Ampere architecture).
|
||||
CUTLASS 3.0 was compiled with the [CUDA 12.0 Toolkit](https://developer.nvidia.com/cuda-downloads).
|
||||
computations. The above figure shows the continual CUTLASS performance improvements
|
||||
on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture) since
|
||||
CUTLASS 3.1.
|
||||
CUTLASS 3.5.1 was compiled with the [CUDA 12.5u1 Toolkit](https://developer.nvidia.com/cuda-downloads).
|
||||
Tensor Core operations are implemented using CUDA's
|
||||
[mma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) and
|
||||
[wgmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) instructions.
|
||||
@ -98,7 +103,7 @@ as shown in the above figure. Tensor Core operations are implemented using CUDA
|
||||
# Compatibility
|
||||
|
||||
CUTLASS requires a C++17 host compiler and
|
||||
performs best when built with the [**CUDA 12.3.2 Toolkit**](https://developer.nvidia.com/cuda-downloads).
|
||||
performs best when built with the [**CUDA 12.4 Toolkit**](https://developer.nvidia.com/cuda-downloads).
|
||||
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2, CUDA 12.3.1 and CUDA 12.3.2.
|
||||
|
||||
## Operating Systems
|
||||
@ -136,34 +141,36 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be
|
||||
|
||||
In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability).
|
||||
|
||||
The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CTK 12 or 11.8, the kernel is expected to fail with a runtime error.
|
||||
The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CUDA Toolkit 12 or 11.8, the kernel is expected to fail with a runtime error.
|
||||
|
||||
```
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="90a"
|
||||
```
|
||||
|
||||
Please refer to the [functionality documentation](media/docs/functionality.md) for details on which kernels require which target architectures.
|
||||
Please refer to the [functionality documentation](./media/docs/functionality.md) for details on which kernels require which target architectures.
|
||||
|
||||
# Documentation
|
||||
|
||||
CUTLASS is described in the following documents and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
|
||||
- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS
|
||||
- [Functionality](/media/docs/functionality.md) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [CUTLASS 3.x Design](media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
||||
- [GEMM API 3.x](media/docs/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
|
||||
- [GEMM API 2.x](media/docs/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
|
||||
- [Implicit GEMM Convolution](media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
|
||||
- [Code Organization](media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](media/docs/terminology.md) - describes terms used in the code
|
||||
- [Programming Guidelines](media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
|
||||
- [Fundamental types](media/docs/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
|
||||
- [Layouts](media/docs/layout.md) - describes layouts of matrices and tensors in memory
|
||||
- [Tile Iterators](media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
|
||||
- [CUTLASS Profiler](media/docs/profiler.md) - command-line driven profiling application
|
||||
- [CUTLASS Utilities](media/docs/utilities.md) - additional templates used to facilate rapid development
|
||||
- [Quick Start Guide](./media/docs/quickstart.md) - build and run CUTLASS
|
||||
- [Functionality](./media/docs/functionality.md) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](./media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [CUTLASS 3.x Design](./media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
||||
- [GEMM API 3.x](./media/docs/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
|
||||
- [GEMM API 2.x](./media/docs/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
|
||||
- [Implicit GEMM Convolution](./media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
|
||||
- [Code Organization](./media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](./media/docs/terminology.md) - describes terms used in the code
|
||||
- [Programming Guidelines](./media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
|
||||
- [Fundamental types](./media/docs/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
|
||||
- [Layouts](./media/docs/layout.md) - describes layouts of matrices and tensors in memory
|
||||
- [Tile Iterators](./media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
|
||||
- [CUTLASS Profiler](./media/docs/profiler.md) - command-line driven profiling application
|
||||
- [CUTLASS Utilities](./media/docs/utilities.md) - additional templates used to facilate rapid development
|
||||
- [Dependent kernel launch](./media/docs/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent
|
||||
kernels in the same stream, and how it is used in CUTLASS.
|
||||
|
||||
# Resources
|
||||
We have also described the structure of an efficient GEMM in our talk at the
|
||||
@ -182,7 +189,7 @@ projects. Client applications should target CUTLASS's `include/` directory in th
|
||||
paths.
|
||||
|
||||
CUTLASS unit tests, examples, and utilities can be build with CMake.
|
||||
The minimum version of CMake is given in the [Quickstart guide](media/docs/quickstart.md).
|
||||
The minimum version of CMake is given in the [Quickstart guide](./media/docs/quickstart.md).
|
||||
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
|
||||
on your system.
|
||||
|
||||
@ -227,7 +234,7 @@ CUTLASS is arranged as a header-only library along with Utilities, Tools, Exampl
|
||||
and template concepts defined in the CUTLASS project.
|
||||
|
||||
A detailed explanation of the source code organization may be found in the
|
||||
[CUTLASS documentation](media/docs/code_organization.md), but several main components are summarized below.
|
||||
[CUTLASS documentation](./media/docs/code_organization.md), but several main components are summarized below.
|
||||
|
||||
## CUTLASS Template Library
|
||||
|
||||
@ -276,7 +283,7 @@ include/ # client applications should target this directory
|
||||
|
||||
### CUTLASS SDK Examples
|
||||
|
||||
[CUTLASS SDK examples](/examples) apply CUTLASS templates to implement basic computations.
|
||||
[CUTLASS SDK examples](./examples) apply CUTLASS templates to implement basic computations.
|
||||
|
||||
### Tools
|
||||
|
||||
@ -301,7 +308,7 @@ tools/
|
||||
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
|
||||
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
|
||||
|
||||
Instructions for building and running the Unit tests are described in the [Quickstart guide](media/docs/quickstart.md).
|
||||
Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/quickstart.md).
|
||||
|
||||
# Performance Profiling
|
||||
|
||||
@ -517,9 +524,9 @@ reference_device: Passed
|
||||
|
||||
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
|
||||
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
|
||||
- [GEMM CMake Examples](media/docs/quickstart.md#gemm-cmake-examples)
|
||||
- [Implicit GEMM convolution CMake Examples](media/docs/quickstart.md#convolution-cmake-examples)
|
||||
- [Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md)
|
||||
- [GEMM CMake Examples](./media/docs/quickstart.md#gemm-cmake-examples)
|
||||
- [Implicit GEMM convolution CMake Examples](./media/docs/quickstart.md#convolution-cmake-examples)
|
||||
- [Further details about the CUTLASS Profiler are described here.](./media/docs/profiler.md)
|
||||
|
||||
|
||||
# About
|
||||
|
||||
@ -50,5 +50,3 @@ if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
|
||||
else()
|
||||
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@)
|
||||
endif()
|
||||
|
||||
@_INLINE_PER_TEST_CODE@
|
||||
|
||||
@ -30,14 +30,14 @@ if (CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT)
|
||||
# The longform/extended format allows generator expressions to be
|
||||
# expanded property and is useful in contexts where the files need
|
||||
# to be immediately included into being-processed cmake code.
|
||||
add_test(NAME @TEST_NAME@ COMMAND ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
|
||||
add_test(NAME @TESTCASE_NAME@ COMMAND ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
|
||||
else()
|
||||
add_test(@TEST_NAME@ ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
|
||||
add_test(@TESTCASE_NAME@ ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
|
||||
endif()
|
||||
|
||||
if (TEST_EXE_WORKING_DIRECTORY)
|
||||
set_tests_properties(@TEST_NAME@ PROPERTIES WORKING_DIRECTORY "${TEST_EXE_WORKING_DIRECTORY}")
|
||||
set_tests_properties(@TESTCASE_NAME@ PROPERTIES WORKING_DIRECTORY "${TEST_EXE_WORKING_DIRECTORY}")
|
||||
endif()
|
||||
|
||||
set_tests_properties(@TEST_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@)
|
||||
set_tests_properties(@TESTCASE_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@)
|
||||
|
||||
|
||||
@ -34,10 +34,11 @@ if(GOOGLETEST_DIR)
|
||||
set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override")
|
||||
endif()
|
||||
|
||||
set(GTEST_REPOSITORY "https://github.com/google/googletest.git" CACHE STRING "GoogleTest repo to fetch")
|
||||
FetchContent_Declare(
|
||||
googletest
|
||||
GIT_REPOSITORY https://github.com/google/googletest.git
|
||||
GIT_TAG v1.13.0
|
||||
GIT_REPOSITORY ${GTEST_REPOSITORY}
|
||||
GIT_TAG v1.14.0
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(googletest)
|
||||
|
||||
@ -260,7 +260,7 @@ private:
|
||||
if (options.vectorize <= 2) return std::make_pair(false, -1);
|
||||
|
||||
// Boundary check.
|
||||
if (i > elements.size() || (i + options.vectorize - 1) > elements.size())
|
||||
if (i > int(elements.size()) || (i + options.vectorize - 1) > int(elements.size()))
|
||||
return std::make_pair(false, -1);
|
||||
|
||||
// Check if either all elements are valid or invalid.
|
||||
|
||||
@ -94,7 +94,7 @@ __global__ void copy(
|
||||
|
||||
typename Iterator::Fragment fragment;
|
||||
|
||||
for(int i = 0; i < fragment.size(); ++i) {
|
||||
for(size_t i = 0; i < fragment.size(); ++i) {
|
||||
fragment[i] = 0;
|
||||
}
|
||||
|
||||
|
||||
@ -207,15 +207,15 @@ cudaError_t strided_batched_gemm_nn_reference(
|
||||
|
||||
cudaError_t result = cudaSuccess;
|
||||
|
||||
if (A.size() < lda * k * batch_count) {
|
||||
if (A.size() < size_t(lda * k * batch_count)) {
|
||||
std::cout << "the size of A is too small" << std::endl;
|
||||
return cudaErrorInvalidValue;
|
||||
}
|
||||
if (B.size() < ldb * n) {
|
||||
if (B.size() < size_t(ldb * n)) {
|
||||
std::cout << "the size of B is too small" << std::endl;
|
||||
return cudaErrorInvalidValue;
|
||||
}
|
||||
if (C.size() < ldc * n * batch_count) {
|
||||
if (C.size() < size_t(ldc * n * batch_count)) {
|
||||
std::cout << "the size of C is too small" << std::endl;
|
||||
return cudaErrorInvalidValue;
|
||||
}
|
||||
|
||||
@ -162,7 +162,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes ?
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
||||
@ -161,7 +161,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 16>; // <- MMA Op tile M = 8, N = 8, K = 16
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
||||
@ -84,7 +84,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Define the epilogue operation as LinearCombinationRelu. This is approximately equal to
|
||||
//
|
||||
|
||||
@ -80,4 +80,3 @@ foreach(FUSION_GEMM_EXAMPLE
|
||||
add_dependencies(13_fused_two_gemms 13_${FUSION_GEMM_EXAMPLE})
|
||||
|
||||
endforeach()
|
||||
|
||||
|
||||
@ -102,7 +102,7 @@ struct B2bFusedGroupedGemmRun
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
view, seed, 1, -1, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
@ -231,7 +231,7 @@ struct B2bFusedGroupedGemmRun
|
||||
host_tensor_ref_D1.at(i).sync_device();
|
||||
|
||||
ref_A0.at(i) = (host_tensor_A0.at(i).device_ref());
|
||||
ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());;
|
||||
ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());
|
||||
ref_C0.at(i) = (host_tensor_C0.at(i).device_ref());
|
||||
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||
ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref());
|
||||
@ -340,7 +340,7 @@ struct B2bFusedGroupedGemmRun
|
||||
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
||||
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
host_tensor_D1.at(i).sync_host();;
|
||||
host_tensor_D1.at(i).sync_host();
|
||||
|
||||
//
|
||||
// Verify
|
||||
|
||||
@ -157,35 +157,34 @@ struct B2bGemm {
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size_0;
|
||||
GemmCoord problem_size_1;
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0;
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0;
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
|
||||
int64_t batch_stride_A0;
|
||||
int64_t batch_stride_B0;
|
||||
int64_t batch_stride_B1;
|
||||
int64_t batch_stride_C1;
|
||||
int64_t batch_stride_D1;
|
||||
int64_t batch_stride_Bias0;
|
||||
int64_t batch_stride_Scale0;
|
||||
typename OutputOp0::Params epilogue0;
|
||||
typename OutputOp1::Params epilogue1;
|
||||
int batch_count;
|
||||
GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
|
||||
GemmCoord problem_size_0{0,0,0};
|
||||
GemmCoord problem_size_1{0,0,0};
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0{};
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0{};
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0{};
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{};
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{};
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1{};
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1{};
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1{};
|
||||
int64_t batch_stride_A0{0};
|
||||
int64_t batch_stride_B0{0};
|
||||
int64_t batch_stride_B1{0};
|
||||
int64_t batch_stride_C1{0};
|
||||
int64_t batch_stride_D1{0};
|
||||
int64_t batch_stride_Bias0{0};
|
||||
int64_t batch_stride_Scale0{0};
|
||||
typename OutputOp0::Params epilogue0 {};
|
||||
typename OutputOp1::Params epilogue1 {};
|
||||
int batch_count{1};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() : mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {}
|
||||
Arguments() = default;
|
||||
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -285,47 +284,45 @@ struct B2bGemm {
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmUniversalMode mode;
|
||||
cutlass::gemm::GemmCoord problem_size_0;
|
||||
cutlass::gemm::GemmCoord problem_size_1;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
typename B2bMma::IteratorA0::Params params_A0;
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0;
|
||||
typename B2bMma::IteratorB0::Params params_B0;
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0;
|
||||
typename Epilogue::OutputTileIterator::Params params_C0;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0;
|
||||
typename B2bMma::IteratorB1::Params params_B1;
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1;
|
||||
typename Epilogue::OutputTileIterator::Params params_C1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1;
|
||||
typename Epilogue::OutputTileIterator::Params params_D1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
|
||||
typename OutputOp0::Params output_op_0;
|
||||
typename OutputOp1::Params output_op_1;
|
||||
int64_t batch_stride_A0;
|
||||
int64_t batch_stride_B0;
|
||||
int64_t batch_stride_B1;
|
||||
int64_t batch_stride_C1;
|
||||
int64_t batch_stride_D1;
|
||||
int64_t batch_stride_Bias0;
|
||||
int64_t batch_stride_Scale0;
|
||||
int *semaphore;
|
||||
int gemm_k_iterations_0;
|
||||
int gemm_k_size_0;
|
||||
int gemm_k_iterations_1;
|
||||
int gemm_k_size_1;
|
||||
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
|
||||
cutlass::gemm::GemmCoord problem_size_0{};
|
||||
cutlass::gemm::GemmCoord problem_size_1{};
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape{};
|
||||
int swizzle_log_tile{0};
|
||||
typename B2bMma::IteratorA0::Params params_A0{};
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0{};
|
||||
typename B2bMma::IteratorB0::Params params_B0{};
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0{};
|
||||
typename Epilogue::OutputTileIterator::Params params_C0{};
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0{};
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{};
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{};
|
||||
typename B2bMma::IteratorB1::Params params_B1{};
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1{};
|
||||
typename Epilogue::OutputTileIterator::Params params_C1{};
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1{};
|
||||
typename Epilogue::OutputTileIterator::Params params_D1{};
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1{};
|
||||
typename OutputOp0::Params output_op_0{};
|
||||
typename OutputOp1::Params output_op_1{};
|
||||
int64_t batch_stride_A0{0};
|
||||
int64_t batch_stride_B0{0};
|
||||
int64_t batch_stride_B1{0};
|
||||
int64_t batch_stride_C1{0};
|
||||
int64_t batch_stride_D1{0};
|
||||
int64_t batch_stride_Bias0{0};
|
||||
int64_t batch_stride_Scale0{0};
|
||||
int *semaphore = nullptr;
|
||||
int gemm_k_iterations_0{0};
|
||||
int gemm_k_size_0{0};
|
||||
int gemm_k_iterations_1{0};
|
||||
int gemm_k_size_1{0};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): mode(mode), swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
|
||||
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
|
||||
Params() = default;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
|
||||
@ -194,7 +194,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
||||
@ -33,6 +33,11 @@ cutlass_example_add_executable(
|
||||
ampere_sparse_tensorop_gemm.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
15_ampere_sparse_tensorop_gemm_universal
|
||||
ampere_sparse_tensorop_gemm_universal.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
15_ampere_sparse_tensorop_gemm_with_visitor
|
||||
ampere_sparse_tensorop_gemm_with_visitor.cu
|
||||
|
||||
@ -84,7 +84,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
||||
@ -0,0 +1,329 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere
|
||||
architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4.
|
||||
|
||||
Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of
|
||||
meta data is different for every data types. CUTLASS templates can automatically infer it based on
|
||||
input A and B. Check code below.
|
||||
|
||||
Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers
|
||||
efficiently.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_sparse_universal.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
#include "cutlass/util/host_uncompress.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output matrices and computation between
|
||||
// elements in input matrices.
|
||||
using ElementAccumulator = int32_t; // <- data type of accumulator
|
||||
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
|
||||
using ElementInputA = cutlass::int4b_t; // <- data type of elements in input matrix A
|
||||
using ElementInputB = cutlass::int4b_t; // <- data type of elements in input matrix B
|
||||
using ElementOutput = int32_t; // <- data type of elements in output matrix D
|
||||
|
||||
// The code section below describes matrix layout of input and output matrices. Row Major for
|
||||
// Matrix A, Column Major for Matrix B and Row Major for Matrix C
|
||||
using LayoutInputA = cutlass::layout::RowMajor;
|
||||
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ShapeMMAThreadBlock =
|
||||
cutlass::gemm::GemmShape<128, 128, 256>; // <- threadblock tile M = 128, N = 128, K = 256
|
||||
// This code section describes tile size a warp will compute
|
||||
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M = 64, N = 64, K = 256
|
||||
// This code section describes the size of MMA op
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // <- data type of output matrix
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
|
||||
// memory access. For a byte, it's 16
|
||||
// elements. This becomes the vector width of
|
||||
// math instructions in the epilogue too
|
||||
ElementAccumulator, // <- data type of accumulator
|
||||
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 3;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmSparseUniversal<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ShapeMMAThreadBlock,
|
||||
ShapeMMAWarp,
|
||||
ShapeMMAOp,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages>;
|
||||
|
||||
// Data type and layout of meta data matrix E can be inferred from template Gemm.
|
||||
using ElementInputE = typename Gemm::ElementE;
|
||||
using LayoutInputE = cutlass::layout::RowMajor;
|
||||
using ReorderedLayoutInputE = typename Gemm::LayoutE;
|
||||
|
||||
// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h
|
||||
// 50% Sparsity on Ampere
|
||||
constexpr int kSparse = Gemm::kSparse;
|
||||
// How many elements of A are covered per ElementE
|
||||
constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
|
||||
// The size of individual meta data
|
||||
constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
|
||||
|
||||
int run() {
|
||||
|
||||
const int length_m = 512;
|
||||
const int length_n = 512;
|
||||
const int length_k = 1024;
|
||||
|
||||
// Create a tuple of problem size for matrix multiplication
|
||||
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
|
||||
|
||||
// Initialize tensors using CUTLASS helper functions
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); // <- Create matrix A with dimensions M x (K / 2)
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a_uncompressed(
|
||||
problem_size.mk()); // <- Create uncompressed matrix A with dimensions M x K for reference computing
|
||||
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
||||
problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
|
||||
problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// CUTLASS kernel
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// reference kernel
|
||||
|
||||
// Create matrix E with dimensions M x (K / 2 / kElementsPerElementE). This one is used by reference computing.
|
||||
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||
// Same size as the above. The above one needs to be reordered and stored in this one.
|
||||
cutlass::HostTensor<ElementInputE, ReorderedLayoutInputE> tensor_e_reordered(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(2),
|
||||
ElementInputA(-2),
|
||||
0); // <- Fill matrix A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(2),
|
||||
ElementInputB(-2),
|
||||
0); // <- Fill matrix B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(2),
|
||||
ElementOutput(-2),
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomSparseMeta(
|
||||
tensor_e.host_view(),
|
||||
1,
|
||||
kMetaSizeInBits); // <- Fill matrix E on host with uniform-distribution random meta data
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d.host_view()); // <- fill matrix D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
|
||||
|
||||
// Reorder the meta data matrix so that we can use ldmatrix to load them to tensor core
|
||||
// instructions.
|
||||
cutlass::reorder_meta(tensor_e_reordered.host_ref(), tensor_e.host_ref(),
|
||||
{problem_size.m(), problem_size.n(),
|
||||
problem_size.k() / kSparse / kElementsPerElementE});
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_e_reordered.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
// Initialize alpha and beta for dot product computation
|
||||
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
|
||||
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 2;
|
||||
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size, // <- problem size of matrix multiplication
|
||||
split_k_slices,// <- k-dimension split factor
|
||||
{alpha, beta}, // <- tuple of alpha and beta
|
||||
tensor_a.device_data(), // <- reference to matrix A on device
|
||||
tensor_b.device_data(), // <- reference to matrix B on device
|
||||
tensor_c.device_data(), // <- reference to matrix C on device
|
||||
tensor_d.device_data(), // <- reference to matrix D on device
|
||||
tensor_e_reordered.device_data(), // <- reference to matrix E on device
|
||||
int64_t(),
|
||||
int64_t(),
|
||||
int64_t(),
|
||||
int64_t(),
|
||||
int64_t(),
|
||||
tensor_a.layout().stride(0),
|
||||
tensor_b.layout().stride(0),
|
||||
tensor_c.layout().stride(0),
|
||||
tensor_d.layout().stride(0),
|
||||
tensor_e_reordered.layout().stride(0)
|
||||
};
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm_op;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
status = gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// uncompress tensor_a based on meta data tensor_e. We need it for reference computing.
|
||||
cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(),
|
||||
tensor_e.host_ref(), problem_size.m(), problem_size.k());
|
||||
|
||||
// Create instantiation for host reference gemm kernel
|
||||
cutlass::reference::host::Gemm<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue,
|
||||
typename Gemm::Operator>
|
||||
gemm_host;
|
||||
|
||||
// Launch host reference gemm kernel
|
||||
gemm_host(problem_size,
|
||||
alpha,
|
||||
tensor_a_uncompressed.host_ref(),
|
||||
tensor_b.host_ref(),
|
||||
beta,
|
||||
tensor_c.host_ref(),
|
||||
tensor_ref_d.host_ref());
|
||||
|
||||
// Copy output data from CUTLASS host for comparison
|
||||
tensor_d.sync_host();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
std::cout << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
return (passed ? 0 : -1);
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.1.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples.
|
||||
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (props.major * 10 + props.minor < 80) {
|
||||
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
return run();
|
||||
}
|
||||
@ -94,7 +94,7 @@ using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 64>; // <- MMA Op tile M = 1
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
using Operator = cutlass::arch::OpMultiplyAddSaturate;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 3;
|
||||
|
||||
@ -265,6 +265,10 @@ constexpr int NumStages = 3;
|
||||
// Which iterator algorithm to use: Analytic or Optimized
|
||||
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized;
|
||||
|
||||
// Is the output packed or strided
|
||||
// Use kStride if using strided output
|
||||
static cutlass::conv::StrideSupport const OutputStride = cutlass::conv::StrideSupport::kUnity;
|
||||
|
||||
// The epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
@ -289,7 +293,8 @@ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
IteratorAlgorithm
|
||||
IteratorAlgorithm,
|
||||
OutputStride
|
||||
>::Kernel;
|
||||
|
||||
// Type of the actual kernel
|
||||
|
||||
@ -27,10 +27,14 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
set(TEST_STANDARD --m=1024 --n=1024 --k=1024)
|
||||
set(TEST_LARGE_PERFCHECK --m=4096 --n=3456 --k=4096 --perf-check)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
23_ampere_gemm_operand_reduction_fusion
|
||||
ampere_gemm_operand_reduction_fusion.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_STANDARD
|
||||
TEST_LARGE_PERFCHECK
|
||||
)
|
||||
|
||||
|
||||
@ -138,7 +138,7 @@ using Gemm = typename cutlass::gemm::device::GemmWithKReduction<
|
||||
>;
|
||||
|
||||
// Below is the reduction kernel used in the case of parallel split-k
|
||||
using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;;
|
||||
using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;
|
||||
|
||||
using ReduceOp = cutlass::reduction::thread::ReduceAdd<
|
||||
ElementAccumulator,
|
||||
@ -154,7 +154,7 @@ using ReduceGemmSplitKKernel = cutlass::reduction::kernel::ReduceSplitK<
|
||||
|
||||
using ReduceGemmSplitK = cutlass::reduction::device::ReduceSplitK<ReduceGemmSplitKKernel>;
|
||||
|
||||
using ReduceVectorSplitKShape = cutlass::MatrixShape<1, 256>;;
|
||||
using ReduceVectorSplitKShape = cutlass::MatrixShape<1, 256>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using DummyEpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
@ -377,22 +377,22 @@ Result profile(Options const &options) {
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1997,
|
||||
ElementInputA(2),
|
||||
ElementInputA(-2),
|
||||
ElementInputA(1),
|
||||
ElementInputA(-1),
|
||||
0); // <- Fill tensor A on host with uniform-distribution random data
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
2003,
|
||||
ElementInputB(2),
|
||||
ElementInputB(-2),
|
||||
ElementInputB(1),
|
||||
ElementInputB(-1),
|
||||
0); // <- Fill tensor B on host with uniform-distribution random data
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
2017,
|
||||
ElementOutput(2),
|
||||
ElementOutput(-2),
|
||||
ElementOutput(1),
|
||||
ElementOutput(-1),
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d.host_view()); // <- fill matrix D on host with zeros
|
||||
|
||||
@ -789,7 +789,7 @@ public:
|
||||
problem_count_check += bin.second.size();
|
||||
}
|
||||
|
||||
if (problem_count_check != this->problem_count()) {
|
||||
if (problem_count_check != size_t(this->problem_count())) {
|
||||
std::cout << "\n***\nERROR in BINNING LOGIC!\n***\n" << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ implicitly to tf32 inside the GEMM kernel which means no change is needed to acc
|
||||
fp32 data by using NVIDIA Ampere architecture.
|
||||
|
||||
We can use the tf32 mode of tensor core to emulate a fast accurate SGEMM kernel which is accelerated
|
||||
using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h).
|
||||
using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h).
|
||||
|
||||
The trick is very simple
|
||||
a x b = (a_big + a_small) x (b_big + b_small) = a_big x b_big + a_big x b_small + a_small x b_big
|
||||
@ -45,11 +45,11 @@ The trick is very simple
|
||||
|
||||
a_small x b_small is discarded because they are too small.
|
||||
|
||||
This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual FP32
|
||||
This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual FP32
|
||||
results (SGEMM using SIMT) and against FP64 results (DGEMM)
|
||||
|
||||
To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to
|
||||
OpMultiplyAddFastF32.
|
||||
To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to
|
||||
OpMultiplyAddFastF32.
|
||||
|
||||
Now, we have several different flavors of sgemm now in the profiler for Ampere. Here are the difference
|
||||
|
||||
@ -97,14 +97,14 @@ struct Result {
|
||||
double l2_norm_fp32_vs_fp64;
|
||||
|
||||
// ctor
|
||||
Result(
|
||||
Result(
|
||||
int m, int n, int k,
|
||||
double runtime_ms, double gflops,
|
||||
double l2_norm_3xtf32_vs_fp64,
|
||||
double l2_norm_1xtf32_vs_fp64,
|
||||
double l2_norm_fp32_vs_fp64) :
|
||||
double l2_norm_fp32_vs_fp64) :
|
||||
m(m), n(n), k(k),
|
||||
runtime_ms(runtime_ms), gflops(gflops),
|
||||
runtime_ms(runtime_ms), gflops(gflops),
|
||||
l2_norm_3xtf32_vs_fp64(l2_norm_3xtf32_vs_fp64),
|
||||
l2_norm_1xtf32_vs_fp64(l2_norm_1xtf32_vs_fp64),
|
||||
l2_norm_fp32_vs_fp64(l2_norm_fp32_vs_fp64) {}
|
||||
@ -147,7 +147,7 @@ struct Options {
|
||||
int iterations;
|
||||
int seed;
|
||||
bool benchmark;
|
||||
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
problem_size({3456, 4096, 4096}),
|
||||
@ -190,7 +190,7 @@ struct Options {
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("seed", seed);
|
||||
cmd.get_cmd_line_argument("rand_mode", rand_mode);
|
||||
@ -227,9 +227,9 @@ struct Options {
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of real-valued multiply-adds
|
||||
// Number of real-valued multiply-adds
|
||||
int64_t fmas = problem_size.product();
|
||||
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
@ -258,7 +258,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 16>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
@ -272,10 +272,10 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 3;
|
||||
// Alignment
|
||||
// Alignment
|
||||
constexpr int Alignment = 4;
|
||||
|
||||
//
|
||||
//
|
||||
// Gemm Operators (Gemm_3xTF32, Gemm_1xTF32, GEMM_F32, GEMM_F64)
|
||||
//
|
||||
|
||||
@ -296,7 +296,7 @@ using Gemm_3xTF32 = cutlass::gemm::device::Gemm<
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
Alignment,
|
||||
Alignment,
|
||||
Alignment,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddFastF32>;
|
||||
@ -318,7 +318,7 @@ using Gemm_1xTF32 = cutlass::gemm::device::Gemm<
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
Alignment,
|
||||
Alignment,
|
||||
Alignment,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAdd>;
|
||||
@ -356,7 +356,7 @@ bool run(Options &options) {
|
||||
cutlass::HostTensor<float, LayoutInputA> tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<float, LayoutInputB> tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<float, LayoutOutput> tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<float, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
cutlass::HostTensor<float, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
|
||||
if (options.rand_mode == "uniform") {
|
||||
const float min = -1;
|
||||
@ -397,7 +397,7 @@ bool run(Options &options) {
|
||||
}
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d_F32.host_view()); // <- fill matrix D on host with zeros
|
||||
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a_F32.sync_device();
|
||||
tensor_b_F32.sync_device();
|
||||
@ -411,7 +411,7 @@ bool run(Options &options) {
|
||||
cutlass::HostTensor<double, LayoutInputA> tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<double, LayoutInputB> tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<double, LayoutOutput> tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
|
||||
|
||||
// Gemm output (D) for GEMM_F64
|
||||
cutlass::HostTensor<double, LayoutOutput> tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
// Gemm output (D) for GEMM_3xTF32
|
||||
@ -426,7 +426,7 @@ bool run(Options &options) {
|
||||
cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view());
|
||||
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a_F64.sync_device();
|
||||
tensor_b_F64.sync_device();
|
||||
@ -464,7 +464,7 @@ bool run(Options &options) {
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm_3xTF32 gemm_op_3xTF32;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status_3xtf32 = gemm_op_3xTF32.can_implement(arguments_3xtf32);
|
||||
CUTLASS_CHECK(status_3xtf32);
|
||||
|
||||
@ -568,7 +568,7 @@ bool run(Options &options) {
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm_1xTF32 gemm_op_1xtf32;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status_1xtf32 = gemm_op_1xtf32.can_implement(arguments_1xtf32);
|
||||
CUTLASS_CHECK(status_1xtf32);
|
||||
|
||||
@ -627,7 +627,7 @@ bool run(Options &options) {
|
||||
tensor_d_F32.sync_host();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/////// Compute l2 norms
|
||||
/////// Compute l2 norms
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// l2 norm 3xTF32 vs F64
|
||||
@ -664,7 +664,7 @@ bool run(Options &options) {
|
||||
std::cout << "GFLOPs: " << result.gflops << std::endl;
|
||||
std::cout << "Normalized L2 norm of" << std::endl;
|
||||
std::cout.precision(8);
|
||||
std::cout << std::scientific
|
||||
std::cout << std::scientific
|
||||
<< " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl
|
||||
<< " - 1xTF32 error with FP64 reference : " << result.l2_norm_1xtf32_vs_fp64 << std::endl
|
||||
<< " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl;
|
||||
@ -673,11 +673,11 @@ bool run(Options &options) {
|
||||
}
|
||||
|
||||
int main(int argc, const char **argv) {
|
||||
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.0.
|
||||
// in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
|
||||
@ -690,7 +690,7 @@ int main(int argc, const char **argv) {
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 80)) {
|
||||
@ -716,17 +716,17 @@ int main(int argc, const char **argv) {
|
||||
|
||||
if (options.benchmark) {
|
||||
for (int k = 4; k <= 65536; k *= 2) {
|
||||
|
||||
|
||||
options.problem_size[2] = k;
|
||||
|
||||
|
||||
printf("Gemm problem size: %d x %d x %d\n", \
|
||||
options.problem_size.m(), options.problem_size.n(), options.problem_size.k());
|
||||
|
||||
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
result &= run(options);
|
||||
}
|
||||
} else {
|
||||
|
||||
@ -34,7 +34,7 @@
|
||||
difference is that this example uses 3xtf32 on complex gemm.
|
||||
|
||||
To enable this feature, the only change needs to make is to change OpMultiplyAddComplex
|
||||
to OpMultiplyAddComplexFastF32.
|
||||
to OpMultiplyAddComplexFastF32.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
@ -74,14 +74,14 @@ struct Result {
|
||||
double l2_norm_fp32_vs_fp64;
|
||||
|
||||
// ctor
|
||||
Result(
|
||||
Result(
|
||||
int m, int n, int k,
|
||||
double runtime_ms, double gflops,
|
||||
double l2_norm_3xtf32_vs_fp64,
|
||||
double l2_norm_1xtf32_vs_fp64,
|
||||
double l2_norm_fp32_vs_fp64) :
|
||||
double l2_norm_fp32_vs_fp64) :
|
||||
m(m), n(n), k(k),
|
||||
runtime_ms(runtime_ms), gflops(gflops),
|
||||
runtime_ms(runtime_ms), gflops(gflops),
|
||||
l2_norm_3xtf32_vs_fp64(l2_norm_3xtf32_vs_fp64),
|
||||
l2_norm_1xtf32_vs_fp64(l2_norm_1xtf32_vs_fp64),
|
||||
l2_norm_fp32_vs_fp64(l2_norm_fp32_vs_fp64) {}
|
||||
@ -124,7 +124,7 @@ struct Options {
|
||||
int iterations;
|
||||
int seed;
|
||||
bool benchmark;
|
||||
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
problem_size({3456, 4096, 4096}),
|
||||
@ -153,7 +153,7 @@ struct Options {
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("seed", seed);
|
||||
cmd.get_cmd_line_argument("rand_mode", rand_mode);
|
||||
@ -190,9 +190,9 @@ struct Options {
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of real-valued multiply-adds
|
||||
// Number of real-valued multiply-adds
|
||||
int64_t fmas = problem_size.product();
|
||||
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
@ -221,7 +221,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
@ -239,7 +239,7 @@ constexpr int NumStages = 3;
|
||||
constexpr cutlass::ComplexTransform TransformA = cutlass::ComplexTransform::kNone;
|
||||
constexpr cutlass::ComplexTransform TransformB = cutlass::ComplexTransform::kNone;
|
||||
|
||||
//
|
||||
//
|
||||
// Gemm Operators (Gemm_3xTF32, Gemm_1xTF32, GEMM_F32, GEMM_F64)
|
||||
//
|
||||
|
||||
@ -260,7 +260,7 @@ using Gemm_3xTF32 = cutlass::gemm::device::GemmComplex<
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
TransformA,
|
||||
TransformA,
|
||||
TransformB,
|
||||
cutlass::arch::OpMultiplyAddComplexFastF32>;
|
||||
|
||||
@ -281,7 +281,7 @@ using Gemm_1xTF32 = cutlass::gemm::device::GemmComplex<
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
TransformA,
|
||||
TransformA,
|
||||
TransformB,
|
||||
cutlass::arch::OpMultiplyAddComplex>;
|
||||
|
||||
@ -296,7 +296,7 @@ bool run(Options &options) {
|
||||
cutlass::HostTensor<cutlass::complex<float>, LayoutInputA> tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<cutlass::complex<float>, LayoutInputB> tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<cutlass::complex<float>, LayoutOutput> tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<cutlass::complex<float>, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
cutlass::HostTensor<cutlass::complex<float>, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
|
||||
if (options.rand_mode == "uniform") {
|
||||
const float min = -1;
|
||||
@ -337,7 +337,7 @@ bool run(Options &options) {
|
||||
}
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d_F32.host_view()); // <- fill matrix D on host with zeros
|
||||
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a_F32.sync_device();
|
||||
tensor_b_F32.sync_device();
|
||||
@ -351,7 +351,7 @@ bool run(Options &options) {
|
||||
cutlass::HostTensor<cutlass::complex<double>, LayoutInputA> tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<cutlass::complex<double>, LayoutInputB> tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<cutlass::complex<double>, LayoutOutput> tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
|
||||
|
||||
// Gemm output (D) for GEMM_F64
|
||||
cutlass::HostTensor<cutlass::complex<double>, LayoutOutput> tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
// Gemm output (D) for GEMM_3xTF32
|
||||
@ -366,7 +366,7 @@ bool run(Options &options) {
|
||||
cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view());
|
||||
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a_F64.sync_device();
|
||||
tensor_b_F64.sync_device();
|
||||
@ -404,7 +404,7 @@ bool run(Options &options) {
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm_3xTF32 gemm_op;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status_3xtf32 = gemm_op.can_implement(arguments_3xtf32);
|
||||
CUTLASS_CHECK(status_3xtf32);
|
||||
|
||||
@ -508,7 +508,7 @@ bool run(Options &options) {
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm_1xTF32 gemm_op_1xtf32;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status_1xtf32 = gemm_op_1xtf32.can_implement(arguments_1xtf32);
|
||||
CUTLASS_CHECK(status_1xtf32);
|
||||
|
||||
@ -569,7 +569,7 @@ bool run(Options &options) {
|
||||
tensor_d_F32.sync_host();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/////// Compute l2 norms
|
||||
/////// Compute l2 norms
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// l2 norm 3xTF32 vs F64
|
||||
@ -606,7 +606,7 @@ bool run(Options &options) {
|
||||
std::cout << "GFLOPs: " << result.gflops << std::endl;
|
||||
std::cout << "Normalized L2 norm of" << std::endl;
|
||||
std::cout.precision(8);
|
||||
std::cout << std::scientific
|
||||
std::cout << std::scientific
|
||||
<< " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl
|
||||
<< " - 1xTF32 error with FP64 reference : " << result.l2_norm_1xtf32_vs_fp64 << std::endl
|
||||
<< " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl;
|
||||
@ -615,11 +615,11 @@ bool run(Options &options) {
|
||||
}
|
||||
|
||||
int main(int argc, const char **argv) {
|
||||
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.0.
|
||||
// in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
|
||||
@ -632,7 +632,7 @@ int main(int argc, const char **argv) {
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 80)) {
|
||||
@ -658,17 +658,17 @@ int main(int argc, const char **argv) {
|
||||
|
||||
if (options.benchmark) {
|
||||
for (int k = 4; k <= 65536; k *= 2) {
|
||||
|
||||
|
||||
options.problem_size[2] = k;
|
||||
|
||||
|
||||
printf("Gemm problem size: %d x %d x %d\n", \
|
||||
options.problem_size.m(), options.problem_size.n(), options.problem_size.k());
|
||||
|
||||
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
result &= run(options);
|
||||
}
|
||||
} else {
|
||||
|
||||
@ -113,10 +113,10 @@ cudaError_t CutlassSsyrkNN(
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
|
||||
5, // Stages
|
||||
1, // AligmentA
|
||||
1, // AlignmentA
|
||||
false, // SplitKSerail
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::BlasMode::kSymmetric
|
||||
>;
|
||||
|
||||
@ -149,7 +149,7 @@ cudaError_t CutlassSsyrkNN(
|
||||
//
|
||||
// Launch the CUTLASS SYRK kernel.
|
||||
//
|
||||
|
||||
|
||||
cutlass::Status status = syrk_operator(args);
|
||||
|
||||
//
|
||||
|
||||
@ -36,7 +36,7 @@ implicitly to tf32 inside the SYMM kernel which means no change is needed to acc
|
||||
F32 data by using NVIDIA Ampere architecture.
|
||||
|
||||
We can use the tf32 mode of tensor core to emulate a fast accurate SYMM kernel which is accelerated
|
||||
using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h).
|
||||
using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h).
|
||||
|
||||
The trick is very simple
|
||||
a x b = (a_big + a_small) x (b_big + b_small) = a_big x b_big + a_big x b_small + a_small x b_big
|
||||
@ -45,11 +45,11 @@ The trick is very simple
|
||||
|
||||
a_small x b_small is discarded because they are too small.
|
||||
|
||||
This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual F32
|
||||
This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual F32
|
||||
results (SSYMM from cuBLAS) and against F64 results (DSYMM from CUTLASS)
|
||||
|
||||
To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to
|
||||
OpMultiplyAddFastF32.
|
||||
To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to
|
||||
OpMultiplyAddFastF32.
|
||||
|
||||
Now, we have two different flavors of SSYMM in the profiler for Ampere:
|
||||
|
||||
@ -95,7 +95,7 @@ struct Options {
|
||||
float beta;
|
||||
std::string rand_mode;
|
||||
int seed;
|
||||
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
problem_size({4096, 4096, 4096}),
|
||||
@ -137,7 +137,7 @@ struct Options {
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
|
||||
cmd.get_cmd_line_argument("seed", seed);
|
||||
cmd.get_cmd_line_argument("rand_mode", rand_mode);
|
||||
|
||||
@ -193,7 +193,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 16>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
@ -207,10 +207,10 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 3;
|
||||
// Alignment
|
||||
// Alignment
|
||||
constexpr int Alignment = 4;
|
||||
|
||||
//
|
||||
//
|
||||
// CUTLASS Symm Operators (SSYM: Symm_3xTF32, Symm_1xTF32, DSYMM: Symm_F64)
|
||||
//
|
||||
|
||||
@ -233,7 +233,7 @@ using Symm_3xTF32 = cutlass::gemm::device::Symm<
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
1, // Symmetric matrix is always align 1
|
||||
1, // Symmetric matrix is always align 1
|
||||
Alignment,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddFastF32>;
|
||||
@ -257,7 +257,7 @@ using Symm_1xTF32 = cutlass::gemm::device::Symm<
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
1, // Symmetric matrix is always align 1
|
||||
1, // Symmetric matrix is always align 1
|
||||
Alignment,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAdd>;
|
||||
@ -298,7 +298,7 @@ bool run(Options &options) {
|
||||
cutlass::HostTensor<float, LayoutInputA> tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<float, LayoutInputB> tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<float, LayoutOutput> tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<float, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
cutlass::HostTensor<float, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
|
||||
if (options.rand_mode == "uniform") {
|
||||
const float min = -1;
|
||||
@ -339,7 +339,7 @@ bool run(Options &options) {
|
||||
}
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d_F32.host_view()); // <- fill matrix D on host with zeros
|
||||
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a_F32.sync_device();
|
||||
tensor_b_F32.sync_device();
|
||||
@ -353,7 +353,7 @@ bool run(Options &options) {
|
||||
cutlass::HostTensor<double, LayoutInputA> tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<double, LayoutInputB> tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<double, LayoutOutput> tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
|
||||
|
||||
// Symm output (D) for SYMM_3xTF32
|
||||
cutlass::HostTensor<float, LayoutOutput> tensor_d_3xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
// Symm output (D) for SYMM_1xTF32
|
||||
@ -375,7 +375,7 @@ bool run(Options &options) {
|
||||
#if CUTLASS_ENABLE_CUBLAS
|
||||
cutlass::reference::host::TensorCopy(tensor_d_cublasF32.host_view(), tensor_d_F32.host_view());
|
||||
#endif
|
||||
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a_F64.sync_device();
|
||||
tensor_b_F64.sync_device();
|
||||
@ -430,7 +430,7 @@ bool run(Options &options) {
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Symm_3xTF32 symm_op_3xtf32;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status_3xtf32 = symm_op_3xtf32.can_implement(arguments_3xtf32);
|
||||
CUTLASS_CHECK(status_3xtf32);
|
||||
|
||||
@ -477,7 +477,7 @@ bool run(Options &options) {
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Symm_1xTF32 symm_op_1xtf32;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status_1xtf32 = symm_op_1xtf32.can_implement(arguments_1xtf32);
|
||||
CUTLASS_CHECK(status_1xtf32);
|
||||
|
||||
@ -524,7 +524,7 @@ bool run(Options &options) {
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Symm_F64 symm_op_f64;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status_f64 = symm_op_f64.can_implement(arguments_f64);
|
||||
CUTLASS_CHECK(status_f64);
|
||||
|
||||
@ -568,7 +568,7 @@ bool run(Options &options) {
|
||||
static_cast<const float*>(&beta),
|
||||
static_cast<float*>(tensor_d_cublasF32.device_data()),
|
||||
int(tensor_d_cublasF32.layout().stride(0))
|
||||
);
|
||||
);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
@ -576,7 +576,7 @@ bool run(Options &options) {
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// 7. Compute l2 norms
|
||||
/// 7. Compute l2 norms
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUTLASS_ENABLE_CUBLAS
|
||||
@ -605,20 +605,20 @@ bool run(Options &options) {
|
||||
double l2_norm_3xtf32_vs_cublasf32 = cutlass::reference::host::TensorRelativeErrorMetric(
|
||||
tensor_d_3xTF32.host_view(), tensor_d_cublasF32.host_view());
|
||||
#endif
|
||||
|
||||
|
||||
// l2 norm 3xTF32 vs 1xTF32
|
||||
double l2_norm_3xtf32_vs_1xtf32 = cutlass::reference::host::TensorRelativeErrorMetric(
|
||||
tensor_d_3xTF32.host_view(), tensor_d_1xTF32.host_view());
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Print kernel info and L2 norms
|
||||
// Print kernel info and L2 norms
|
||||
std::cout << "Problem Size: (" << problem_size.m() << "," << problem_size.n() << "," << problem_size.k() << ") "
|
||||
<< "Alpha: " << alpha << "," << " Beta: " << beta << std::endl;
|
||||
std::cout << std::fixed;
|
||||
std::cout << "Normalized L2 norm of" << std::endl;
|
||||
std::cout.precision(8);
|
||||
std::cout << std::scientific
|
||||
std::cout << std::scientific
|
||||
#if CUTLASS_ENABLE_CUBLAS
|
||||
<< " - cuBLAS F32 error with F64 reference : " << l2_norm_cublasf32_vs_f64 << std::endl
|
||||
#endif
|
||||
@ -633,11 +633,11 @@ bool run(Options &options) {
|
||||
}
|
||||
|
||||
int main(int argc, const char **argv) {
|
||||
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.0.
|
||||
// in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
|
||||
@ -650,7 +650,7 @@ int main(int argc, const char **argv) {
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 80)) {
|
||||
|
||||
@ -42,7 +42,8 @@
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/device/gemm_complex.h"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/numeric_size.h"
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
|
||||
@ -56,6 +57,7 @@
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/error_metrics.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
@ -456,7 +458,7 @@ struct Testbed {
|
||||
bool verify_tensor(std::vector<Element> vector_Input, \
|
||||
std::vector<Element> vector_Input_Ref) {
|
||||
|
||||
int64_t size = (vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size();
|
||||
auto size = int64_t((vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size());
|
||||
float abs_tol = options.tolerance;
|
||||
float rel_tol = options.tolerance;
|
||||
|
||||
@ -657,7 +659,9 @@ struct Testbed {
|
||||
}
|
||||
|
||||
int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2;
|
||||
int64_t bytes = (sizeof(ElementD) * 2 + sizeof(ElementSoftmax)) * options.problem_size.m() * options.problem_size.n();
|
||||
int64_t bytes = cutlass::bits_to_bytes(
|
||||
(cutlass::sizeof_bits<ElementD>::value * 2 + cutlass::sizeof_bits<ElementSoftmax>::value) *
|
||||
options.problem_size.m() * options.problem_size.n());
|
||||
|
||||
double gflops_per_second = double(flops) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1.0e9);
|
||||
double gbytes_per_second = double(bytes) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1 << 30);
|
||||
|
||||
@ -59,11 +59,11 @@
|
||||
// Also, we don't check the index value is legal and index array point is valid
|
||||
// for the sake of the performance.
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <time.h>
|
||||
#include <math.h>
|
||||
#include <assert.h>
|
||||
#include <cstdlib>
|
||||
#include <cstdio>
|
||||
#include <ctime>
|
||||
#include <cmath>
|
||||
#include <cassert>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <algorithm>
|
||||
@ -215,7 +215,7 @@ using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; // <- MMA Op tile M = 8
|
||||
// 16, 8, 16 -> Ampere
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Define the epilogue operation as LinearCombination. This is approximately equal to
|
||||
//
|
||||
|
||||
@ -454,48 +454,48 @@ struct Testbed {
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_A0.host_view(),
|
||||
options.seed,
|
||||
ElementInputA0(5),
|
||||
ElementInputA0(-5),
|
||||
ElementInputA0(4),
|
||||
ElementInputA0(-4),
|
||||
0
|
||||
);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_B0.host_view(),
|
||||
options.seed + 1,
|
||||
ElementInputB0(5),
|
||||
ElementInputB0(-5),
|
||||
ElementInputB0(4),
|
||||
ElementInputB0(-4),
|
||||
0
|
||||
);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_A1.host_view(),
|
||||
options.seed + 2,
|
||||
ElementInputA1(5),
|
||||
ElementInputA1(-5),
|
||||
ElementInputA1(4),
|
||||
ElementInputA1(-4),
|
||||
0
|
||||
);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_Beta.host_view(),
|
||||
options.seed + 3,
|
||||
ElementInputScaleBias(5),
|
||||
ElementInputScaleBias(-5),
|
||||
ElementInputScaleBias(4),
|
||||
ElementInputScaleBias(-4),
|
||||
0
|
||||
);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_Gamma.host_view(),
|
||||
options.seed + 4,
|
||||
ElementInputScaleBias(5),
|
||||
ElementInputScaleBias(-5),
|
||||
ElementInputScaleBias(4),
|
||||
ElementInputScaleBias(-4),
|
||||
0
|
||||
);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_Shifted_K.host_view(),
|
||||
options.seed + 5,
|
||||
ElementOutput(5),
|
||||
ElementOutput(-6),
|
||||
ElementOutput(4),
|
||||
ElementOutput(-5),
|
||||
0
|
||||
);
|
||||
|
||||
|
||||
@ -803,7 +803,7 @@ public:
|
||||
// Use 'D' for the in/out workspace
|
||||
this->block_D.copy_from_device(this->block_C.get());
|
||||
|
||||
for (int i = 0; i < this->options.problem_sizes.size(); ++i) {
|
||||
for (size_t i = 0; i < this->options.problem_sizes.size(); ++i) {
|
||||
cutlass::gemm::GemmCoord const & problem = this->options.problem_sizes[i];
|
||||
int32_t batch_count = 1;
|
||||
int64_t lda = this->lda_host.at(i);
|
||||
@ -904,10 +904,10 @@ public:
|
||||
// Run profiling loop
|
||||
//
|
||||
|
||||
int last_stream_idx = 0;
|
||||
size_t last_stream_idx = 0;
|
||||
|
||||
for (int iter = 0; iter < this->options.iterations; ++iter) {
|
||||
for (int i = 0; i < this->options.problem_sizes.size(); ++i) {
|
||||
for (size_t i = 0; i < this->options.problem_sizes.size(); ++i) {
|
||||
cutlass::gemm::GemmCoord const & problem = this->options.problem_sizes[i];
|
||||
int32_t batch_count = 1;
|
||||
int64_t lda = this->lda_host.at(i);
|
||||
@ -1146,7 +1146,7 @@ public:
|
||||
);
|
||||
|
||||
// Initialize the Rank2K object
|
||||
Rank2K rank2k;
|
||||
Rank2K rank2k{};
|
||||
size_t workspace_size = rank2k.get_workspace_size(args);
|
||||
cutlass::DeviceAllocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
|
||||
@ -33,11 +33,7 @@
|
||||
computing reference permutations of 4/5D tensors when source data is column-major.
|
||||
*/
|
||||
#pragma once
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include "assert.h"
|
||||
#endif
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
@ -118,7 +118,7 @@ operation = Conv2dOperation(
|
||||
conv_kind=cutlass_bindings.conv.Operator.fprop,
|
||||
iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized,
|
||||
arch=cc, tile_description=tile_description,
|
||||
A=A, B=B, C=C, stride_support=StrideSupport.Strided,
|
||||
A=A, B=B, C=C, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor
|
||||
)
|
||||
|
||||
|
||||
@ -30,8 +30,8 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
#include <float.h>
|
||||
#include <stdio.h>
|
||||
#include <cfloat>
|
||||
#include <cstdio>
|
||||
#include <cmath>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -40,7 +40,7 @@
|
||||
// Nans & inf detection
|
||||
#define NANCHECK(frag) \
|
||||
{ \
|
||||
for (int _i = 0; _i < frag.size(); ++_i) { \
|
||||
for (size_t _i = 0; _i < frag.size(); ++_i) { \
|
||||
assert(std::isfinite(float(frag[_i]))); \
|
||||
assert(!std::isnan(float(frag[_i]))); \
|
||||
} \
|
||||
@ -147,7 +147,7 @@ constexpr __string_view __get_type_name() {
|
||||
{ \
|
||||
auto typeStr = __get_type_name<decltype(frag)>(); \
|
||||
PRINT_B0_T0("printing %s (%s)", name, typeStr.data); \
|
||||
for (int _start = 0; _start < frag.size(); _start += 8) { \
|
||||
for (size_t _start = 0; _start < frag.size(); _start += 8) { \
|
||||
PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \
|
||||
} \
|
||||
/*__syncthreads(); \
|
||||
|
||||
@ -43,11 +43,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
@ -57,12 +53,9 @@
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
@ -43,11 +43,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
@ -57,16 +53,12 @@
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/thread/scale_type.h"
|
||||
|
||||
@ -167,58 +167,39 @@ public:
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord *problem_sizes0;
|
||||
GemmCoord *problem_sizes1;
|
||||
GemmCoord *problem_sizes0{nullptr};
|
||||
GemmCoord *problem_sizes1{nullptr};
|
||||
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
int problem_count{0};
|
||||
int threadblock_count{0};
|
||||
|
||||
ElementQ ** ptr_Q;
|
||||
ElementK ** ptr_K;
|
||||
ElementP ** ptr_P;
|
||||
ElementV ** ptr_V;
|
||||
ElementO ** ptr_O;
|
||||
ElementOAccum ** ptr_O_accum;
|
||||
ElementQ ** ptr_Q{nullptr};
|
||||
ElementK ** ptr_K{nullptr};
|
||||
ElementP ** ptr_P{nullptr};
|
||||
ElementV ** ptr_V{nullptr};
|
||||
ElementO ** ptr_O{nullptr};
|
||||
ElementOAccum ** ptr_O_accum{nullptr};
|
||||
|
||||
typename LayoutQ::Stride::LongIndex *ldq;
|
||||
typename LayoutK::Stride::LongIndex *ldk;
|
||||
typename LayoutP::Stride::LongIndex *ldv;
|
||||
typename LayoutO::Stride::LongIndex *ldo;
|
||||
|
||||
// Scale
|
||||
ElementAccumulator scale;
|
||||
typename LayoutQ::Stride::LongIndex *ldq{nullptr};
|
||||
typename LayoutK::Stride::LongIndex *ldk{nullptr};
|
||||
typename LayoutP::Stride::LongIndex *ldv{nullptr};
|
||||
typename LayoutO::Stride::LongIndex *ldo{nullptr};
|
||||
|
||||
// Whether causal masking is to be performed
|
||||
bool causal;
|
||||
bool causal{false};
|
||||
|
||||
// Scale
|
||||
ElementAccumulator scale{0};
|
||||
|
||||
// Only used by device-level operator
|
||||
GemmCoord *host_problem_sizes;
|
||||
GemmCoord *host_problem_sizes{nullptr};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
problem_count(0),
|
||||
threadblock_count(0),
|
||||
ptr_Q(nullptr),
|
||||
ptr_K(nullptr),
|
||||
ptr_P(nullptr),
|
||||
ptr_V(nullptr),
|
||||
ptr_O(nullptr),
|
||||
ptr_O_accum(nullptr),
|
||||
ldq(nullptr),
|
||||
ldk(nullptr),
|
||||
ldv(nullptr),
|
||||
ldo(nullptr),
|
||||
scale(0),
|
||||
causal(false),
|
||||
host_problem_sizes(nullptr)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
/// Default ctor
|
||||
Arguments() = default;
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -569,7 +550,7 @@ public:
|
||||
|
||||
auto prologueV = [&](int blockN) {
|
||||
typename MM1::Mma::IteratorB iterator_V(
|
||||
typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])},
|
||||
typename MM1::IteratorB::Params{typename MM1::LayoutB(params.ldv[problem_idx])},
|
||||
params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx],
|
||||
{problem_size_1_k, problem_size_1_n},
|
||||
thread_id(),
|
||||
@ -738,7 +719,7 @@ public:
|
||||
}
|
||||
|
||||
typename MM1::Mma::IteratorB iterator_V(
|
||||
typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])},
|
||||
typename MM1::IteratorB::Params{typename MM1::LayoutB(params.ldv[problem_idx])},
|
||||
params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx],
|
||||
{problem_size_1_k, problem_size_1_n},
|
||||
thread_id(),
|
||||
@ -780,15 +761,15 @@ public:
|
||||
using EpilogueOutputOp = typename cutlass::epilogue::
|
||||
thread::MemoryEfficientAttentionNormalize<
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
kIsLast::value,
|
||||
output_t,
|
||||
output_accum_t>::type,
|
||||
output_accum_t,
|
||||
DefaultOp::kCount,
|
||||
typename DefaultOp::ElementAccumulator,
|
||||
output_accum_t,
|
||||
kIsFirst,
|
||||
kIsLast,
|
||||
kIsFirst::value,
|
||||
kIsLast::value,
|
||||
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
EpiloguePipelined<
|
||||
@ -796,7 +777,7 @@ public:
|
||||
typename MM1::Mma::Operator,
|
||||
DefaultEpilogue::kPartitionsK,
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
kIsLast::value,
|
||||
typename MM1::OutputTileIterator,
|
||||
typename MM1::OutputTileIteratorAccum>::type,
|
||||
typename DefaultEpilogue::
|
||||
@ -814,7 +795,7 @@ public:
|
||||
int col = blockN * MM1::Mma::Shape::kN;
|
||||
auto source_iter = createOutputAccumIter(col);
|
||||
auto dest_iter = gemm_kernel_utils::call_conditional<
|
||||
kIsLast,
|
||||
kIsLast::value,
|
||||
decltype(createOutputIter),
|
||||
decltype(createOutputAccumIter)>::
|
||||
apply(createOutputIter, createOutputAccumIter, col);
|
||||
@ -836,8 +817,8 @@ public:
|
||||
}
|
||||
|
||||
if (kKeepOutputInRF) {
|
||||
const bool kIsFirst = true;
|
||||
const bool kIsLast = true;
|
||||
constexpr bool kIsFirst = true;
|
||||
constexpr bool kIsLast = true;
|
||||
using DefaultEpilogue = typename MM1::DefaultEpilogue;
|
||||
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
|
||||
using ElementCompute = typename DefaultOp::ElementCompute;
|
||||
|
||||
@ -286,7 +286,7 @@ struct Options {
|
||||
// Number of real-valued multiply-adds
|
||||
int64_t fops = int64_t();
|
||||
|
||||
for (int i = 0; i < problem_sizes0.size(); ++i) {
|
||||
for (size_t i = 0; i < problem_sizes0.size(); ++i) {
|
||||
auto const& problem0 = problem_sizes0[i];
|
||||
auto const& problem1 = problem_sizes1[i];
|
||||
for (int row = 0; row < problem0.m(); ++row) {
|
||||
|
||||
@ -340,7 +340,7 @@ struct Options {
|
||||
// Number of real-valued multiply-adds
|
||||
int64_t fops = int64_t();
|
||||
|
||||
for (int i = 0; i < problem_sizes0.size(); ++i) {
|
||||
for (size_t i = 0; i < problem_sizes0.size(); ++i) {
|
||||
auto const& problem0 = problem_sizes0[i];
|
||||
auto const& problem1 = problem_sizes1[i];
|
||||
|
||||
|
||||
@ -244,11 +244,13 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
|
||||
CUTLASS_DEVICE
|
||||
bool set_prologue_done(bool value) {
|
||||
prologue_done_ = value;
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool set_zero_outside_bounds(bool value) {
|
||||
zero_outside_bounds_ = value;
|
||||
return true;
|
||||
}
|
||||
|
||||
template <bool kLoadA = true, bool kLoadB = true>
|
||||
|
||||
@ -1799,7 +1799,7 @@ struct B2bGemm<
|
||||
if (rowIdx == 1) {
|
||||
lse_prefetched[colIdx] = accum_n < lse_extent
|
||||
? lse[accum_n]
|
||||
: platform::numeric_limits<accum_t>::infinity();
|
||||
: cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]);
|
||||
++colIdx;
|
||||
@ -1938,7 +1938,7 @@ struct B2bGemm<
|
||||
if (rowIdx == 1) {
|
||||
lse_prefetched[colIdx] = accum_n < lse_extent
|
||||
? lse[accum_n]
|
||||
: platform::numeric_limits<accum_t>::infinity();
|
||||
: cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]);
|
||||
++colIdx;
|
||||
|
||||
@ -55,13 +55,14 @@
|
||||
#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
|
||||
{ \
|
||||
if (BOOL_V) { \
|
||||
constexpr bool BOOL_NAME = true; \
|
||||
using BOOL_NAME = std::true_type; \
|
||||
F(); \
|
||||
} else { \
|
||||
constexpr bool BOOL_NAME = false; \
|
||||
using BOOL_NAME = std::false_type; \
|
||||
F(); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_ARCHTAG(CC, func) \
|
||||
{ \
|
||||
if (CC >= 80) { \
|
||||
|
||||
@ -32,6 +32,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <cinttypes>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
@ -85,8 +86,6 @@
|
||||
#include "gemm/mma_from_smem.h"
|
||||
#include "transform/tile_smem_loader.h"
|
||||
|
||||
#include <inttypes.h>
|
||||
|
||||
using namespace gemm_kernel_utils;
|
||||
|
||||
namespace {
|
||||
@ -1446,7 +1445,7 @@ struct AttentionBackwardKernel {
|
||||
uint8_t lane_id) {
|
||||
cutlass::Array<cutlass::uint1b_t, MatmulDOIVJ::Mma::FragmentC::kElements>
|
||||
dropout_keep_mask_doivj;
|
||||
dropout_keep_mask_doivj.fill(1);
|
||||
dropout_keep_mask_doivj.fill(cutlass::uint1b_t{1});
|
||||
const float dropout_scale =
|
||||
kApplyDropout ? 1.0 / (1.0 - p.dropout_prob) : 1.0f;
|
||||
|
||||
@ -1744,7 +1743,7 @@ struct AttentionBackwardKernel {
|
||||
[&](int accum_m) {},
|
||||
[&](int accum_m /*q*/, int accum_n /*k*/, int idx) {
|
||||
if (zij.at({accum_n, accum_m}) == scalar_t(0)) {
|
||||
dropout_keep_mask_doivj[idx] = cutlass::uint1b_t(0);
|
||||
dropout_keep_mask_doivj[idx] = cutlass::uint1b_t{0};
|
||||
}
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
@ -1956,7 +1955,8 @@ struct AttentionBackwardKernel {
|
||||
|
||||
// no-op epilogue operator - just casting and storing contents of
|
||||
// accum to global memory
|
||||
typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op({1, 1});
|
||||
typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op(
|
||||
typename MatmulDOIVJ::BiasGradEpilogue::OutputOp::Params{1, 1});
|
||||
typename MatmulDOIVJ::BiasGradEpilogue epilogue(
|
||||
shared_storage.gradB_epilogue(), thread_id, warp_id, lane_id);
|
||||
epilogue(output_op, output_iter, accum, output_iter);
|
||||
@ -2211,7 +2211,7 @@ struct AttentionBackwardKernel {
|
||||
incrIteration(p, query_start, key_start, next_query, next_key);
|
||||
DISPATCH_BOOL(
|
||||
next_key != key_start, kForceReloadK, ([&]() {
|
||||
prologueQkNextIteration<kForceReloadK>(
|
||||
prologueQkNextIteration<kForceReloadK::value>(
|
||||
shared_storage, p, next_query, next_key, warp_id, lane_id);
|
||||
}));
|
||||
}
|
||||
@ -2342,7 +2342,7 @@ struct AttentionBackwardKernel {
|
||||
thread_id,
|
||||
cutlass::MatrixCoord{0, 0});
|
||||
|
||||
MatmulQK::Mma::prologue<kReloadK, true>(
|
||||
MatmulQK::Mma::template prologue<kReloadK, true>(
|
||||
shared_storage.mm_qk_k(),
|
||||
shared_storage.mm_qk_q(),
|
||||
iterator_A,
|
||||
@ -2369,6 +2369,7 @@ struct AttentionBackwardKernel {
|
||||
p.grad_value_ptr + key_start * p.gV_strideM(),
|
||||
{num_keys_in_block, p.head_dim_value},
|
||||
thread_id);
|
||||
|
||||
accumulateInGmem<MatmulGradV>(
|
||||
shared_storage.gradV_epilogue_final(),
|
||||
output_frags.gradV,
|
||||
@ -2406,7 +2407,7 @@ struct AttentionBackwardKernel {
|
||||
int thread_id = 32 * warp_id + lane_id;
|
||||
DISPATCH_BOOL(
|
||||
first, kIsFirst, ([&]() {
|
||||
static constexpr auto ScaleType = kIsFirst
|
||||
static constexpr auto ScaleType = kIsFirst::value
|
||||
? cutlass::epilogue::thread::ScaleType::Nothing
|
||||
: cutlass::epilogue::thread::ScaleType::NoBetaScaling;
|
||||
using EpilogueOutputOp =
|
||||
|
||||
@ -38,9 +38,9 @@
|
||||
|
||||
#include <curand_kernel.h>
|
||||
#include <cmath>
|
||||
#include <cinttypes>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
@ -72,8 +72,6 @@
|
||||
#include "gemm_kernel_utils.h"
|
||||
#include "transform/tile_smem_loader.h"
|
||||
|
||||
#include <inttypes.h>
|
||||
|
||||
using namespace gemm_kernel_utils;
|
||||
|
||||
namespace {
|
||||
@ -1037,15 +1035,15 @@ struct AttentionKernel {
|
||||
using EpilogueOutputOp = typename cutlass::epilogue::
|
||||
thread::MemoryEfficientAttentionNormalize<
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
kIsLast::value,
|
||||
output_t,
|
||||
output_accum_t>::type,
|
||||
output_accum_t,
|
||||
DefaultOp::kCount,
|
||||
typename DefaultOp::ElementAccumulator,
|
||||
ElementCompute,
|
||||
kIsFirst,
|
||||
kIsLast,
|
||||
kIsFirst::value,
|
||||
kIsLast::value,
|
||||
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
EpiloguePipelined<
|
||||
@ -1053,7 +1051,7 @@ struct AttentionKernel {
|
||||
typename MM1::Mma::Operator,
|
||||
DefaultEpilogue::kPartitionsK,
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
kIsLast::value,
|
||||
typename MM1::OutputTileIterator,
|
||||
typename MM1::OutputTileIteratorAccum>::type,
|
||||
typename DefaultEpilogue::
|
||||
@ -1071,7 +1069,7 @@ struct AttentionKernel {
|
||||
int col = blockN * MM1::Mma::Shape::kN;
|
||||
auto source_iter = createOutputAccumIter(col);
|
||||
auto dest_iter = call_conditional<
|
||||
kIsLast,
|
||||
kIsLast::value,
|
||||
decltype(createOutputIter),
|
||||
decltype(createOutputAccumIter)>::
|
||||
apply(createOutputIter, createOutputAccumIter, col);
|
||||
|
||||
@ -452,7 +452,7 @@ public:
|
||||
// Determine SMEM requirements and waive if not satisfied
|
||||
//
|
||||
|
||||
int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage));
|
||||
size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage);
|
||||
|
||||
cudaDeviceProp properties;
|
||||
int device_idx;
|
||||
@ -509,7 +509,7 @@ public:
|
||||
);
|
||||
|
||||
// Initialize the GEMM object
|
||||
Gemm gemm;
|
||||
Gemm gemm{};
|
||||
|
||||
result.status = gemm.initialize(args);
|
||||
|
||||
|
||||
@ -39,11 +39,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
@ -53,12 +49,9 @@
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/functional.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@ class gen_test:
|
||||
|
||||
def gen_cpp_sample(self):
|
||||
code = "/* Auto Generated code - Do not edit.*/\n"
|
||||
code += "#include <stdio.h> \n"
|
||||
code += "#include <cstdio> \n"
|
||||
|
||||
code += "#include \"cutlass/gemm/device/gemm_batched.h\" \n"
|
||||
code += "#include \"cutlass/cutlass.h\" \n"
|
||||
|
||||
@ -380,7 +380,7 @@ class gen_one_API:
|
||||
def gen_CUTLASS_irrelevant_API(self):
|
||||
code = ""
|
||||
code += "#include <cuda_runtime.h>\n"
|
||||
code += "#include <assert.h>\n"
|
||||
code += "#include <cassert>\n"
|
||||
|
||||
param_name = "Fused" + str(self.b2b_num) + "xGemm_"
|
||||
for i in range(self.b2b_num):
|
||||
|
||||
@ -66,7 +66,7 @@ int testRun(int arch, std::vector<bool (*)()> & test_funcs, const std::string &
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major == arch_major && props.minor == arch_minor)) {
|
||||
if (props.major < arch_major || (props.major == arch_major && props.minor < arch_minor) ) {
|
||||
supported = false;
|
||||
}
|
||||
|
||||
|
||||
@ -38,11 +38,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
@ -35,19 +35,23 @@
|
||||
This example demonstrate a simple way to instantiate and run a TF32 GEMM using the new CUTLASS 3.0
|
||||
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
|
||||
|
||||
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
|
||||
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
|
||||
which are more efficient than the Ampere tensor core instructions.
|
||||
|
||||
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
|
||||
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
|
||||
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
|
||||
copies between thread blocks in a cluster. Another advantage is that TMA can load in FP32 data and
|
||||
convert them implicitly to TF32.
|
||||
|
||||
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
|
||||
|
||||
4. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
|
||||
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
|
||||
improve performance.
|
||||
|
||||
Examples:
|
||||
|
||||
$ ./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm --m=2048 --n=2048 --k=2048
|
||||
$ ./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm --m=2048 --n=2048 --k=2048 --rasterization=N --swizzle=2
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
@ -63,6 +67,7 @@
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
@ -105,7 +110,7 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; // O
|
||||
using TileShape = Shape<_128,_128,_32>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
@ -175,6 +180,8 @@ cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions;
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
@ -183,12 +190,16 @@ struct Options {
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
RasterOrderOptions raster;
|
||||
int swizzle;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(5120), n(4096), k(4096),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(1000)
|
||||
iterations(1000),
|
||||
raster(RasterOrderOptions::Heuristic),
|
||||
swizzle(1)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
@ -206,6 +217,21 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
char raster_char;
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
|
||||
if (raster_char == 'N' || raster_char == 'n') {
|
||||
raster = RasterOrderOptions::AlongN;
|
||||
}
|
||||
else if (raster_char == 'M' || raster_char == 'm') {
|
||||
raster = RasterOrderOptions::AlongM;
|
||||
}
|
||||
else if (raster_char == 'H' || raster_char == 'h') {
|
||||
raster = RasterOrderOptions::Heuristic;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle, 1);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
@ -220,6 +246,8 @@ struct Options {
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
|
||||
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
@ -275,14 +303,14 @@ bool initialize_block(
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
@ -294,10 +322,10 @@ bool initialize_block(
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, Int<1>{}));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, Int<1>{}));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, Int<1>{}));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, Int<1>{}));
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
block_A.reset(options.m * options.k);
|
||||
block_B.reset(options.k * options.n);
|
||||
@ -320,6 +348,10 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
arguments.scheduler.raster_order = options.raster;
|
||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
@ -408,7 +440,17 @@ int run(Options &options)
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
@ -441,7 +483,6 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -538,9 +538,8 @@ int main(int argc, char const **args) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -354,9 +354,8 @@ int main(int argc, char const **args) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -102,7 +102,8 @@ gett_kernel(
|
||||
ElementB, StrideB, 128 / cutlass::sizeof_bits<ElementB>::value,
|
||||
ElementAccumulator,
|
||||
TileShape, Shape<_1,_2,_1>,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
|
||||
@ -45,18 +45,18 @@
|
||||
and BEFORE scatter operations are applied.
|
||||
*/
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <time.h>
|
||||
#include <math.h>
|
||||
#include <assert.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <cstdlib>
|
||||
#include <cstdio>
|
||||
#include <ctime>
|
||||
#include <cmath>
|
||||
#include <cassert>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <numeric>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
@ -64,7 +64,6 @@
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
@ -289,7 +288,8 @@ struct ExampleRunner
|
||||
ElementAccumulator,
|
||||
Shape<_128,_128,_64>,
|
||||
Shape<_2,_2,_1>,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename EpilogueOpt::SharedStorage)>,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename EpilogueOpt::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
@ -626,7 +626,6 @@ int main(int argc, const char ** argv) {
|
||||
std::cerr << "This example requires a device with compute capability 90 or higher.\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems
|
||||
}
|
||||
|
||||
@ -39,6 +39,11 @@
|
||||
|
||||
#include "gather_tensor.hpp"
|
||||
|
||||
namespace cutlass {
|
||||
///Forward declaration
|
||||
struct CudaHostAdapter;
|
||||
}
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -106,7 +111,7 @@ public:
|
||||
EpilogueTensorStorage epilogue;
|
||||
} tensors;
|
||||
|
||||
struct PipelineStorage : cute::aligned_struct<16> {
|
||||
struct PipelineStorage : cute::aligned_struct<16, _2> {
|
||||
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
|
||||
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
||||
|
||||
@ -143,10 +148,10 @@ public:
|
||||
|
||||
// Kernel entry point API
|
||||
struct Params {
|
||||
GemmUniversalMode mode;
|
||||
ProblemShape problem_shape;
|
||||
MainloopParams mainloop;
|
||||
EpilogueParams epilogue;
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopParams mainloop{};
|
||||
EpilogueParams epilogue{};
|
||||
GatherA gather_A{};
|
||||
GatherB gather_B{};
|
||||
};
|
||||
@ -161,7 +166,7 @@ public:
|
||||
to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
(void) workspace;
|
||||
auto problem_shape = args.problem_shape;
|
||||
if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
|
||||
if constexpr (detail::Has_SwapAB_v<CollectiveMainloop>) {
|
||||
// swap M/N
|
||||
get<0>(problem_shape) = get<1>(args.problem_shape);
|
||||
get<1>(problem_shape) = get<0>(args.problem_shape);
|
||||
@ -176,8 +181,7 @@ public:
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE static
|
||||
bool
|
||||
static bool
|
||||
can_implement(Arguments const& args) {
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
||||
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
@ -191,14 +195,15 @@ public:
|
||||
}
|
||||
|
||||
static
|
||||
int
|
||||
size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static
|
||||
cutlass::Status
|
||||
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/numeric/int.hpp"
|
||||
#include "cute/numeric/numeric_types.hpp"
|
||||
|
||||
#include "gather_tensor.hpp"
|
||||
|
||||
@ -119,7 +119,7 @@ public:
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
CUTLASS_HOST_DEVICE static bool
|
||||
static bool
|
||||
can_implement(
|
||||
[[maybe_unused]] ProblemShape const& problem_shape,
|
||||
[[maybe_unused]] Arguments const& args) {
|
||||
|
||||
@ -393,7 +393,8 @@ private:
|
||||
ElementB, StrideB, 128 / cutlass::sizeof_bits<ElementB>::value,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
@ -403,7 +404,8 @@ private:
|
||||
ElementB, StrideBPermute, 128 / cutlass::sizeof_bits<ElementB>::value,
|
||||
ElementAccumulator,
|
||||
TileShapePermute, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpiloguePermute::SharedStorage)>,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpiloguePermute::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
@ -748,7 +750,6 @@ int main(int argc, char const **argv)
|
||||
std::cerr << "This example requires a device with compute capability 90 or higher.\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems
|
||||
}
|
||||
|
||||
@ -37,7 +37,7 @@
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cute/numeric/uint128.hpp"
|
||||
#include "cute/numeric/numeric_types.hpp"
|
||||
|
||||
namespace example
|
||||
{
|
||||
|
||||
@ -50,7 +50,7 @@ struct PermuteTraits {};
|
||||
using X = Underscore;
|
||||
|
||||
// Reshape a rank-2 shape into a multidimensional shape.
|
||||
// Input:
|
||||
// Input:
|
||||
// shape = (A, B, ...)
|
||||
// target_shape = ((A1, ..., X, ..., Am), (B1, ..., X, ..., Bn), ...)
|
||||
// Output:
|
||||
@ -76,12 +76,12 @@ reshape(Shape const& shape, TargetShape const& target_shape)
|
||||
// - sub-modes corresponding to the implied multidimensional shape of the source tensor
|
||||
// - strides accounting for the permutation operation being performed
|
||||
template<class Permute, bool Transpose, class Shape, class Stride>
|
||||
constexpr auto
|
||||
constexpr auto
|
||||
make_permute_layout(Layout<Shape,Stride> const& layout) {
|
||||
static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
|
||||
if constexpr (Transpose) {
|
||||
// Deal with tensor B by transposing appropriately before and after computing the permute layout.
|
||||
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
|
||||
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
|
||||
return select<1,0,2>(make_permute_layout<Permute, false>(select<1,0,2>(layout)));
|
||||
}
|
||||
else {
|
||||
@ -129,23 +129,24 @@ inverse(Permutation const & perm) {
|
||||
template<class T>
|
||||
using inverse_t = decltype(inverse(T{}));
|
||||
|
||||
// Given a rank-2 layout of tensor that is assumed to have been permuted,
|
||||
// Given a rank-2 layout of tensor that is assumed to have been permuted,
|
||||
// compute the original rank-2 layout of the tensor prior to the permutation.
|
||||
// This is needed to form the correct input to the standalone permutation kernel.
|
||||
// This is needed to form the correct input to the standalone permutation kernel.
|
||||
template<class Permute, bool Transpose, class Shape, class Stride>
|
||||
constexpr auto
|
||||
constexpr auto
|
||||
make_original_layout(Layout<Shape,Stride> const& layout) {
|
||||
static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
|
||||
if constexpr (Transpose) {
|
||||
// Deal with tensor B by transposing appropriately before and after computing the permute layout.
|
||||
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
|
||||
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
|
||||
return select<1,0,2>(make_original_layout<Permute, false>(select<1,0,2>(layout)));
|
||||
}
|
||||
else {
|
||||
using ShapeProfile = typename PermuteTraits<Permute>::ShapeProfile;
|
||||
auto re_shape = flatten(reshape(layout.shape(), ShapeProfile{}));
|
||||
using IndexOrder = typename PermuteTraits<Permute>::IndexOrder;
|
||||
auto orig_shape = transform_leaf(IndexOrder{}, [&](auto i){ return get<i>(re_shape); });
|
||||
using OrigOrder = conditional_t<cutlass::gemm::detail::is_major<0,Stride>(), seq<0,1,2>, seq<1,0,2>>;
|
||||
auto orig_shape = select(flatten(reshape(layout.shape(), ShapeProfile{})), IndexOrder{});
|
||||
// print("Permuted shape: "); print(reshape(layout.shape(), ShapeProfile{})); print("\n");
|
||||
// print("Original shape: "); print(orig_shape); print("\n");
|
||||
return make_ordered_layout(product_each(orig_shape), OrigOrder{});
|
||||
@ -202,7 +203,7 @@ struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D>>
|
||||
};
|
||||
|
||||
template<int D>
|
||||
struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajorInverse<D>>
|
||||
struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajorInverse<D>>
|
||||
{
|
||||
static constexpr bool kBatched = true;
|
||||
using ShapeProfile = Shape<Shape<X,Int<D>>, Shape<X>, Shape<X>>;
|
||||
@ -222,7 +223,7 @@ struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D>>
|
||||
};
|
||||
|
||||
template<int D>
|
||||
struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0213RowMajorInverse<D>>
|
||||
struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0213RowMajorInverse<D>>
|
||||
{
|
||||
static constexpr bool kBatched = true;
|
||||
using ShapeProfile = Shape<Shape<X>, Shape<X,Int<D>>, Shape<X>>;
|
||||
|
||||
@ -47,9 +47,13 @@
|
||||
4. This example shows all important fusions used by FP8 gemm kernels,
|
||||
i.e., scale factor for A, B, C, D tensor, the abs_max value of D tensor.
|
||||
|
||||
5. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
|
||||
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
|
||||
improve performance.
|
||||
|
||||
Examples:
|
||||
|
||||
$ ./examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm --m=2048 --n=2048 --k=2048
|
||||
$ ./examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm --m=2048 --n=2048 --k=2048 --rasterization=N --swizzle=2
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
@ -63,6 +67,7 @@
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
@ -214,6 +219,8 @@ cutlass::HostTensor<ElementAmax , LayoutScalar> reference_abs_max_aux;
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions;
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
@ -273,7 +280,7 @@ bool initialize_tensor(
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
void initialize(const Options<RasterOrderOptions> &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
@ -346,7 +353,7 @@ void initialize(const Options &options) {
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
typename Gemm::Arguments args_from_options(const Options<RasterOrderOptions> &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
@ -392,10 +399,14 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
fusion_args.amax_D_ptr = abs_max_D.device_data();
|
||||
}
|
||||
|
||||
arguments.scheduler.raster_order = options.raster;
|
||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
bool verify(const Options<RasterOrderOptions> &options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
@ -468,7 +479,7 @@ bool verify(const Options &options) {
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
int run(Options<RasterOrderOptions> &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
@ -518,7 +529,17 @@ int run(Options &options)
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
@ -551,12 +572,11 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
Options<RasterOrderOptions> options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
// Command line options parsing
|
||||
template<typename RasterOrderOptions>
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
@ -41,6 +42,8 @@ struct Options {
|
||||
bool save_amax = true;
|
||||
int iterations = 1000;
|
||||
int m = 1024, n = 512, k = 1024, l = 1;
|
||||
RasterOrderOptions raster;
|
||||
int swizzle;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
@ -66,6 +69,21 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("save_aux", save_aux, true);
|
||||
cmd.get_cmd_line_argument("save_amax", save_amax, true);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
char raster_char;
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
|
||||
if (raster_char == 'N' || raster_char == 'n') {
|
||||
raster = RasterOrderOptions::AlongN;
|
||||
}
|
||||
else if (raster_char == 'M' || raster_char == 'm') {
|
||||
raster = RasterOrderOptions::AlongM;
|
||||
}
|
||||
else if (raster_char == 'H' || raster_char == 'h') {
|
||||
raster = RasterOrderOptions::Heuristic;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle, 1);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
@ -89,6 +107,8 @@ struct Options {
|
||||
<< " --device_scale=<bool> Copy scalars to device memory before kernel launch (default: false)\n"
|
||||
<< " --save_aux=<bool> Save the pre-activation as an auxiliary tensor (default: true)\n"
|
||||
<< " --save_amax=<bool> Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n"
|
||||
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
|
||||
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
|
||||
683
examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu
Normal file
683
examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu
Normal file
@ -0,0 +1,683 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Hopper GEMM example with different data types using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
|
||||
|
||||
This example shows how to perform INT4 x BF16 GEMM and scale up the INT4 weight during dequantization.
|
||||
|
||||
The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap
|
||||
A and B in the main loop. However, as a result of this collective performing implicit swaps, it does not support TMA epilogues. Consequently, it is essential to consider this when constructing the epilogue,
|
||||
as illustrated in this example.
|
||||
|
||||
Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest.
|
||||
|
||||
As an additional optimization, we can reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory.
|
||||
This promotes vectorization of shared memory loads and removes additional instructions on the critical path. For example, when MMA is performed in 16-bit data type, each thread reads
|
||||
4 groups of 2 elements that are logically contiguous in the same row (refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-a for thread-value layout).
|
||||
If the narrow type is INT4 and tensor is major in K dim, only 8 bits can be read at a time, leading to extra load instructions and suboptimal utilization of shared memory throughput.
|
||||
If we reorder the data offline to place all 16 elements read by a thread contiguously in memory, a single 64-bit load is sufficient. This reordering is often feasible when the quantized
|
||||
tensor is static (e.g. weight tensor of a NN layer at inference time). This example demonstrates how such a reordering can be performed and communicated to the kernel when the options.shuffle is set to true.
|
||||
|
||||
Furthermore, the conversion from {INT4, UINT4} to {FP16, BF16} can benefit from pre-shuffling the weights in the order [0,2,4,6,1,3,5,7]. This allows multiple nibbles to be efficiently extracted and up-converted
|
||||
in parallel. The reordering is enabled by defining the layout type `ValueShuffle`. Refer to the partial specializations of `NumericArrayShuffleConverter` in "include/cutlass/detail/collective/mixed_input_utils.hpp"
|
||||
for more details.
|
||||
|
||||
It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size).
|
||||
|
||||
Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled.
|
||||
|
||||
If A is being scaled, the scales must have shape [M, scale_k], while if B is scaled, it must have shape [N, scale_k].
|
||||
|
||||
The implementation only supports "group-wise" scales. However, we can make it work for per-column scales by setting the group's size
|
||||
equal to the gemm problem K.
|
||||
|
||||
Limitations:
|
||||
1) The INT4 weights have additional encoding requirements.
|
||||
2) The scales must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major.
|
||||
3) The scales must have the same layout and groupsize.
|
||||
4) The groupsize must be greater or equal to the tile shape k.
|
||||
5) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the
|
||||
operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations.
|
||||
We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands.
|
||||
|
||||
Optimizing suggestions:
|
||||
1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space).
|
||||
|
||||
Examples:
|
||||
|
||||
Runs the mixed input batched gemm (with batch size 2), converting B to the type of A (mode 0)
|
||||
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm --m=2048 --n=2048 --k=2048 --l=2 --mode=0
|
||||
|
||||
Runs the mixed input gemm, and applies a scaling factor to B before mma (mode 1). Applies a vector of scales to the entire
|
||||
matrix (group size is the same as the gemm k dimension).
|
||||
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm --m=4096 --n=5120 --k=8192 --g=8192 --mode=1
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
|
||||
#include "helper.h"
|
||||
#include "mixed_dtype_utils.hpp"
|
||||
#include "packed_scale.hpp"
|
||||
#include "reorder_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
using MmaType = cutlass::bfloat16_t;
|
||||
using QuantType = cutlass::int4b_t;
|
||||
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = QuantType; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// This example manually swaps and transposes, so keep transpose of input layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
|
||||
|
||||
// Define the CuTe layout for reoredered quantized tensor B
|
||||
// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory.
|
||||
// It specifies the reordering within a single warp's fragment
|
||||
//using ValueShuffle = Layout<_1>; // no value reordering
|
||||
using ValueShuffle = Layout<Shape<_2,_4>, Stride<_4,_1>>; // order [0,2,4,6,1,3,5,7]
|
||||
int constexpr NumShuffleAtoms = 1;
|
||||
using MmaAtomShape = Layout<Shape<_1,Int<NumShuffleAtoms>>>;
|
||||
using LayoutAtomQuant = decltype(compute_memory_reordering_atom<MmaType, MmaAtomShape, ValueShuffle>());
|
||||
using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
|
||||
|
||||
using ElementScale = MmaType;
|
||||
using ElementZero = ElementScale;
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,cute::Int<TileShapeK>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch based on the default setting in the Collective Builder
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
// Transpose layout of D here since we use explicit swap + transpose
|
||||
// the void type for C tells the builder to allocate 0 smem for the C matrix.
|
||||
// We can enable this if beta == 0 by changing ElementC to void below.
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type, AlignmentC,
|
||||
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type, AlignmentD,
|
||||
EpilogueSchedule // This is the only epi supporting the required swap + transpose.
|
||||
>::CollectiveOp;
|
||||
|
||||
// ============================================================ MIXED INPUT NO SCALES ============================================================================
|
||||
// The collective will infer that the narrow type should be upcasted to the wide type.
|
||||
// We swap A and B operands to the builder here
|
||||
using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementB, LayoutB_Transpose, AlignmentB,
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopConvertOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelConvertOnly>;
|
||||
|
||||
using CollectiveMainloopConvertOnlyShuffled = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementB, LayoutB_Reordered, AlignmentB,
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelConvertOnlyShuffled = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopConvertOnlyShuffled,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmConvertOnlyShuffled = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelConvertOnlyShuffled>;
|
||||
|
||||
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
|
||||
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
|
||||
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, ElementScale>, LayoutB_Transpose, AlignmentB,
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopScaleOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||
|
||||
using CollectiveMainloopScaleOnlyShuffled = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, ElementScale>, LayoutB_Reordered, AlignmentB,
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleOnlyShuffled = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopScaleOnlyShuffled,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleOnlyShuffled = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnlyShuffled>;
|
||||
|
||||
// =========================================================== MIXED INPUT WITH SCALES AND ZEROS ==================================================================
|
||||
// We specify scale + zero elements to indicate that we require both. Scales and biases have the same format.
|
||||
using CollectiveMainloopScaleWithZeroPoint = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, ElementScale, ElementZero>, LayoutB_Transpose, AlignmentB,
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleWithZeroPoint = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopScaleWithZeroPoint,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleWithZeroPoint = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleWithZeroPoint>;
|
||||
|
||||
using CollectiveMainloopScaleWithZeroPointShuffled = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, ElementScale, ElementZero>, LayoutB_Reordered, AlignmentB,
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleWithZeroPointShuffled = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopScaleWithZeroPointShuffled,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleWithZeroPointShuffled = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleWithZeroPointShuffled>;
|
||||
// =================================================================================================================================================================
|
||||
|
||||
using StrideC = typename GemmKernelScaleOnly::StrideC;
|
||||
using StrideD = typename GemmKernelScaleOnly::StrideD;
|
||||
|
||||
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
|
||||
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideC_ref stride_C_ref;
|
||||
StrideD stride_D;
|
||||
StrideD_ref stride_D_ref;
|
||||
uint64_t seed;
|
||||
|
||||
LayoutB_Reordered layout_B_reordered;
|
||||
|
||||
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
|
||||
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||
StrideS stride_S;
|
||||
StrideS_ref stride_S_ref;
|
||||
|
||||
cutlass::DeviceAllocation<ElementA> block_A;
|
||||
cutlass::DeviceAllocation<ElementB> block_B;
|
||||
cutlass::DeviceAllocation<ElementA> block_B_dq;
|
||||
cutlass::DeviceAllocation<ElementScale> block_scale;
|
||||
cutlass::DeviceAllocation<ElementZero> block_zero;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options : MixedDtypeOptions{
|
||||
bool shuffle = true;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
cmd.get_cmd_line_argument("shuffle", shuffle);
|
||||
|
||||
this->MixedDtypeOptions::parse(argc, args);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "55_hopper_int4_bf16_gemm\n\n"
|
||||
<< " Hopper Mixed Data Type GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> The number of independent gemm problems with mnk shape\n"
|
||||
<< " --g=<int> The size of each group for the scales. To broadcast a vector of scales or zeros, set the group size to K.\n"
|
||||
<< " --mode=<int> The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
|
||||
<< " --warmup=<int> Number of warmup iterations to perform.\n\n"
|
||||
<< " --shuffle=<boolean> Enable the offline layout swizzling.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "55_hopper_int4_bf16_gemm" << " --m=1024 --n=512 --k=1024 -g=1024 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(Options const& options) {
|
||||
|
||||
auto shape_B = cute::make_shape(options.n, options.k, options.l);
|
||||
int const scale_k = (options.k + options.g - 1) / options.g;
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
|
||||
// Reverse stride here due to swap and transpose
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l));
|
||||
stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l));
|
||||
// Reverse stride here due to swap and transpose
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l));
|
||||
stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l));
|
||||
|
||||
auto layout_B = make_layout(shape_B, stride_B);
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
|
||||
block_A.reset(a_coord.product());
|
||||
block_B.reset(b_coord.product());
|
||||
block_B_dq.reset(b_coord.product());
|
||||
block_C.reset(c_coord.product());
|
||||
block_D.reset(c_coord.product());
|
||||
block_ref_D.reset(c_coord.product());
|
||||
|
||||
block_scale.reset(scale_k * options.l * options.n);
|
||||
block_zero.reset(scale_k * options.l * options.n);
|
||||
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
|
||||
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
|
||||
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
|
||||
|
||||
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
|
||||
|
||||
if (options.shuffle) {
|
||||
// Repeat the reorder layout atom to tile the whole tensor shape
|
||||
layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
reorder_tensor(block_B.get(), layout_B, layout_B_reordered);
|
||||
|
||||
print("Quantized tensor layout: ");
|
||||
print(layout_B_reordered);
|
||||
print("\n");
|
||||
}
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
/// Swap the A and B tensors, as well as problem shapes here.
|
||||
template <typename Gemm>
|
||||
typename Gemm::Arguments args_from_options(Options const& options)
|
||||
{
|
||||
using Args = typename Gemm::Arguments;
|
||||
auto&& dB = [&]() {
|
||||
if constexpr (cute::is_same_v<Gemm, GemmConvertOnlyShuffled> ||
|
||||
cute::is_same_v<Gemm, GemmScaleOnlyShuffled> ||
|
||||
cute::is_same_v<Gemm, GemmScaleWithZeroPointShuffled>) {
|
||||
// offline swizzling is enabled.
|
||||
return layout_B_reordered;
|
||||
}
|
||||
else {
|
||||
return stride_B;
|
||||
}
|
||||
}();
|
||||
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{block_B.get(), dB, block_A.get(), stride_A},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
}
|
||||
else if (options.mode == MixedDtypeGemmMode::ScaleOnly) {
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{block_B.get(), dB, block_A.get(), stride_A, block_scale.get(), stride_S, options.g},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
}
|
||||
else if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) {
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{block_B.get(), dB, block_A.get(), stride_A, block_scale.get(), stride_S, options.g, block_zero.get()},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
} else {
|
||||
std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
bool verify(Options const& options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaType, LayoutA, AlignmentA,
|
||||
MmaType, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopRef,
|
||||
CollectiveEpilogueRef
|
||||
>;
|
||||
|
||||
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
|
||||
|
||||
typename GemmRef::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{block_A.get(), stride_A, block_B_dq.get(), stride_B},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref}
|
||||
};
|
||||
|
||||
// Run the gemm where the scaling is performed outside of the kernel.
|
||||
GemmRef gemm_ref;
|
||||
size_t workspace_size = GemmRef::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm_ref.run());
|
||||
|
||||
// compare_reference
|
||||
ElementD const epsilon(1e-2f);
|
||||
ElementD const non_zero_floor(1e-4f);
|
||||
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<Gemm>(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
MixedDtypeResult result;
|
||||
result.passed = verify(options);
|
||||
mixed_dtype_profiling(gemm, options, result);
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
|
||||
std::cout << "Running in no scale mode." << std::endl;
|
||||
if (options.shuffle) {
|
||||
std::cout << "Offline shuffle enabled." << std::endl;
|
||||
run<GemmConvertOnlyShuffled>(options);
|
||||
} else {
|
||||
std::cout << "Offline shuffle disabled." << std::endl;
|
||||
run<GemmConvertOnly>(options);
|
||||
}
|
||||
}
|
||||
else if (options.mode == MixedDtypeGemmMode::ScaleOnly) {
|
||||
if (options.g == options.k) {
|
||||
std::cout << "Running in per-column scale mode." << std::endl;
|
||||
} else {
|
||||
std::cout << "Running in group scale mode." << std::endl;
|
||||
}
|
||||
if (options.shuffle) {
|
||||
std::cout << "Offline shuffle enabled." << std::endl;
|
||||
run<GemmScaleOnlyShuffled>(options);
|
||||
} else {
|
||||
std::cout << "Offline shuffle disabled." << std::endl;
|
||||
run<GemmScaleOnly>(options);
|
||||
}
|
||||
}
|
||||
else if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) {
|
||||
if (options.g == options.k) {
|
||||
std::cout << "Running in per-column scale and zero mode." << std::endl;
|
||||
} else {
|
||||
std::cout << "Running in group scale and zero mode." << std::endl;
|
||||
}
|
||||
if (options.shuffle) {
|
||||
std::cout << "Offline shuffle enabled." << std::endl;
|
||||
run<GemmScaleWithZeroPointShuffled>(options);
|
||||
} else {
|
||||
std::cout << "Offline shuffle disabled." << std::endl;
|
||||
run<GemmScaleWithZeroPoint>(options);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
562
examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
Normal file
562
examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
Normal file
@ -0,0 +1,562 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Hopper GEMM example with different data types using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
|
||||
|
||||
This example shows how to perform INT4 x FP8 GEMM and scale up the INT4 weight during dequantization. It uses a look-up table to avoid the multiplications
|
||||
between INT4 and FP8. To trigger this method, use cutlass::Array<ElementScale, 8> as the scale type in the collective's arguments.
|
||||
|
||||
However, this algorithm requires changes to the encoding of INT4 weights and scale factors. These changes must happen before launching the GEMM. See the helper functions
|
||||
`unify_quant_encoding`, `initialize_packed_scale` in the header `fp8_packed_scale.hpp` for details.
|
||||
|
||||
In a nutshell, the positive values of INT4 weights need to be encoded in the same way as negative values except for the sign bit. For each scale factor,
|
||||
8 negative results (-8 x scale, -7 x scale, ... -1 x scale) are packed together, forming a cutlass::Array<ElementScale, 8> value.
|
||||
|
||||
The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap
|
||||
A and B in the main loop. However, as a result of this collective performing implicit swaps, it does not support TMA epilogues. Consequently, it is essential to consider this when constructing the epilogue,
|
||||
as illustrated in this example.
|
||||
|
||||
Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest.
|
||||
|
||||
As an additional optimization, we can reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory.
|
||||
This promotes vectorization of shared memory loads and removes additional instructions on the critical path. For example, when MMA is performed in FP8 data type, each thread reads
|
||||
4 groups of 4 elements that are logically contiguous in the same row (refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n32-a for thread-value layout).
|
||||
If the narrow type is INT4 and tensor is major in K dim, only 16 bits can be read at a time, leading to extra load instructions and suboptimal utilization of shared memory throughput.
|
||||
If we reorder the data offline to place all 16 elements read by a thread contiguously in memory, a single 64-bit load is sufficient. This reordering is often feasible when the quantized
|
||||
tensor is static (e.g. weight tensor of a NN layer at inference time). This example demonstrates how such a reordering can be performed and communicated to the kernel when the options.shuffle is set to true.
|
||||
|
||||
It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size).
|
||||
|
||||
Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled.
|
||||
|
||||
If A is being scaled, the scales must have shape [M, scale_k], while if B is scaled, it must have shape [N, scale_k].
|
||||
|
||||
The implementation only supports "group-wise" scales. However, we can make it work for per-column scales by setting the group's size
|
||||
equal to the gemm problem K.
|
||||
|
||||
Limitations:
|
||||
1) Only supports INT4 x { FP8, INT8, UINT8 }. The scales must be the same as mma Type. Scale with zero-point mode is not supported.
|
||||
2) The INT4 weights and scale factors have additional encoding requirements.
|
||||
3) The scales must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major.
|
||||
4) The scales must have the same layout and groupsize.
|
||||
5) The groupsize must be greater or equal to the tile shape k.
|
||||
6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the
|
||||
operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations.
|
||||
We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands.
|
||||
|
||||
Optimizing suggestions:
|
||||
1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space).
|
||||
|
||||
Examples:
|
||||
|
||||
Runs the mixed input batched gemm (with batch size 2), converting B to the type of A (mode 0)
|
||||
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm --m=2048 --n=2048 --k=2048 --l=2 --mode=0
|
||||
|
||||
Runs the mixed input gemm, and applies a scaling factor to B before mma (mode 1). Applies a vector of scales to the entire
|
||||
matrix (group size is the same as the gemm k dimension).
|
||||
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm --m=4096 --n=5120 --k=8192 --g=8192 --mode=1
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
|
||||
#include "helper.h"
|
||||
#include "mixed_dtype_utils.hpp"
|
||||
#include "packed_scale.hpp"
|
||||
#include "reorder_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
using MmaType = cutlass::float_e4m3_t;
|
||||
using QuantType = cutlass::int4b_t;
|
||||
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = QuantType; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// This example manually swaps and transposes, so keep transpose of input layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
|
||||
|
||||
// Define the CuTe layout for reoredered quantized tensor B
|
||||
// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory.
|
||||
// It specifies the reordering within a single warp's fragment
|
||||
using LayoutAtomQuant = decltype(compute_memory_reordering_atom<MmaType>());
|
||||
using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
|
||||
|
||||
using ElementScale = MmaType;
|
||||
using ElementZero = ElementScale; // only for verify
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,cute::Int<TileShapeK>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch based on the default setting in the Collective Builder
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
// Transpose layout of D here since we use explicit swap + transpose
|
||||
// the void type for C tells the builder to allocate 0 smem for the C matrix.
|
||||
// We can enable this if beta == 0 by changing ElementC to void below.
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type, AlignmentC,
|
||||
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type, AlignmentD,
|
||||
EpilogueSchedule // This is the only epi supporting the required swap + transpose.
|
||||
>::CollectiveOp;
|
||||
|
||||
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
|
||||
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
|
||||
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>, LayoutB_Transpose, AlignmentB,
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopScaleOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using CollectiveMainloopShuffled = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>, LayoutB_Reordered, AlignmentB,
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopShuffled,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||
using GemmShuffled = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>;
|
||||
|
||||
using StrideC = typename GemmKernelScaleOnly::StrideC;
|
||||
using StrideD = typename GemmKernelScaleOnly::StrideD;
|
||||
|
||||
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
|
||||
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideC_ref stride_C_ref;
|
||||
StrideD stride_D;
|
||||
StrideD_ref stride_D_ref;
|
||||
uint64_t seed;
|
||||
|
||||
LayoutB_Reordered layout_B_reordered;
|
||||
|
||||
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
|
||||
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||
StrideS stride_S;
|
||||
StrideS_ref stride_S_ref;
|
||||
|
||||
cutlass::DeviceAllocation<ElementA> block_A;
|
||||
cutlass::DeviceAllocation<ElementB> block_B;
|
||||
cutlass::DeviceAllocation<ElementB> block_B_modified;
|
||||
cutlass::DeviceAllocation<ElementA> block_B_dq;
|
||||
cutlass::DeviceAllocation<ElementScale> block_scale;
|
||||
cutlass::DeviceAllocation<cutlass::Array<ElementScale, 8>> block_scale_packed;
|
||||
cutlass::DeviceAllocation<ElementZero> block_zero;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options : MixedDtypeOptions {
|
||||
bool shuffle = true;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
cmd.get_cmd_line_argument("shuffle", shuffle);
|
||||
|
||||
this->MixedDtypeOptions::parse(argc, args);
|
||||
|
||||
mode = 1; // override the mode value to always be scale only mode
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "55_hopper_int4_fp8_gemm\n\n"
|
||||
<< " Hopper Mixed Data Type GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> The number of independent gemm problems with mnk shape\n"
|
||||
<< " --g=<int> The size of each group for the scales. To broadcast a vector of scales or zeros, set the group size to K.\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
|
||||
<< " --warmup=<int> Number of warmup iterations to perform.\n\n"
|
||||
<< " --shuffle=<boolean> Enable the offline layout swizzling.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "55_hopper_int4_fp8_gemm" << " --m=1024 --n=512 --k=1024 -g=1024 --l=10 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(Options const& options) {
|
||||
|
||||
auto shape_B = cute::make_shape(options.n, options.k, options.l);
|
||||
int const scale_k = (options.k + options.g - 1) / options.g;
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
|
||||
// Reverse stride here due to swap and transpose
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l));
|
||||
stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l));
|
||||
// Reverse stride here due to swap and transpose
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l));
|
||||
stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l));
|
||||
|
||||
auto layout_B = make_layout(shape_B, stride_B);
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
|
||||
block_A.reset(a_coord.product());
|
||||
block_B.reset(b_coord.product());
|
||||
block_B_modified.reset(b_coord.product());
|
||||
block_B_dq.reset(b_coord.product());
|
||||
block_C.reset(c_coord.product());
|
||||
block_D.reset(c_coord.product());
|
||||
block_ref_D.reset(c_coord.product());
|
||||
|
||||
block_scale.reset(scale_k * options.l * options.n);
|
||||
block_scale_packed.reset(scale_k * options.l * options.n);
|
||||
block_zero.reset(scale_k * options.l * options.n);
|
||||
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
unify_quant_encoding(block_B, block_B_modified);
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_packed_scale(block_scale, block_scale_packed);
|
||||
initialize_zero(block_zero, options);
|
||||
|
||||
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
|
||||
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
|
||||
|
||||
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
|
||||
|
||||
if (options.shuffle) {
|
||||
// Repeat the reorder layout atom to tile the whole tensor shape
|
||||
layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
reorder_tensor(block_B_modified.get(), layout_B, layout_B_reordered);
|
||||
|
||||
print("Quantized tensor layout: ");
|
||||
print(layout_B_reordered);
|
||||
print("\n");
|
||||
}
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
/// Swap the A and B tensors, as well as problem shapes here.
|
||||
template <typename Gemm>
|
||||
typename Gemm::Arguments args_from_options(Options const& options)
|
||||
{
|
||||
using Args = typename Gemm::Arguments;
|
||||
auto&& dB = [&]() {
|
||||
if constexpr (cute::is_same_v<Gemm, GemmShuffled>) { // offline swizzling is enabled.
|
||||
return layout_B_reordered;
|
||||
}
|
||||
else {
|
||||
return stride_B;
|
||||
}
|
||||
}();
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{block_B_modified.get(), dB, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
}
|
||||
|
||||
bool verify(Options const& options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// In this example, we use the GPU default kernels as a reference (unfused scale).
|
||||
// This avoids numerical differences due to different accumulation order.
|
||||
|
||||
// Again, due to numerical differences, we must use fast acc here when the mma type is
|
||||
// FP8 as the fused implementation only supports fast acc at the moment.
|
||||
constexpr bool IsFP8Input = cute::is_same_v<MmaType, cutlass::float_e4m3_t> || cute::is_same_v<MmaType, cutlass::float_e5m2_t>;
|
||||
using FP8Sched = cute::conditional_t<size<0>(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>;
|
||||
using ScheduleRef = cute::conditional_t<IsFP8Input, FP8Sched, cutlass::gemm::collective::KernelScheduleAuto>;
|
||||
|
||||
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaType, LayoutA, AlignmentA,
|
||||
MmaType, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
ScheduleRef
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopRef,
|
||||
CollectiveEpilogueRef
|
||||
>;
|
||||
|
||||
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
|
||||
|
||||
typename GemmRef::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{block_A.get(), stride_A, block_B_dq.get(), stride_B},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref}
|
||||
};
|
||||
|
||||
// Run the gemm where the scaling is performed outside of the kernel.
|
||||
GemmRef gemm_ref;
|
||||
size_t workspace_size = GemmRef::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm_ref.run());
|
||||
|
||||
// compare_reference
|
||||
ElementD const epsilon(1e-2f);
|
||||
ElementD const non_zero_floor(1e-4f);
|
||||
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<Gemm>(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
MixedDtypeResult result;
|
||||
result.passed = verify(options);
|
||||
mixed_dtype_profiling(gemm, options, result);
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
if (options.g == options.k) {
|
||||
std::cout << "Running in per-column scale mode." << std::endl;
|
||||
} else {
|
||||
std::cout << "Running in group scale mode." << std::endl;
|
||||
}
|
||||
if (options.shuffle) {
|
||||
std::cout << "Offline shuffle enabled." << std::endl;
|
||||
run<GemmShuffled>(options);
|
||||
} else {
|
||||
std::cout << "Offline shuffle disabled." << std::endl;
|
||||
run<GemmScaleOnly>(options);
|
||||
}
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -53,14 +53,18 @@
|
||||
equal to the gemm problem K.
|
||||
|
||||
Limitations:
|
||||
1) Only supported combinations are 16-bit x {8-bit, 4-bit, 2-bit} and {8-bit} x {4-bit, 2-bit}.
|
||||
2) The narrow type must always be in K-major format.
|
||||
3) The scales and zeros must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major.
|
||||
4) The scales and the zeros must have the same layout and groupsize.
|
||||
1) The narrow type must always be in K-major format.
|
||||
2) The scales and zeros must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major.
|
||||
3) The scales and the zeros must have the same layout and groupsize.
|
||||
4) The groupsize must be greater or equal to tile shape k.
|
||||
5) When dealing with 8-bit x {4-bit, 2-bit}, both inputs must be in K-major format.
|
||||
6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the
|
||||
operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations.
|
||||
We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands.
|
||||
|
||||
Optimizing suggestions:
|
||||
1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space).
|
||||
2) Try avoid using scale or zero mode cause the computations will be the bottleneck.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -94,31 +98,21 @@
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
|
||||
#include "helper.h"
|
||||
#include "unfused_weight_dequantize.hpp"
|
||||
#include "mixed_dtype_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
// This is just an example, so we use a regular enum so we can compare directly to the command-line int.
|
||||
enum GemmMode {
|
||||
ConvertOnly,
|
||||
ScaleOnly,
|
||||
ScaleWithZeroPoint
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
using MmaType = cutlass::float_e4m3_t;
|
||||
using QuantType = cutlass::int4b_t;
|
||||
using MmaType = cutlass::half_t;
|
||||
using QuantType = cutlass::float_e4m3_t;
|
||||
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
|
||||
// A matrix configuration
|
||||
@ -154,9 +148,9 @@ using ElementAccumulator = float; // E
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_256,cute::Int<TileShapeK>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder
|
||||
using TileShape = Shape<_128,_128,cute::Int<TileShapeK>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch based on the default setting in the Collective Builder
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
@ -268,14 +262,14 @@ using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||
StrideS stride_S;
|
||||
StrideS_ref stride_S_ref;
|
||||
|
||||
cutlass::HostTensor<MmaType, LayoutA> tensor_A;
|
||||
cutlass::HostTensor<QuantType, LayoutB> tensor_B;
|
||||
cutlass::HostTensor<MmaType, LayoutB> tensor_B_dq;
|
||||
cutlass::HostTensor<ElementScale, LayoutScale> tensor_scale;
|
||||
cutlass::HostTensor<ElementZero, LayoutScale> tensor_zero;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
|
||||
cutlass::HostTensor<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOutput, LayoutD> tensor_D;
|
||||
cutlass::HostTensor<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOutput, LayoutD> tensor_ref_D;
|
||||
cutlass::DeviceAllocation<ElementA> block_A;
|
||||
cutlass::DeviceAllocation<ElementB> block_B;
|
||||
cutlass::DeviceAllocation<ElementA> block_B_dq;
|
||||
cutlass::DeviceAllocation<ElementScale> block_scale;
|
||||
cutlass::DeviceAllocation<ElementZero> block_zero;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
@ -283,179 +277,17 @@ cutlass::HostTensor<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOu
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
int iterations = 1000;
|
||||
int mode = 2;
|
||||
int m = 5120, n = 4096, k = 4096;
|
||||
int g = 128;
|
||||
int l = 1;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("g", g);
|
||||
cmd.get_cmd_line_argument("mode", mode);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "55_hopper_warp_specialized_gemm\n\n"
|
||||
<< " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> The number of independent gemm problems with mnk shape\n"
|
||||
<< " --g=<int> The size of each group for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n"
|
||||
<< " --mode=<int> The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 -g 0 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k * l;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms = 0.0;
|
||||
double gflops = 0.0;
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
cudaError_t error = cudaSuccess;
|
||||
bool passed = false;
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element, class Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
}
|
||||
else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_quant_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
|
||||
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Element, class Layout>
|
||||
bool initialize_scale(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
const Options &options) {
|
||||
|
||||
if (options.mode == GemmMode::ConvertOnly) {
|
||||
// No scales, so just initialize with 1 so we can use the same kernel to dequantize the data.
|
||||
cutlass::reference::host::TensorFill(view, Element(1.0f));
|
||||
}
|
||||
else {
|
||||
float elt_max_f = float(cutlass::platform::numeric_limits<QuantType>::max());
|
||||
const float max_dequant_val = 4.f;
|
||||
const float min_dequant_val = 0.5f;
|
||||
|
||||
float scope_max(max_dequant_val / elt_max_f);
|
||||
float scope_min(min_dequant_val / elt_max_f);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Element, class Layout>
|
||||
bool initialize_zero(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
const Options &options) {
|
||||
|
||||
if (options.mode == GemmMode::ScaleWithZeroPoint) {
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2.0f, -2.0f);
|
||||
} else {
|
||||
// No bias, so just initialize with 1 so we can use the same kernel to dequantize the data.
|
||||
cutlass::reference::host::TensorFill(view, Element(0.0f));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
void initialize(MixedDtypeOptions const& options) {
|
||||
|
||||
auto shape_b = cute::make_shape(options.n, options.k, options.l);
|
||||
const int scale_k = (options.k + options.g - 1) / options.g;
|
||||
int const scale_k = (options.k + options.g - 1) / options.g;
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
|
||||
// Reverse stride here due to swap and transpose
|
||||
@ -469,27 +301,21 @@ void initialize(const Options &options) {
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
tensor_B_dq.resize(b_coord);
|
||||
tensor_C.resize(c_coord);
|
||||
tensor_D.resize(c_coord);
|
||||
tensor_ref_D.resize(c_coord);
|
||||
block_A.reset(a_coord.product());
|
||||
block_B.reset(b_coord.product());
|
||||
block_B_dq.reset(b_coord.product());
|
||||
block_C.reset(c_coord.product());
|
||||
block_D.reset(c_coord.product());
|
||||
block_ref_D.reset(c_coord.product());
|
||||
|
||||
tensor_scale.resize({scale_k * options.l, options.n});
|
||||
tensor_zero.resize({scale_k * options.l, options.n});
|
||||
block_scale.reset(scale_k * options.l * options.n);
|
||||
block_zero.reset(scale_k * options.l * options.n);
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), seed + 2022);
|
||||
initialize_quant_tensor(tensor_B.host_view(), seed + 2021);
|
||||
initialize_tensor(tensor_C.host_view(), seed + 2020);
|
||||
initialize_scale(tensor_scale.host_view(), options);
|
||||
initialize_zero(tensor_zero.host_view(), options);
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_scale.sync_device();
|
||||
tensor_zero.sync_device();
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
|
||||
auto layout_B = make_layout(shape_b, stride_B);
|
||||
|
||||
@ -498,37 +324,36 @@ void initialize(const Options &options) {
|
||||
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
|
||||
|
||||
dequantize_weight(tensor_B_dq.device_data(), tensor_B.device_data(), layout_B, tensor_scale.device_data(), tensor_zero.device_data(), layout_scale_zero, options.g);
|
||||
tensor_B_dq.sync_host();
|
||||
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template <typename Args>
|
||||
Args args_from_options(const Options &options)
|
||||
Args args_from_options(MixedDtypeOptions const& options)
|
||||
{
|
||||
// Swap the A and B tensors, as well as problem shapes here.
|
||||
if (options.mode == GemmMode::ConvertOnly) {
|
||||
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A},
|
||||
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
|
||||
{block_B.get(), stride_B, block_A.get(), stride_A},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
}
|
||||
else if (options.mode == GemmMode::ScaleOnly) {
|
||||
else if (options.mode == MixedDtypeGemmMode::ScaleOnly) {
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A, tensor_scale.device_data(), stride_S, options.g},
|
||||
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
|
||||
{block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
}
|
||||
else if (options.mode == GemmMode::ScaleWithZeroPoint) {
|
||||
else if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) {
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A, tensor_scale.device_data(), stride_S, options.g, tensor_zero.device_data()},
|
||||
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
|
||||
{block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g, block_zero.get()},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
} else {
|
||||
std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl;
|
||||
@ -536,13 +361,13 @@ Args args_from_options(const Options &options)
|
||||
}
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
bool verify(MixedDtypeOptions const& options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// In this example, we use the GPU default kernels as a reference (unfused scale)
|
||||
// This is to avoid numerical differences from different accumulation order.
|
||||
// This avoids numerical differences due to different accumulation order.
|
||||
|
||||
// Again, due to numerical differences, we must use fast acc here when the mma type is
|
||||
// FP8 as the fused implementation only supports fast acc at the moment.
|
||||
@ -581,8 +406,8 @@ bool verify(const Options &options) {
|
||||
typename GemmRef::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(), stride_A, tensor_B_dq.device_data(), stride_B},
|
||||
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C_ref, tensor_ref_D.device_data(), stride_D_ref}
|
||||
{block_A.get(), stride_A, block_B_dq.get(), stride_B},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref}
|
||||
};
|
||||
|
||||
// Run the gemm where the scaling is performed outside of the kernel.
|
||||
@ -594,17 +419,15 @@ bool verify(const Options &options) {
|
||||
CUTLASS_CHECK(gemm_ref.run());
|
||||
|
||||
// compare_reference
|
||||
tensor_D.sync_host();
|
||||
tensor_ref_D.sync_host();
|
||||
const ElementD epsilon(1e-2f);
|
||||
const ElementD non_zero_floor(1e-4f);
|
||||
bool passed = cutlass::reference::host::TensorRelativelyEquals(tensor_ref_D.host_view(), tensor_D.host_view(), epsilon, non_zero_floor);
|
||||
ElementD const epsilon(1e-2f);
|
||||
ElementD const non_zero_floor(1e-4f);
|
||||
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
int run(MixedDtypeOptions &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
@ -630,35 +453,14 @@ int run(Options &options)
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
MixedDtypeResult result;
|
||||
result.passed = verify(options);
|
||||
|
||||
mixed_dtype_profiling(gemm, options, result);
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -687,12 +489,11 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
MixedDtypeOptions options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
@ -706,11 +507,11 @@ int main(int argc, char const **args) {
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
if (options.mode == GemmMode::ConvertOnly) {
|
||||
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
|
||||
std::cout << "Running in no scale mode." << std::endl;
|
||||
run<GemmConvertOnly>(options);
|
||||
}
|
||||
else if (options.mode == GemmMode::ScaleOnly) {
|
||||
else if (options.mode == MixedDtypeGemmMode::ScaleOnly) {
|
||||
if (options.g == options.k) {
|
||||
std::cout << "Running in per-column scale mode." << std::endl;
|
||||
} else {
|
||||
@ -718,7 +519,7 @@ int main(int argc, char const **args) {
|
||||
}
|
||||
run<GemmScaleOnly>(options);
|
||||
}
|
||||
else if (options.mode == GemmMode::ScaleWithZeroPoint) {
|
||||
else if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) {
|
||||
if (options.g == options.k) {
|
||||
std::cout << "Running in per-column scale and zero mode." << std::endl;
|
||||
} else {
|
||||
|
||||
@ -55,5 +55,27 @@ cutlass_example_add_executable(
|
||||
TEST_SCALE_ZERO_GROUPED
|
||||
TEST_SCALE_RESIDUE
|
||||
TEST_SCALE_ZERO_RESIDUE
|
||||
TEST_ALPHA_BETA
|
||||
# TEST_ALPHA_BETA
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
55_hopper_int4_fp8_gemm
|
||||
55_hopper_int4_fp8_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
TEST_SCALE_GROUP
|
||||
TEST_SCALE_RESIDUE
|
||||
# TEST_ALPHA_BETA
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
55_hopper_int4_bf16_gemm
|
||||
55_hopper_int4_bf16_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
TEST_SCALE_GROUP
|
||||
TEST_SCALE_RESIDUE
|
||||
# TEST_ALPHA_BETA
|
||||
)
|
||||
|
||||
@ -3,16 +3,23 @@ This example shows how to do mixed types GEMMs in CUTLASS.
|
||||
## High level overview
|
||||
This example shows how to perform GEMMs on Hopper when A and B have different types. This implementation always passes the type with fewer bits through the register file and upcasts to the type with the higher bit count.
|
||||
|
||||
When relying on `KernelScheduleAuto`, the main loop supporting different A and B types will be selected whenever the bit count of A is not equal to the bit count of B. Users can manually select the mixed type main loop and explicitly choose the scheduling policy by specifying one of the following schedules to the `CollectiveBuilder`: `KernelTmaWarpSpecializedMixedInput`, `KernelTmaWarpSpecializedPingpongMixedInput` or `KernelTmaWarpSpecializedCooperativeMixedInput`.
|
||||
When relying on `KernelScheduleAuto`, the main loop supporting different A and B types will be selected whenever the bit count of A is not equal to the bit count of B. Users can manually select the mixed type main loop and explicitly choose the scheduling policy by specifying one of the following schedules to the `CollectiveBuilder`: `KernelTmaWarpSpecialized`, `KernelTmaWarpSpecializedPingpong` or `KernelTmaWarpSpecializedCooperative`.
|
||||
|
||||
This first version only supports mixed type GEMMs using TMA.
|
||||
|
||||
## Performance
|
||||
|
||||
While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16, bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type.
|
||||
While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4, int2}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16`, `bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type as mma's type.
|
||||
|
||||
The scale only mode for `fp8 x int4` is significantly slower than direct conversion mode. There is a lookup-table workaround targeting this mode, as shown in `55_hopper_int4_fp8_gemm.cu`. To use this feature, use `cutlass::Array<ElementScale, 8>` as the scale type in the collective builder. However, it requires modifications to the encoding of quantized weights and scale factors. Also, scale with zero point mode is not supported for now.
|
||||
|
||||
|
||||
Additionally, it's recommended to reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory. The user can use the helper function `compute_memory_reordering_atom` and `reorder_tensor` to achieve this. See `55_hopper_int4_fp8_gemm.cu` and `55_hopper_int4_bf16_gemm.cu` for more details.
|
||||
|
||||
|
||||
We are currently optimizing the following cases:
|
||||
1. Memory bound cases for all types
|
||||
2. `fp8 x {int2, uint2}` case
|
||||
|
||||
## Limitations
|
||||
|
||||
|
||||
391
examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp
Normal file
391
examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp
Normal file
@ -0,0 +1,391 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <numeric>
|
||||
#include "helper.h"
|
||||
|
||||
enum MixedDtypeGemmMode {
|
||||
ConvertOnly,
|
||||
ScaleOnly,
|
||||
ScaleWithZeroPoint
|
||||
};
|
||||
|
||||
/// Command line options parsing
|
||||
struct MixedDtypeOptions {
|
||||
|
||||
bool help = false;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
int iterations = 1000;
|
||||
int warmup = 1000;
|
||||
int mode = 1;
|
||||
int m = 5120, n = 4096, k = 4096;
|
||||
int g = 128;
|
||||
int l = 1;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("g", g);
|
||||
cmd.get_cmd_line_argument("mode", mode);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("warmup", warmup);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "55_hopper_mixed_dtype_gemm\n\n"
|
||||
<< " Hopper Mixed Data Type GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> The number of independent gemm problems with mnk shape\n"
|
||||
<< " --g=<int> The size of each group for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n"
|
||||
<< " --mode=<int> The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
|
||||
<< " --warmup=<int> Number of warmup iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "55_hopper_mixed_dtype_gemm" << " --m=1024 --n=512 --k=1024 -g=1024 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k * l;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct MixedDtypeResult
|
||||
{
|
||||
double avg_runtime_ms = 0.0;
|
||||
double gflops = 0.0;
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
cudaError_t error = cudaSuccess;
|
||||
bool passed = false;
|
||||
|
||||
};
|
||||
|
||||
/// Profiling Loop
|
||||
template <class Gemm>
|
||||
void mixed_dtype_profiling(
|
||||
Gemm& gemm,
|
||||
MixedDtypeOptions const& options,
|
||||
MixedDtypeResult& result) {
|
||||
|
||||
if (options.iterations <= 0) return;
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
|
||||
std::vector<float> runtimes;
|
||||
runtimes.reserve(options.iterations);
|
||||
|
||||
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
|
||||
cudaEventRecord(start);
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
cudaEventRecord(stop);
|
||||
cudaEventSynchronize(stop);
|
||||
|
||||
if (iter >= options.warmup) {
|
||||
float milliseconds = 0;
|
||||
cudaEventElapsedTime(&milliseconds, start, stop);
|
||||
runtimes.push_back(milliseconds);
|
||||
}
|
||||
}
|
||||
|
||||
cudaEventDestroy(start);
|
||||
cudaEventDestroy(stop);
|
||||
|
||||
// Compute average setup and runtime and GFLOPs.
|
||||
result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size();
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
|
||||
}
|
||||
|
||||
/// Helpers to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_tensor(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed = 2023) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
}
|
||||
else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
bool initialize_quant_tensor(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed = 2023) {
|
||||
|
||||
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
|
||||
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Element>
|
||||
bool initialize_scale(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
MixedDtypeOptions const& options,
|
||||
uint64_t seed = 2023) {
|
||||
|
||||
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
|
||||
// No scales, so just initialize with 1 so we can use the same kernel to dequantize the data.
|
||||
std::vector<Element> stage(block.size(), Element(1.0f));
|
||||
block.copy_from_host(stage.data());
|
||||
}
|
||||
else {
|
||||
float elt_max_f = float(cutlass::platform::numeric_limits<Element>::max());
|
||||
const float max_dequant_val = 4.f;
|
||||
const float min_dequant_val = 0.5f;
|
||||
|
||||
float scope_max(max_dequant_val / elt_max_f);
|
||||
float scope_min(min_dequant_val / elt_max_f);
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Element>
|
||||
bool initialize_zero(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
MixedDtypeOptions const& options,
|
||||
uint64_t seed = 2023) {
|
||||
|
||||
if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(2.0f), Element(-2.0f));
|
||||
} else {
|
||||
// No bias, so just initialize with 1 so we can use the same kernel to dequantize the data.
|
||||
std::vector<Element> stage(block.size(), Element(0.0f));
|
||||
block.copy_from_host(stage.data());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Dequantize the weights for verification
|
||||
|
||||
template <class QuantizedElement,
|
||||
class DequantizedElement,
|
||||
class OperandLayout,
|
||||
class ElementScale,
|
||||
class ElementZero,
|
||||
class ScaleBroadCastLayout,
|
||||
class ThrLayout>
|
||||
__global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer,
|
||||
QuantizedElement const* q_buffer,
|
||||
OperandLayout const operand_layout,
|
||||
ElementScale const* scale_buffer,
|
||||
ElementZero const* zero_buffer,
|
||||
ScaleBroadCastLayout const broadcasted_scale_layout,
|
||||
ThrLayout thr_layout) {
|
||||
using namespace cute;
|
||||
|
||||
// Represent the full tensors to gmem elements.
|
||||
// These are expected to have shape [MN, K, L]
|
||||
cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout);
|
||||
auto init_quantized_iterator = [&]() {
|
||||
if constexpr (cute::sizeof_bits_v<QuantizedElement> >= 8) {
|
||||
return cute::make_gmem_ptr(q_buffer);
|
||||
} else {
|
||||
return cute::subbyte_iterator<const QuantizedElement>(q_buffer);
|
||||
}
|
||||
};
|
||||
cute::Tensor gmem_op_q = cute::make_tensor(init_quantized_iterator(), operand_layout);
|
||||
// While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting
|
||||
// It is expected that K % G == 0
|
||||
cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout);
|
||||
cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout);
|
||||
|
||||
// Assign 1 thread per element in the thread block
|
||||
auto blk_shape = make_shape(size<0>(thr_layout), _1{}, _1{}); //
|
||||
auto blk_coord = make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L)
|
||||
|
||||
// Tile across the block
|
||||
auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord);
|
||||
auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord);
|
||||
auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord);
|
||||
auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord);
|
||||
|
||||
auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x);
|
||||
auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x);
|
||||
auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x);
|
||||
auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x);
|
||||
|
||||
// Make a fragment of registers to hold gmem loads
|
||||
cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0));
|
||||
cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0));
|
||||
cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0));
|
||||
cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0));
|
||||
cute::Tensor rmem_op_scaled = cute::make_fragment_like<ElementScale>(rmem_op_dq);
|
||||
cute::Tensor rmem_zero_buf = cute::make_fragment_like<ElementScale>(rmem_zero);
|
||||
|
||||
cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout));
|
||||
auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord);
|
||||
auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x);
|
||||
|
||||
const auto num_iters = cute::size<3>(tOpDq_gOpDq);
|
||||
|
||||
for (int ii = 0; ii < num_iters; ++ii) {
|
||||
const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii));
|
||||
if (thread_offset < cute::size<0>(operand_layout)) {
|
||||
cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q);
|
||||
cute::copy(tScale_gScale(_, _, _, ii), rmem_scale);
|
||||
cute::copy(tZero_gZero(_, _, _, ii), rmem_zero);
|
||||
cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } );
|
||||
cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } );
|
||||
cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, multiplies{});
|
||||
cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, plus{});
|
||||
cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } );
|
||||
cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class QuantizedElement,
|
||||
class DequantizedElement,
|
||||
class OperandLayout,
|
||||
class ElementScale,
|
||||
class ElementZero,
|
||||
class ScaleLayout>
|
||||
void dequantize_weight(DequantizedElement* dq_buffer,
|
||||
QuantizedElement const* q_buffer,
|
||||
OperandLayout const operand_layout,
|
||||
ElementScale const* scale_buffer,
|
||||
ElementZero const* zero_buffer,
|
||||
ScaleLayout const scale_layout,
|
||||
int const group_size) {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
constexpr int tpb = 128;
|
||||
auto thr_layout = make_layout(make_shape(Int<tpb>{}));
|
||||
|
||||
const auto num_rows = get<0>(shape(operand_layout));
|
||||
const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L]
|
||||
const auto batches = get<2>(shape(operand_layout)); // [MN, K, L]
|
||||
const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L]
|
||||
|
||||
if (num_rows != size<0>(scale_layout)) {
|
||||
std::cerr << "Invalid first dimension for scales. Must match first dim for weights."
|
||||
<< " But got shapes " << shape(operand_layout) << " " << shape(scale_layout)
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const auto scale_stride0 = get<0>(stride(scale_layout));
|
||||
const auto scale_stride1 = get<1>(stride(scale_layout));
|
||||
const auto scale_stride2 = get<2>(stride(scale_layout));
|
||||
|
||||
auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches);
|
||||
auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2);
|
||||
auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast);
|
||||
|
||||
const auto blocks_x = gemm_k;
|
||||
const auto blocks_y = batches;
|
||||
|
||||
dim3 blocks(blocks_x, blocks_y, 1);
|
||||
dequantize_weight_kernel<<<blocks, tpb>>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
212
examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp
Normal file
212
examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp
Normal file
@ -0,0 +1,212 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/integer_subbyte.h"
|
||||
#include "cutlass/float8.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/util/type_traits.hpp"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
template<typename T>
|
||||
class packed_scale_t {
|
||||
public:
|
||||
static_assert(cute::is_same_v<T, cutlass::int8_t> ||
|
||||
cute::is_same_v<T, cutlass::uint8_t> ||
|
||||
cute::is_same_v<T, cutlass::float_e4m3_t> ||
|
||||
cute::is_same_v<T, cutlass::float_e5m2_t>,
|
||||
"only 8 bit arithmetic types are supported.");
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit packed_scale_t(T val) {
|
||||
if constexpr (!cute::is_unsigned_v<T>) {
|
||||
// Only pack negative values. The positive values are generated in flight in the mainloop.
|
||||
storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f));
|
||||
storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val);
|
||||
}
|
||||
else {
|
||||
storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f));
|
||||
storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val);
|
||||
}
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
packed_scale_t() = default;
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit operator float() const {
|
||||
return float(get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(packed_scale_t const& rhs) const {
|
||||
return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1];
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!=(packed_scale_t const& rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() + rhs.get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() - rhs.get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() * rhs.get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() / rhs.get());
|
||||
}
|
||||
|
||||
private:
|
||||
using Storage = uint32_t;
|
||||
using Stage = uint8_t;
|
||||
|
||||
Storage storage[2] {};
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Storage pack4(T c1, T c2, T c3, T c4) {
|
||||
Storage result = 0;
|
||||
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c4)) << 24);
|
||||
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c3)) << 16);
|
||||
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c2)) << 8);
|
||||
result |= static_cast<Storage>(reinterpret_cast<Stage const&>(c1));
|
||||
return result;
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
T get() const {
|
||||
auto stage = static_cast<Stage>(storage[0] >> 8);
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return reinterpret_cast<T const&>(stage);
|
||||
#else
|
||||
T tmp;
|
||||
std::memcpy(&tmp, &stage, sizeof(Stage));
|
||||
return tmp;
|
||||
#endif
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
T get(int idx) const {
|
||||
Stage stage;
|
||||
if (idx < 4) stage = static_cast<Stage>(storage[0] >> (8 * idx));
|
||||
else stage = static_cast<Stage>(storage[1] >> (8 * idx - 32));
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return reinterpret_cast<T const&>(stage);
|
||||
#else
|
||||
T tmp;
|
||||
std::memcpy(&tmp, &stage, sizeof(Stage));
|
||||
return tmp;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Helpers to initialize scale lookup table
|
||||
|
||||
// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT.
|
||||
// Here the encodings of positive values and negative values are unified (except for the sign bit).
|
||||
// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111).
|
||||
bool unify_quant_encoding(
|
||||
cutlass::DeviceAllocation<cutlass::int4b_t> const& block_in,
|
||||
cutlass::DeviceAllocation<cutlass::int4b_t>& block_out) {
|
||||
|
||||
using StorageType = cutlass::int4b_t::Storage;
|
||||
|
||||
if (block_in.size() != block_out.size()) {
|
||||
std::cerr << "block_in and block_out must have same size.\n";
|
||||
return false;
|
||||
}
|
||||
constexpr int pack = cute::sizeof_bits_v<StorageType> / 4;
|
||||
std::vector<StorageType> data(block_in.size() / pack);
|
||||
cutlass::device_memory::copy_to_host(data.data(), (StorageType*)block_in.get(), block_in.size() / pack);
|
||||
|
||||
for (auto&& d : data) {
|
||||
StorageType out = 0;
|
||||
StorageType mask = 0x0f;
|
||||
for (int i = 0; i < pack; ++i) {
|
||||
cutlass::int4b_t curr;
|
||||
curr.storage = (d >> (i * 4)) & 0x0f;
|
||||
switch (curr) {
|
||||
case 1: curr.storage = StorageType(0b0111); break; // 2's complement
|
||||
case 2: curr.storage = StorageType(0b0110); break; // 2's complement
|
||||
case 3: curr.storage = StorageType(0b0101); break; // 2's complement
|
||||
case 4: curr.storage = StorageType(0b0100); break; // 2's complement
|
||||
case 5: curr.storage = StorageType(0b0011); break; // 2's complement
|
||||
case 6: curr.storage = StorageType(0b0010); break; // 2's complement
|
||||
case 7: curr.storage = StorageType(0b0001); break; // 2's complement
|
||||
default: break;
|
||||
}
|
||||
out |= (curr.storage << (4 * i)) & mask;
|
||||
mask <<= 4;
|
||||
}
|
||||
d = out;
|
||||
}
|
||||
|
||||
cutlass::device_memory::copy_to_device((StorageType*)block_out.get(), data.data(), block_out.size() / pack);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ElementScale>
|
||||
bool initialize_packed_scale(
|
||||
cutlass::DeviceAllocation<ElementScale> const& block_in,
|
||||
cutlass::DeviceAllocation<cutlass::Array<ElementScale, 8> > & block_out) {
|
||||
|
||||
std::vector<ElementScale> data_in(block_in.size());
|
||||
std::vector<cutlass::Array<ElementScale, 8> > data_out(block_in.size());
|
||||
try {
|
||||
block_in.copy_to_host(data_in.data());
|
||||
} catch (cutlass::cuda_exception const& e)
|
||||
{
|
||||
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < block_in.size(); ++i)
|
||||
{
|
||||
cutlass::packed_scale_t<ElementScale> tmp(data_in[i]);
|
||||
data_out[i] = reinterpret_cast<cutlass::Array<ElementScale, 8> const&>(tmp);
|
||||
}
|
||||
try {
|
||||
block_out.copy_from_host(data_out.data());
|
||||
} catch (cutlass::cuda_exception const& e)
|
||||
{
|
||||
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
162
examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp
Normal file
162
examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp
Normal file
@ -0,0 +1,162 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/arch/mma_sm90.hpp"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
// Given a type of MMA instruction, compute a memory reordering atom that places all values
|
||||
// owned by each thread in contiguous memory locations. This improves smem load vectorization,
|
||||
// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order
|
||||
// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses.
|
||||
// In addition, we can reorder the values across several MMA instructions to get even wider
|
||||
// vectorization (AtomLayout parameter) and permute the values within each instruction to get
|
||||
// more optimal conversion instruction sequences (ValLayout parameter).
|
||||
template<class ElementMma,
|
||||
class AtomLayout = cute::Layout<cute::_1>,
|
||||
class ValLayout = cute::Layout<cute::_1>>
|
||||
constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {})
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
static_assert(is_static_v<ValLayout>, "ValLayout must be static");
|
||||
static_assert(is_static_v<AtomLayout>, "AtomLayout must be static");
|
||||
|
||||
// 1. Choose an MMA atom to access TV layout and MN shape
|
||||
// Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary
|
||||
using MmaAtom = decltype(SM90::GMMA::rs_op_selector<ElementMma, ElementMma, float, Shape<_64,_16,_32>>());
|
||||
using MmaTraits = MMA_Traits<MmaAtom>;
|
||||
auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{});
|
||||
auto tv_layout_mma = typename MmaTraits::ALayout{};
|
||||
static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout");
|
||||
|
||||
// 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val)
|
||||
// Note: this assumes A is partitioned between warps along M mode
|
||||
auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma));
|
||||
auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{});
|
||||
auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp));
|
||||
auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp);
|
||||
|
||||
// 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization
|
||||
auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout);
|
||||
|
||||
// 4. Compose with a contiguous layout of values in each thread (required for smem vectorization)
|
||||
auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout));
|
||||
auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp));
|
||||
auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset));
|
||||
auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt);
|
||||
|
||||
return layout_atom;
|
||||
}
|
||||
|
||||
template<class TileShape, class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst, class TiledCopy>
|
||||
__global__ void reorder_tensor_kernel(
|
||||
cute::Tensor<EngineSrc, LayoutSrc> S,
|
||||
cute::Tensor<EngineDst, LayoutDst> D,
|
||||
TiledCopy tiled_copy)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
using T = typename EngineDst::value_type;
|
||||
|
||||
Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
|
||||
Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
|
||||
|
||||
auto thread_copy = tiled_copy.get_slice(threadIdx.x);
|
||||
Tensor tS = thread_copy.partition_S(gS);
|
||||
Tensor tD = thread_copy.partition_D(gD);
|
||||
|
||||
copy(tiled_copy, tS, tD);
|
||||
}
|
||||
|
||||
template<class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst>
|
||||
void reorder_tensor(
|
||||
cute::Tensor<EngineSrc, LayoutSrc> S,
|
||||
cute::Tensor<EngineDst, LayoutDst> D)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
using T = typename EngineDst::value_type;
|
||||
static_assert(is_same_v<remove_const_t<typename EngineSrc::value_type>, T>, "Type mismatch");
|
||||
|
||||
// Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread
|
||||
// This avoids a race condition when writing out subbyte types (e.g. int4b_t).
|
||||
auto has_major_mode = [](auto s) {
|
||||
return any_of(s, [](auto a){ return is_constant<1, decltype(a)>{}; });
|
||||
};
|
||||
static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})),
|
||||
"Could not find stride-1 mode in destination layout");
|
||||
constexpr int N = shape_div(Int<8>{}, sizeof_bits<T>{});
|
||||
auto val_layout = conditional_return<has_major_mode(stride<0>(LayoutDst{}))>(
|
||||
make_layout(make_shape(Int<N>{}, Int<1>{}), GenColMajor{}),
|
||||
make_layout(make_shape(Int<1>{}, Int<N>{}), GenRowMajor{}));
|
||||
|
||||
// Make a tiled copy with a simple row-major thread order and above layout
|
||||
int constexpr NumThreads = 128;
|
||||
auto const thr_layout = make_layout(make_shape(Int<1>{}, Int<NumThreads>{}));
|
||||
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, T>{}, thr_layout, val_layout);
|
||||
|
||||
// Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper
|
||||
using TileShape = Shape<_16>;
|
||||
auto tiled_D = group_modes<3,rank_v<LayoutDst>>(tiled_divide(D, TileShape{}));
|
||||
dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))};
|
||||
|
||||
reorder_tensor_kernel<TileShape><<<blocks, NumThreads>>>(S, D, tiled_copy);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
// In-place version
|
||||
template<class T, class LayoutSrc, class LayoutDst>
|
||||
void reorder_tensor(
|
||||
T const* src,
|
||||
LayoutSrc const& layout_src,
|
||||
T * dst,
|
||||
LayoutDst const& layout_dst)
|
||||
{
|
||||
using namespace cute;
|
||||
reorder_tensor(make_tensor(make_gmem_ptr<T>(src), layout_src),
|
||||
make_tensor(make_gmem_ptr<T>(dst), layout_dst));
|
||||
}
|
||||
|
||||
// In-place version
|
||||
template<class T, class LayoutSrc, class LayoutDst>
|
||||
void reorder_tensor(
|
||||
T * data,
|
||||
LayoutSrc const& layout_src,
|
||||
LayoutDst const& layout_dst)
|
||||
{
|
||||
using namespace cute;
|
||||
cutlass::DeviceAllocation<T> temp(size(layout_src));
|
||||
reorder_tensor(data, layout_src, temp.get(), layout_dst);
|
||||
cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast<size_t>(size(layout_src)));
|
||||
}
|
||||
@ -1,161 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include <cuda.h>
|
||||
#include "helper.h"
|
||||
|
||||
template <class QuantizedElement,
|
||||
class DequantizedElement,
|
||||
class OperandLayout,
|
||||
class ElementScale,
|
||||
class ElementZero,
|
||||
class ScaleBroadCastLayout,
|
||||
class ThrLayout>
|
||||
__global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer,
|
||||
QuantizedElement const* q_buffer,
|
||||
OperandLayout const operand_layout,
|
||||
ElementScale const* scale_buffer,
|
||||
ElementZero const* zero_buffer,
|
||||
ScaleBroadCastLayout const broadcasted_scale_layout,
|
||||
ThrLayout thr_layout) {
|
||||
using namespace cute;
|
||||
|
||||
// Represent the full tensors to gmem elements.
|
||||
// These are expected to have shape [MN, K, L]
|
||||
Tensor gmem_op_dq = make_tensor(make_gmem_ptr(dq_buffer), operand_layout);
|
||||
auto init_quantized_iterator = [&]() {
|
||||
if constexpr (cute::sizeof_bits_v<QuantizedElement> >= 8) {
|
||||
return make_gmem_ptr(q_buffer);
|
||||
} else {
|
||||
return subbyte_iterator<const QuantizedElement>(q_buffer);
|
||||
}
|
||||
};
|
||||
Tensor gmem_op_q = make_tensor(init_quantized_iterator(), operand_layout);
|
||||
// While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting
|
||||
// It is expected that K % G == 0
|
||||
Tensor gmem_scale_broadcasted = make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout);
|
||||
Tensor gmem_zero_broadcasted = make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout);
|
||||
|
||||
// Assign 1 thread per element in the thread block
|
||||
auto blk_shape = make_shape(size<0>(thr_layout), _1{}, _1{}); //
|
||||
auto blk_coord = make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L)
|
||||
|
||||
// Tile across the block
|
||||
auto gOp_dq = local_tile(gmem_op_dq, blk_shape, blk_coord);
|
||||
auto gScale = local_tile(gmem_scale_broadcasted, blk_shape, blk_coord);
|
||||
auto gZero = local_tile(gmem_zero_broadcasted, blk_shape, blk_coord);
|
||||
auto gOp_q = local_tile(gmem_op_q, blk_shape, blk_coord);
|
||||
|
||||
auto tOpDq_gOpDq = local_partition(gOp_dq, thr_layout, threadIdx.x);
|
||||
auto tScale_gScale = local_partition(gScale, thr_layout, threadIdx.x);
|
||||
auto tZero_gZero = local_partition(gZero, thr_layout, threadIdx.x);
|
||||
auto tOpQ_gOpQ = local_partition(gOp_q, thr_layout, threadIdx.x);
|
||||
|
||||
// Make a fragment of registers to hold gmem loads
|
||||
Tensor rmem_op_q = make_fragment_like(tOpQ_gOpQ(_, _, _, 0));
|
||||
Tensor rmem_scale = make_fragment_like(tScale_gScale(_, _, _, 0));
|
||||
Tensor rmem_zero = make_fragment_like(tZero_gZero(_, _, _, 0));
|
||||
Tensor rmem_op_dq = make_fragment_like(tOpDq_gOpDq(_, _, _, 0));
|
||||
Tensor rmem_op_scaled = make_fragment_like<ElementScale>(rmem_op_dq);
|
||||
Tensor rmem_zero_buf = make_fragment_like<ElementScale>(rmem_zero);
|
||||
|
||||
Tensor pred_id = make_identity_tensor(shape(operand_layout));
|
||||
auto pred_blk_tile = local_tile(pred_id, blk_shape, blk_coord);
|
||||
auto pred_thr_partition = local_partition(pred_blk_tile, thr_layout, threadIdx.x);
|
||||
|
||||
const auto num_iters = size<3>(tOpDq_gOpDq);
|
||||
|
||||
for (int ii = 0; ii < num_iters; ++ii) {
|
||||
const auto thread_offset = get<0>(pred_thr_partition(0, 0, 0, ii));
|
||||
if (thread_offset < size<0>(operand_layout)) {
|
||||
copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q);
|
||||
copy(tScale_gScale(_, _, _, ii), rmem_scale);
|
||||
copy(tZero_gZero(_, _, _, ii), rmem_zero);
|
||||
transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } );
|
||||
transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } );
|
||||
transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, multiplies{});
|
||||
transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, plus{});
|
||||
transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } );
|
||||
copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class QuantizedElement,
|
||||
class DequantizedElement,
|
||||
class OperandLayout,
|
||||
class ElementScale,
|
||||
class ElementZero,
|
||||
class ScaleLayout>
|
||||
void dequantize_weight(DequantizedElement* dq_buffer,
|
||||
QuantizedElement const* q_buffer,
|
||||
OperandLayout const operand_layout,
|
||||
ElementScale const* scale_buffer,
|
||||
ElementZero const* zero_buffer,
|
||||
ScaleLayout const scale_layout,
|
||||
int const group_size) {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
constexpr int tpb = 128;
|
||||
auto thr_layout = make_layout(make_shape(Int<tpb>{}));
|
||||
|
||||
const auto num_rows = get<0>(shape(operand_layout));
|
||||
const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L]
|
||||
const auto batches = get<2>(shape(operand_layout)); // [MN, K, L]
|
||||
const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L]
|
||||
|
||||
if (num_rows != size<0>(scale_layout)) {
|
||||
std::cerr << "Invalid first dimension for scales. Must match first dim for weights."
|
||||
<< " But got shapes " << shape(operand_layout) << " " << shape(scale_layout)
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const auto scale_stride0 = get<0>(stride(scale_layout));
|
||||
const auto scale_stride1 = get<1>(stride(scale_layout));
|
||||
const auto scale_stride2 = get<2>(stride(scale_layout));
|
||||
|
||||
auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches);
|
||||
auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2);
|
||||
auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast);
|
||||
|
||||
const auto blocks_x = gemm_k;
|
||||
const auto blocks_y = batches;
|
||||
|
||||
dim3 blocks(blocks_x, blocks_y, 1);
|
||||
dequantize_weight_kernel<<<blocks, tpb>>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
@ -32,7 +32,7 @@
|
||||
/*! \file
|
||||
\brief Hopper Ptr-Array Batched GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture.
|
||||
|
||||
This example demonstrates an implementation of Ptr-Array Batched GEMM using a TMA + GMMA
|
||||
This example demonstrates an implementation of Ptr-Array Batched GEMM using a TMA + GMMA
|
||||
warp-specialized cooperative kernel.
|
||||
The new feature showcased in this example is on-the-fly modification of TMA descriptors
|
||||
to move between batches (represented by l).
|
||||
@ -95,40 +95,66 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // M
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
// Different configs for pingpong/cooperative
|
||||
struct CooperativeConfig {
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_256,_128,_64>;
|
||||
using ClusterShape = Shape<_1,_2,_1>;
|
||||
};
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
struct PingpongConfig {
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = Shape<_64,_128,_64>;
|
||||
using ClusterShape = Shape<_1,_1,_1>;
|
||||
};
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
template <typename ScheduleConfig>
|
||||
struct GemmGivenSchedule {
|
||||
using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size
|
||||
using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch
|
||||
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
|
||||
using GemmKernel = GemmGivenSchedule<CooperativeConfig>::GemmKernel;
|
||||
using Gemm = GemmGivenSchedule<CooperativeConfig>::Gemm;
|
||||
|
||||
using GemmKernelPingpong = GemmGivenSchedule<PingpongConfig>::GemmKernel;
|
||||
using GemmPingpong = GemmGivenSchedule<PingpongConfig>::Gemm;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
@ -261,14 +287,14 @@ bool initialize_block(
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
scope_max = static_cast<Element>(2);
|
||||
scope_min = static_cast<Element>(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
scope_max = static_cast<Element>(2);
|
||||
scope_min = static_cast<Element>(-2);
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
scope_max = static_cast<Element>(8);
|
||||
scope_min = static_cast<Element>(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
@ -351,7 +377,8 @@ void initialize(const Options &options) {
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
template <typename GemmT>
|
||||
typename GemmT::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
@ -359,7 +386,7 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
typename GemmT::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kArray,
|
||||
{{options.m, options.n, options.k, options.l}},
|
||||
{ptr_A.get(), stride_A, ptr_B.get(), stride_B},
|
||||
@ -405,20 +432,20 @@ bool verify(const Options &options) {
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
template <typename GemmT>
|
||||
int run(Options &options)
|
||||
{
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
GemmT gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
auto arguments = args_from_options<GemmT>(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
size_t workspace_size = GemmT::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
@ -492,7 +519,6 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
@ -511,10 +537,14 @@ int main(int argc, char const **args) {
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
std::cout << "\n*** Cooperative schedule ***" << std::endl;
|
||||
run<Gemm>(options);
|
||||
std::cout << "\n*** Pingpong schedule ***" << std::endl;
|
||||
run<GemmPingpong>(options);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -30,10 +30,10 @@
|
||||
set(TEST_SQUARE --m=2048 --n=2048 --k=2048 -l=10 --iterations=1) # Square problem sizes
|
||||
set(TEST_SQUARE_LARGE_BATCH --m=2048 --n=2048 --k=2048 -l=500 --iterations=1) # Square problem sizes
|
||||
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=1) # Default problem sizes
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=1) # Default problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_BATCH --alpha=1.5 --beta=2.0 -l=500 --iterations=1) # Default problem sizes
|
||||
|
||||
set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Default problem sizes w/ Epilogue Op test
|
||||
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Default problem sizes w/ Epilogue Op test
|
||||
set(TEST_EPILOGUE_OP_LARGE_BATCH --alpha=1.5 -l=500 --iterations=1) # Default problem sizes w/ Epilogue Op test
|
||||
|
||||
set(TEST_SMALLK --m=2048 --n=5120 --k=128 --l=5 --iterations=1) # Small-k problem sizes
|
||||
|
||||
@ -32,7 +32,7 @@
|
||||
/*! \file
|
||||
\brief Hopper Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture.
|
||||
|
||||
This example demonstrates an implementation of Grouped GEMM using a TMA + GMMA
|
||||
This example demonstrates an implementation of Grouped GEMM using a TMA + GMMA
|
||||
warp-specialized cooperative kernel.
|
||||
For this example all scheduling work is performed on the device.
|
||||
The new feature showcased in this example is on-the-fly modification of TMA descriptors
|
||||
@ -42,7 +42,7 @@
|
||||
|
||||
$ ./examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm --m=2048 --n=2048 --k=2048 --groups=10
|
||||
|
||||
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
|
||||
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
|
||||
Skipping any of the problem dimensions randomizes it across the different groups.
|
||||
Same applies for alpha and beta values that are randomized across the different groups.
|
||||
|
||||
@ -63,7 +63,7 @@
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <float.h>
|
||||
#include <cfloat>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
@ -91,9 +91,9 @@
|
||||
|
||||
using namespace cute;
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
@ -117,20 +117,39 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // A
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
// Different configs for pingpong/cooperative
|
||||
struct CooperativeConfig {
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_256,_128,_128>;
|
||||
using ClusterShape = Shape<_2,_2,_1>;
|
||||
};
|
||||
|
||||
struct PingpongConfig {
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = Shape<_128,_128,_128>;
|
||||
using ClusterShape = Shape<_2,_1,_1>;
|
||||
};
|
||||
|
||||
template <typename ScheduleConfig>
|
||||
struct GemmGivenSchedule {
|
||||
using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size
|
||||
using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch
|
||||
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
EpilogueSchedule
|
||||
EpilogueSchedule,
|
||||
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
@ -144,13 +163,20 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
|
||||
using GemmKernel = GemmGivenSchedule<CooperativeConfig>::GemmKernel;
|
||||
using Gemm = GemmGivenSchedule<CooperativeConfig>::Gemm;
|
||||
|
||||
using GemmKernelPingpong = GemmGivenSchedule<PingpongConfig>::GemmKernel;
|
||||
using GemmPingpong = GemmGivenSchedule<PingpongConfig>::Gemm;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
@ -163,10 +189,10 @@ using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD;
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
// Host-side allocations
|
||||
std::vector<int64_t> offset_A;
|
||||
@ -226,7 +252,7 @@ struct Options {
|
||||
std::string benchmark_path;
|
||||
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
|
||||
int const tma_alignment_bits = 128;
|
||||
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
@ -271,10 +297,10 @@ struct Options {
|
||||
int n = cmd_line_n;
|
||||
int k = cmd_line_k;
|
||||
if (m < 1) {
|
||||
m = ((rand() % 512) + 1);
|
||||
m = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
if (n < 1) {
|
||||
n = ((rand() % 512) + 1);
|
||||
n = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
if (k < 1) {
|
||||
k = alignment * ((rand() % 64) + 1);
|
||||
@ -438,10 +464,10 @@ void allocate(const Options &options) {
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, Int<1>{})));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{})));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{})));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{})));
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
|
||||
|
||||
}
|
||||
|
||||
@ -456,7 +482,7 @@ void allocate(const Options &options) {
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
@ -521,7 +547,8 @@ void initialize(const Options &options) {
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true)
|
||||
template <typename GemmT>
|
||||
typename GemmT::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
@ -529,33 +556,49 @@ typename Gemm::Arguments args_from_options(const Options &options, bool host_pro
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
typename Gemm::EpilogueOutputOp::Params params;
|
||||
typename GemmT::Arguments arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
|
||||
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
|
||||
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
|
||||
params = typename Gemm::EpilogueOutputOp::Params(
|
||||
ElementAccumulator(options.alpha), ElementAccumulator(options.beta));
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
// Single alpha and beta for all groups
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
}
|
||||
else {
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
|
||||
params = typename Gemm::EpilogueOutputOp::Params(alpha_device.get(), beta_device.get());
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// One alpha and beta per each group
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
|
||||
typename Gemm::Arguments arguments;
|
||||
if (host_problem_shapes_available) {
|
||||
arguments = typename Gemm::Arguments {
|
||||
arguments = typename GemmT::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
|
||||
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
else {
|
||||
arguments = typename Gemm::Arguments {
|
||||
arguments = typename GemmT::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
|
||||
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
@ -605,20 +648,20 @@ bool verify(const Options &options) {
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
template <typename GemmT>
|
||||
int run(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
GemmT gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options, host_problem_shapes_available);
|
||||
auto arguments = args_from_options<GemmT>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
size_t workspace_size = GemmT::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
@ -695,7 +738,6 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
@ -714,8 +756,14 @@ int main(int argc, char const **args) {
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
std::cout << "\n*** Cooperative schedule ***" << std::endl;
|
||||
run<Gemm>(options);
|
||||
std::cout << "\n*** Cooperative schedule (host problem shapes unavailable) ***" << std::endl;
|
||||
run<Gemm>(options, false /*host_problem_shapes_available*/);
|
||||
std::cout << "\n*** Pingpong schedule ***" << std::endl;
|
||||
run<GemmPingpong>(options);
|
||||
std::cout << "\n*** Pingpong schedule (host problem shapes unavailable) ***" << std::endl;
|
||||
run<GemmPingpong>(options, false /*host_problem_shapes_available*/);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@ -32,10 +32,10 @@
|
||||
set(TEST_RANDOM --iterations=0) # Random problem sizes
|
||||
set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Random problem sizes
|
||||
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes
|
||||
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes
|
||||
|
||||
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=50 --iterations=0) # Fixed problem sizes
|
||||
|
||||
34
examples/58_ada_fp8_gemm/CMakeLists.txt
Normal file
34
examples/58_ada_fp8_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,34 @@
|
||||
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
58_ada_fp8_gemm
|
||||
ada_fp8_gemm.cu
|
||||
)
|
||||
826
examples/58_ada_fp8_gemm/ada_fp8_gemm.cu
Normal file
826
examples/58_ada_fp8_gemm/ada_fp8_gemm.cu
Normal file
@ -0,0 +1,826 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Example of running an Ada FP8 GEMM.
|
||||
|
||||
In addition to using FP8 Tensor Core instructions, the Ada FP8 GEMM uses a distinct epilogue
|
||||
that enables additional scaling of operands/outputs, storing a pre-activation-function output
|
||||
tensor (called the "auxiliary" output), and computing the absolute maximum value of the
|
||||
outputs.
|
||||
|
||||
Pseudocode for this epilogue is as follows:
|
||||
|
||||
Aux = ((alpha * scale_a * scale_b) * accumulator) + ((beta * scale_c) * source) + bias
|
||||
D = activation(Aux)
|
||||
|
||||
if Aux is fp8 type:
|
||||
abs_max_output = max( abs(aux) | (for every aux in Aux))
|
||||
Aux = scale_aux * Aux
|
||||
endif
|
||||
|
||||
if D is fp8 type:
|
||||
abs_max_output = max( abs(d) | (for every d in D))
|
||||
D = scale_d * D
|
||||
endif
|
||||
|
||||
Parameter Aux is optionally stored to global memory
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm_complex.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_generic_with_scaling.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_with_absmax.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementOutput = cutlass::float_e4m3_t;
|
||||
using ElementAuxOutput = ElementOutput;
|
||||
using ElementAccumulator = float;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
static int const kStages = 3;
|
||||
static int const kAlignmentA = 16;
|
||||
static int const kAlignmentB = 16;
|
||||
|
||||
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax<
|
||||
cutlass::epilogue::thread::ReLu,
|
||||
ElementOutput,
|
||||
ElementAuxOutput,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>;
|
||||
|
||||
template <typename MathOperator>
|
||||
using Gemm_ = cutlass::gemm::device::GemmUniversalWithAbsMax<
|
||||
ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89,
|
||||
cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>,
|
||||
EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages,
|
||||
kAlignmentA, kAlignmentB, MathOperator
|
||||
>;
|
||||
|
||||
using ElementAbsmax = typename EpilogueOutputOp::ElementAbsmax;
|
||||
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
bool error;
|
||||
bool reference_check;
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
|
||||
int iterations;
|
||||
int warmup_iterations;
|
||||
|
||||
bool scale_A;
|
||||
bool scale_B;
|
||||
bool scale_C;
|
||||
|
||||
float alpha;
|
||||
float beta;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
error(false),
|
||||
reference_check(false),
|
||||
iterations(20),
|
||||
warmup_iterations(5),
|
||||
scale_A(true),
|
||||
scale_B(true),
|
||||
scale_C(true),
|
||||
alpha(1.f),
|
||||
beta(0.f)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations, 20);
|
||||
cmd.get_cmd_line_argument("warmup_iterations", warmup_iterations, 5);
|
||||
cmd.get_cmd_line_argument("reference-check", reference_check, false);
|
||||
cmd.get_cmd_line_argument("scale-A", scale_A, true);
|
||||
cmd.get_cmd_line_argument("scale-B", scale_B, true);
|
||||
cmd.get_cmd_line_argument("scale-C", scale_C, true);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
|
||||
int m, n, k;
|
||||
cmd.get_cmd_line_argument("m", m, 1024);
|
||||
cmd.get_cmd_line_argument("n", n, 1024);
|
||||
cmd.get_cmd_line_argument("k", k, 1024);
|
||||
|
||||
problem_size = cutlass::gemm::GemmCoord{m, n, k};
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "58_ada_fp8_gemm\n\n"
|
||||
<< " This example executes a GEMM using Ada FP8 Tensor Core operations. In addition to performing\n"
|
||||
<< " a normal GEMM, the kernel performs the following operations:\n"
|
||||
<< " Aux = ((alpha * scale_a * scale_b) * accumulator) + ((beta * scale_c) * source) + bias\n"
|
||||
<< " D = activation(Aux)\n\n"
|
||||
<< " if Aux is fp8:\n"
|
||||
<< " abs_max_output = max( abs(aux) | (for every aux in Aux) )\n"
|
||||
<< " Aux = scale_aux * Aux\n\n"
|
||||
<< " if D is fp8 type:\n"
|
||||
<< " abs_max_output = max( abs(d) | (for every d in D) )\n"
|
||||
<< " D = scale_d * D\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M dimension of the GEMM\n"
|
||||
<< " --n=<int> Sets the N dimension of the GEMM\n"
|
||||
<< " --k=<int> Sets the K dimension of the GEMM\n"
|
||||
<< " --scale-A=<bool> Whether to apply a scaling factor to operand A (default: true)\n"
|
||||
<< " --scale-B=<bool> Whether to apply a scaling factor to operand B (default: true)\n"
|
||||
<< " --scale-C=<bool> Whether to apply a scaling factor to operand C (default: true)\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform\n"
|
||||
<< " --warmup-iterations=<int> Number of warmup iterations to perform\n"
|
||||
<< " --reference-check=<bool> If true, performs reference check\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
float gflops(float runtime_s) const {
|
||||
// Two flops per multiply-add
|
||||
return 2.0f * float(problem_size.product()) / float(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Helper class to run the kernel
|
||||
template <typename Gemm>
|
||||
struct TestbedRunner {
|
||||
|
||||
using ElementAccumulator = typename Gemm::ElementAccumulator;
|
||||
using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute;
|
||||
using ElementScalingFactor = typename Gemm::EpilogueOutputOp::ElementScalingFactor;
|
||||
|
||||
static bool const kScaleAux = Gemm::EpilogueOutputOp::kIsScalingAndAmaxAuxOutputNeeded;
|
||||
static bool const kScaleOutput = Gemm::EpilogueOutputOp::kIsScalingAndAmaxOutputNeeded;
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A;
|
||||
cutlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B;
|
||||
cutlass::HostTensor<typename Gemm::ElementC, typename Gemm::LayoutC> tensor_C;
|
||||
cutlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementAuxOutput, typename Gemm::LayoutC> tensor_Aux;
|
||||
cutlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementOutput, typename Gemm::LayoutC> tensor_D;
|
||||
cutlass::HostTensor<typename Gemm::ElementC, typename Gemm::LayoutC> tensor_Vector;
|
||||
cutlass::HostTensor<ElementAccumulator, typename Gemm::LayoutC> tmp_D;
|
||||
cutlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementOutput, typename Gemm::LayoutC> reference_D;
|
||||
cutlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementAuxOutput, typename Gemm::LayoutC> reference_Aux;
|
||||
cutlass::HostTensor<ElementScalingFactor, typename Gemm::LayoutC> scale_A;
|
||||
cutlass::HostTensor<ElementScalingFactor, typename Gemm::LayoutC> scale_B;
|
||||
cutlass::HostTensor<ElementScalingFactor, typename Gemm::LayoutC> scale_C;
|
||||
cutlass::HostTensor<ElementScalingFactor, typename Gemm::LayoutC> scale_D;
|
||||
cutlass::HostTensor<ElementScalingFactor, typename Gemm::LayoutC> scale_Aux;
|
||||
cutlass::HostTensor<ElementAbsmax, typename Gemm::LayoutC> abs_max_Aux;
|
||||
cutlass::HostTensor<ElementAbsmax, typename Gemm::LayoutC> abs_max_D;
|
||||
cutlass::HostTensor<ElementAbsmax, typename Gemm::LayoutC> reference_abs_max_Aux;
|
||||
cutlass::HostTensor<ElementAbsmax, typename Gemm::LayoutC> reference_abs_max_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
TestbedRunner(
|
||||
bool scaleA = true,
|
||||
bool scaleB = true,
|
||||
bool scaleC = true,
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize scaling factors
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_factor(cutlass::TensorView<Element, Layout> view, uint64_t seed, int bits=0) {
|
||||
cutlass::reference::host::TensorFillRandomUniform(view, seed, double(1.), double(0.), bits);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<typename Gemm::ElementC>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
std::cerr << "Not implemented";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initializes data structures
|
||||
void initialize(const Options& options) {
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
tensor_A.resize(options.problem_size.mk());
|
||||
tensor_B.resize(options.problem_size.kn());
|
||||
tensor_C.resize(options.problem_size.mn());
|
||||
tensor_D.resize(options.problem_size.mn());
|
||||
tensor_Vector.resize({1, options.problem_size.n()});
|
||||
reference_D.resize(options.problem_size.mn(), false);
|
||||
tmp_D.resize(options.problem_size.mn(), false);
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), init_A, seed + 2019);
|
||||
initialize_tensor(tensor_B.host_view(), init_B, seed + 2018);
|
||||
initialize_tensor(tensor_C.host_view(), init_C, seed + 2017);
|
||||
initialize_tensor(tensor_Vector.host_view(), init_C, seed + 2020);
|
||||
|
||||
// It is possible to randomly initialize to all zeros, so override this with non-zeros
|
||||
// in the upper left corner of each operand.
|
||||
cutlass::Coord<2> origin(0);
|
||||
tensor_A.host_view().at(origin) = typename Gemm::ElementA(1);
|
||||
tensor_B.host_view().at(origin) = typename Gemm::ElementB(1);
|
||||
tensor_C.host_view().at(origin) = typename Gemm::ElementC(1);
|
||||
tensor_Vector.host_view().at(origin) = typename Gemm::ElementC(1);
|
||||
|
||||
cutlass::reference::host::TensorFill(tensor_D.host_view());
|
||||
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_D.sync_device();
|
||||
tensor_Vector.sync_device();
|
||||
|
||||
int scale_bits = 2;
|
||||
if (options.scale_A) {
|
||||
scale_A.resize({1, 1});
|
||||
initialize_scale_factor(scale_A.host_view(), seed + 2021, scale_bits);
|
||||
scale_A.sync_device();
|
||||
}
|
||||
|
||||
if (options.scale_B) {
|
||||
scale_B.resize({1, 1});
|
||||
initialize_scale_factor(scale_B.host_view(), seed + 2022, scale_bits);
|
||||
scale_B.sync_device();
|
||||
}
|
||||
|
||||
if (options.scale_C) {
|
||||
scale_C.resize({1, 1});
|
||||
initialize_scale_factor(scale_C.host_view(), seed + 2023, scale_bits);
|
||||
scale_C.sync_device();
|
||||
}
|
||||
|
||||
if (kScaleOutput) {
|
||||
scale_D.resize({1, 1});
|
||||
initialize_scale_factor(scale_D.host_view(), seed + 2024, scale_bits);
|
||||
scale_D.sync_device();
|
||||
|
||||
abs_max_D.resize({1, 1});
|
||||
cutlass::reference::host::TensorFill(abs_max_D.host_view());
|
||||
abs_max_D.sync_device();
|
||||
|
||||
reference_abs_max_D.resize({1, 1});
|
||||
}
|
||||
|
||||
if (kScaleAux) {
|
||||
tensor_Aux.resize(options.problem_size.mn());
|
||||
cutlass::reference::host::TensorFill(tensor_Aux.host_view());
|
||||
tensor_Aux.sync_device();
|
||||
|
||||
scale_Aux.resize({1, 1});
|
||||
initialize_scale_factor(scale_Aux.host_view(), seed + 2025, scale_bits);
|
||||
scale_Aux.sync_device();
|
||||
|
||||
abs_max_Aux.resize({1, 1});
|
||||
cutlass::reference::host::TensorFill(abs_max_Aux.host_view());
|
||||
abs_max_Aux.sync_device();
|
||||
|
||||
reference_Aux.resize(options.problem_size.mn(), false);
|
||||
reference_abs_max_Aux.resize({1, 1});
|
||||
}
|
||||
}
|
||||
|
||||
/// Compares computed reference with device reference and outputs to a file if incorrect
|
||||
bool compare_reference(const Options& options) {
|
||||
|
||||
tensor_D.sync_host();
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view());
|
||||
|
||||
if (kScaleAux) {
|
||||
tensor_Aux.sync_host();
|
||||
abs_max_Aux.sync_host();
|
||||
passed &= cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view());
|
||||
passed &= cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view());
|
||||
}
|
||||
|
||||
if (kScaleOutput) {
|
||||
abs_max_D.sync_host();
|
||||
passed &= cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view());
|
||||
}
|
||||
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
|
||||
std::string output_file = "testbed_with_amax_errors.txt";
|
||||
std::ofstream file(output_file);
|
||||
|
||||
file
|
||||
<< "problem: " << options.problem_size
|
||||
<< ", alpha: " << options.alpha << ", beta: " << options.beta << "\n\n";
|
||||
|
||||
file
|
||||
<< "A =\n" << tensor_A.host_view()
|
||||
<< "\nB =\n" << tensor_B.host_view()
|
||||
<< "\nC =\n" << tensor_C.host_view()
|
||||
<< "\nVector =\n" << tensor_Vector.host_view()
|
||||
<< "\nScaleA = " << scale_A.host_view()
|
||||
<< "\nScaleB = " << scale_B.host_view()
|
||||
<< "\nScaleC = " << scale_C.host_view()
|
||||
<< "\nScaleD = " << scale_D.host_view()
|
||||
<< "\nScaleAux = " << scale_Aux.host_view()
|
||||
<< "\n\nReference D =\n" << reference_D.host_view()
|
||||
<< "\nComputed D =\n" << tensor_D.host_view();
|
||||
if (kScaleAux) {
|
||||
file
|
||||
<< "\n\nReference Aux =\n" << reference_Aux.host_view()
|
||||
<< "\nComputed Aux =\n" << tensor_Aux.host_view()
|
||||
<< "\n\nReference Absmax Aux = " << reference_abs_max_Aux.host_view()
|
||||
<< "\nComputed Absmax Aux = " << abs_max_Aux.host_view();
|
||||
}
|
||||
if (kScaleOutput) {
|
||||
file
|
||||
<< "\n\nReference Absmax D = " << reference_abs_max_D.host_view()
|
||||
<< "\nComputed Absmax D = " << abs_max_D.host_view();
|
||||
}
|
||||
|
||||
std::cerr << "Dumped results to " << output_file << std::endl;
|
||||
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Verifies the result is a GEMM
|
||||
bool verify(const Options& options) {
|
||||
|
||||
cutlass::Coord<2> origin(0);
|
||||
ElementCompute scaled_alpha = options.alpha;
|
||||
if (options.scale_A) {
|
||||
scaled_alpha *= scale_A.host_view().at(origin);
|
||||
}
|
||||
if (options.scale_B) {
|
||||
scaled_alpha *= scale_B.host_view().at(origin);
|
||||
}
|
||||
|
||||
ElementCompute scaled_beta = options.beta;
|
||||
if (options.scale_C) {
|
||||
scaled_beta *= scale_C.host_view().at(origin);
|
||||
}
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::host::GemmComplex<
|
||||
typename Gemm::ElementA, typename Gemm::LayoutA,
|
||||
typename Gemm::ElementB, typename Gemm::LayoutB,
|
||||
typename Gemm::ElementC, typename Gemm::LayoutC,
|
||||
ElementCompute, ElementAccumulator, ElementAccumulator
|
||||
>(
|
||||
options.problem_size,
|
||||
scaled_alpha,
|
||||
tensor_A.host_ref(),
|
||||
Gemm::kTransformA,
|
||||
tensor_B.host_ref(),
|
||||
Gemm::kTransformB,
|
||||
scaled_beta,
|
||||
tensor_C.host_ref(),
|
||||
tmp_D.host_ref(),
|
||||
ElementAccumulator(0)
|
||||
);
|
||||
|
||||
ElementCompute tmp_abs_max_Aux(0.);
|
||||
ElementCompute tmp_abs_max_D(0.);
|
||||
|
||||
cutlass::NumericConverter<ElementCompute, typename Gemm::ElementC> cvt_c_to_compute;
|
||||
cutlass::NumericConverter<ElementCompute, ElementAccumulator> cvt_accum_to_compute;
|
||||
cutlass::NumericConverter<ElementAccumulator, ElementCompute> cvt_compute_to_accum;
|
||||
cutlass::NumericConverter<typename Gemm::EpilogueOutputOp::ElementOutput, ElementCompute> cvt_compute_to_d;
|
||||
cutlass::NumericConverter<typename Gemm::EpilogueOutputOp::ElementAuxOutput, ElementCompute> cvt_compute_to_aux;
|
||||
|
||||
cutlass::absolute_value_op<ElementCompute> abs;
|
||||
cutlass::maximum_with_nan_propogation<ElementCompute> max;
|
||||
cutlass::epilogue::thread::ReLu<ElementCompute> act;
|
||||
|
||||
ElementScalingFactor d_scale = kScaleOutput ? scale_D.host_view().at(origin) : ElementScalingFactor(1.);
|
||||
|
||||
for (int m = 0; m < options.problem_size.m(); ++m) {
|
||||
for (int n = 0; n < options.problem_size.n(); ++n) {
|
||||
ElementCompute intermediate = cvt_accum_to_compute(tmp_D.host_view().at({m, n}));
|
||||
ElementCompute bias = cvt_c_to_compute(tensor_Vector.host_view().at({0, n}));
|
||||
ElementCompute aux = intermediate + bias;
|
||||
ElementCompute d = act(aux);
|
||||
tmp_abs_max_Aux = max(abs(aux), tmp_abs_max_Aux);
|
||||
tmp_abs_max_D = max(abs(d), tmp_abs_max_D);
|
||||
reference_D.host_view().at({m, n}) = cvt_compute_to_d(d * d_scale);
|
||||
|
||||
if (kScaleAux) {
|
||||
reference_Aux.host_view().at({m, n}) = cvt_compute_to_aux(aux * scale_Aux.host_view().at(origin));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (kScaleAux) {
|
||||
reference_abs_max_Aux.host_view().at(origin) = cvt_compute_to_accum(tmp_abs_max_Aux);
|
||||
}
|
||||
|
||||
if (kScaleOutput) {
|
||||
reference_abs_max_D.host_view().at(origin) = cvt_compute_to_accum(tmp_abs_max_D);
|
||||
}
|
||||
|
||||
return compare_reference(options);
|
||||
}
|
||||
|
||||
/// Returns true if the CUDA device is sufficient to execute the kernel.
|
||||
bool sufficient() const {
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) {
|
||||
std::cerr << "This example requires CUDA 12.4 or greater." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage);
|
||||
|
||||
cudaDeviceProp properties;
|
||||
int device_idx;
|
||||
cudaError_t result = cudaGetDevice(&device_idx);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaGetDevice() failed with error: " << cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
result = cudaGetDeviceProperties(&properties, device_idx);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() failed with error: " << cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (properties.major < 8 || (properties.major == 8 && properties.minor < 9)) {
|
||||
std::cerr << "CUTLASS's Ada FP8 GEMM example requires a device of compute capability 89 or higher.\n" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (properties.sharedMemPerBlockOptin < smem_size) {
|
||||
std::cerr << "Insufficient shared memory. Need " << smem_size
|
||||
<< ", but device only has " << properties.sharedMemPerBlockOptin << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Executes one test
|
||||
bool run(Options& options)
|
||||
{
|
||||
|
||||
// Waive test if insufficient CUDA device
|
||||
if (!sufficient()) {
|
||||
std::cerr << "Insufficient resources to run the kernel." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
this->initialize(options);
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
|
||||
typename Gemm::EpilogueOutputOp::Params::ActivationParams activation_params{
|
||||
ElementCompute(options.alpha),
|
||||
ElementCompute(options.beta)
|
||||
};
|
||||
typename Gemm::EpilogueOutputOp::Params epilogue_params{
|
||||
activation_params,
|
||||
scale_A.device_data(),
|
||||
scale_B.device_data(),
|
||||
scale_C.device_data(),
|
||||
scale_D.device_data(),
|
||||
scale_Aux.device_data(),
|
||||
abs_max_Aux.device_data(),
|
||||
abs_max_D.device_data()
|
||||
};
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
options.problem_size,
|
||||
/* batch_count = */ 1,
|
||||
epilogue_params,
|
||||
tensor_A.device_data(),
|
||||
tensor_B.device_data(),
|
||||
tensor_C.device_data(),
|
||||
tensor_D.device_data(),
|
||||
tensor_Aux.device_data(),
|
||||
tensor_Vector.device_data(),
|
||||
options.problem_size.m() * options.problem_size.k(),
|
||||
options.problem_size.n() * options.problem_size.k(),
|
||||
options.problem_size.m() * options.problem_size.n(),
|
||||
options.problem_size.m() * options.problem_size.n(),
|
||||
(int)options.problem_size.m(), // Batch stride vector
|
||||
tensor_A.layout().stride(0),
|
||||
tensor_B.layout().stride(0),
|
||||
tensor_C.layout().stride(0),
|
||||
tensor_D.layout().stride(0),
|
||||
(int64_t)0 // Leading dimension of vector. This must be 0
|
||||
};
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Gemm::can_implement() failed" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Gemm::initialize() failed" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
|
||||
status = gemm_op();
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Gemm::run() failed" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
cudaError_t cuda_error = cudaDeviceSynchronize();
|
||||
if (cuda_error != cudaSuccess) {
|
||||
std::cerr << "CUDA error: " << cudaGetErrorString(cuda_error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
|
||||
bool passed = true;
|
||||
if (options.reference_check) {
|
||||
passed &= this->verify(options);
|
||||
} else {
|
||||
std::cout << "Skipped reference check" << std::endl;
|
||||
}
|
||||
|
||||
//
|
||||
// Warm up
|
||||
//
|
||||
|
||||
for (int i = 0; i < options.warmup_iterations; ++i) {
|
||||
gemm_op();
|
||||
}
|
||||
|
||||
//
|
||||
// Profile
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
cudaError_t error;
|
||||
for (auto & event : events) {
|
||||
error = cudaEventCreate(&event);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of GEMM operations
|
||||
error = cudaEventRecord(events[0]);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
gemm_op();
|
||||
}
|
||||
|
||||
// Record an event when the GEMM operations have been launched.
|
||||
error = cudaEventRecord(events[1]);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
error = cudaEventSynchronize(events[1]);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
runtime_ms = runtime_ms / float(options.iterations);
|
||||
float gflops = options.gflops(runtime_ms / 1000.0f);
|
||||
|
||||
std::cout << "Problem size: " << options.problem_size.m() << 'x' << options.problem_size.n() << 'x' << options.problem_size.k() << std::endl;
|
||||
std::cout << "Runtime (ms): " << runtime_ms << std::endl;
|
||||
std::cout << "GFLOPs/sec: " << gflops << std::endl;
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const** argv) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4) ||
|
||||
(props.major != 8 && props.minor != 9)) {
|
||||
|
||||
//
|
||||
// This example requires an NVIDIA Ada-architecture GPU.
|
||||
//
|
||||
|
||||
std::cout
|
||||
<< "CUTLASS's FP8 SM89 example requires a GPU of NVIDIA's Ada architecture "
|
||||
<< "and CUDA toolkit version 12.4 or later.\n";
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, argv);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::cout << "Running GEMM with staged accumulation (OpMultiplyAdd)" << std::endl;
|
||||
std::cout << "=====================================================" << std::endl;
|
||||
TestbedRunner<Gemm_<cutlass::arch::OpMultiplyAdd>> testbed_staged_accum;
|
||||
bool passed = testbed_staged_accum.run(options);
|
||||
|
||||
if (passed) {
|
||||
std::cout << "Passed" << std::endl;
|
||||
} else {
|
||||
std::cout << "Failed" << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "\nRunning GEMM with fast accumulation (OpMultiplyAddFastAccum)" << std::endl;
|
||||
std::cout << "============================================================" << std::endl;
|
||||
TestbedRunner<Gemm_<cutlass::arch::OpMultiplyAddFastAccum>> testbed_fast_accum;
|
||||
passed = testbed_fast_accum.run(options);
|
||||
|
||||
if (passed) {
|
||||
std::cout << "Passed" << std::endl;
|
||||
} else {
|
||||
std::cout << "Failed" << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
40
examples/59_ampere_gather_scatter_conv/CMakeLists.txt
Normal file
40
examples/59_ampere_gather_scatter_conv/CMakeLists.txt
Normal file
@ -0,0 +1,40 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if (NOT MSVC)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
59_ampere_gather_scatter_conv
|
||||
ampere_gather_scatter_conv.cu
|
||||
)
|
||||
|
||||
if (CUTLASS_ENABLE_OPENMP_TESTS AND OpenMP_CXX_FOUND)
|
||||
target_link_libraries(59_ampere_gather_scatter_conv PRIVATE OpenMP::OpenMP_CXX)
|
||||
endif()
|
||||
|
||||
endif()
|
||||
209
examples/59_ampere_gather_scatter_conv/README.md
Normal file
209
examples/59_ampere_gather_scatter_conv/README.md
Normal file
@ -0,0 +1,209 @@
|
||||
# Example 59: Ampere gather/scatter convolution
|
||||
|
||||
CuTe and CUTLASS 3.x based Ampere convolution forward propagation kernel capable of operating on both affine and gather/scatter tensors.
|
||||
|
||||
Example executions:
|
||||
```sh
|
||||
./59_ampere_gather_scatter_conv
|
||||
./59_ampere_gather_scatter_conv --n=108
|
||||
./59_ampere_gather_scatter_conv --n=4096 --i=1
|
||||
./59_ampere_gather_scatter_conv --n=1080 --i=1000
|
||||
./59_ampere_gather_scatter_conv --n=131072 --i=1000 --no-check
|
||||
```
|
||||
|
||||
This example demonstrates a few super cool features of CUTLASS and CuTe. It shows off
|
||||
1. A dense conv 3D fprop kernel written as a single file ...
|
||||
2. ... that leverages off-the-shelf CUTLASS collectives to show how custom kernels can use collectives ...
|
||||
3. ... and uses the exact same templated kernel to also stamp out a gather/scatter 3D fprop conv ...
|
||||
4. ... while getting near peak performance of the Ampere class tensor core on Ampere and Ada GPUs ...
|
||||
5. ... by using static cute shapes and strides in case problem shapes are known at compile time.
|
||||
|
||||
## A dense conv 3D fprop kernel written in CUTLASS 3.x and CuTe
|
||||
|
||||
The most common strategy for implementing high performance convolution kernels on the GPU is to transform
|
||||
the activation tensor in such a way that we can perform the computation as a GEMM. This is called the
|
||||
image to column (im2col) transformation. [CUTLASS 2.x implementation of im2col based convolutions is
|
||||
documented separately](../../media/docs/implicit_gemm_convolution.md), and here we consider a fresh approach for CuTe.
|
||||
|
||||
A 3D convolution has the following input tensors:
|
||||
- Activation tensor (Act): `((N,(D,H,W)), (C,(1,1,1)))`
|
||||
- Filter tensor (Flt): `( K, (C,(T,R,S)))`
|
||||
- Output tensor (Out): `((N,(Z,P,Q)), K )`
|
||||
|
||||
Where
|
||||
- N := number of images
|
||||
- DHW := spatial dimensions of the activation tensor
|
||||
- C := channel dimension of the activation tensor
|
||||
- K := channel dimension of the filter and output tensor
|
||||
- TRS := spoke dimensions of the filter tensor
|
||||
- ZPQ := spatial dimensions of the output tensor
|
||||
|
||||
As is evident in the tensor shapes, these cannot be issued to a GEMM just yet, since there is no
|
||||
logical M, N, and K modes we can group the tensor modes into.
|
||||
|
||||
Notice that every spoke of the filter tensor (TRS) will be applied to some (offset) view of the
|
||||
activation tensor, thus expanding the logical size of the activation tensor.
|
||||
Additionally, a similar logical transform of the spatial dimensions can be encoded as a function of the
|
||||
padding, dilations, traversal strides, and filter spokes. This gets us to our im2col transform:
|
||||
|
||||
im2col transform affects the component shapes/strides of the activation tensor in the following way:
|
||||
- ZPQ Shape : changes DHW domain with formula `(1 + (DHW + pad - (((TRS-1) * dilation) + 1)) / traversal_stride)`
|
||||
- TRS Shape : TRS domain instead of `(1,1,1)`
|
||||
- ZPQ Strides : Original DHW strides get `elem_scale()`-ed by traversal strides DHW
|
||||
- TRS Strides : Original DHW strides get `elem_scale()`-ed by dilation DHW
|
||||
|
||||
With this transform applied, we end up with a set of input and output tensors that
|
||||
are logically consistent in their MNK dimensions, thus allowing us to dispatch to a GEMM.
|
||||
im2col activation layout: ((N,(Z,P,Q)), (C,(T,R,S))) // logical (M,K)
|
||||
filter layout : ( K, (C,(T,R,S))) // logical (N,K)
|
||||
output layout : ((N,(Z,P,Q)), K ) // logical (M,N)
|
||||
|
||||
CuTe's layout representation and algebra make these folded tensors easy to represent and manipulate.
|
||||
This is most evident in the reference check code used in this example:
|
||||
|
||||
```cpp
|
||||
for (size_t logical_m = 0; logical_m < size<0>(mOutputRef); ++logical_m) {
|
||||
for (size_t logical_n = 0; logical_n < size<1>(mOutputRef); ++logical_n) {
|
||||
auto accumulator = float(0);
|
||||
for (size_t logical_k = 0; logical_k < size<1>(mStencil); ++logical_k) {
|
||||
accumulator += mStencil(logical_m, logical_k) * mActivation(logical_n, logical_k);
|
||||
}
|
||||
mOutputRef(logical_m, logical_n) = accumulator;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Which succinctly demonstrates how im2col transform allows us to implement convolutions
|
||||
as GEMMs with special layout transformations on the input tensor.
|
||||
|
||||
Note: in the example kernel's implementation we treat activations as the B tensor
|
||||
and filter as the A tensor, thus making their logical dimensions NK and MK respectively.
|
||||
|
||||
## Leveraging CUTLASS collectives off the shelf in a custom kernel
|
||||
|
||||
Now that we have transformed our problem in such a way that allows us to dispatch to a GEMM,
|
||||
we can reuse much of the machinery CUTLASS offers to implement this forward pass convolution
|
||||
operator. CUTLASS decomposes these "moving parts" of GPU linear algebra into reusable,
|
||||
modular software components abstracted by C++ template classes. This example
|
||||
demonstrates how some of the lower layers of the hierarchy can be re-used for custom kernels
|
||||
by writing a custom kernel for convolution that re-uses the Ampere/Ada GEMM collectives
|
||||
from CUTLASS 3.
|
||||
|
||||
A kernel author is free to compose their custom components with any of the existing templates
|
||||
in the CUTLASS hierarchy to leverage existing high performance implementations from the CUTLASS
|
||||
team. In this example, we write a custom kernel layer and compose with an existing collective.
|
||||
However, any of the CUTLASS kernels can be composed with bespoke collectives if the desired
|
||||
customization is a mainloop or epilogue fusion without changes to the grid planning,
|
||||
tile scheduling, load balancing, or thread marshalling.
|
||||
|
||||
## Implementing gather/scatter and dense convolution with the same kernel
|
||||
|
||||
Functionality and correctness of the implemented kernel, as a virtue of using
|
||||
CuTe and off the shelf CUTLASS collectives, only relies on the logical consistency of
|
||||
the layouts of input and output tensors. This means that we can freely change how
|
||||
the logical coordinates of the tensors map into the index space, and even how these dereferences
|
||||
happen. [CUTLASS example 52](../52_hopper_gather_scatter_fusion/) demonstrates this by implementing a custom stride that
|
||||
supports indexed indirection for tensor data accesses. This allows for example 52
|
||||
to implement a GEMM where inputs are gathered and output is scattered based on an index buffer.
|
||||
|
||||
We re-use the same custom stride utilities in this example to implement a convolution kernel
|
||||
that gathers along the NDHW dimensions of the activation tensor and scatters the output along the
|
||||
NZPQ dimensions of the output tensor, treating the channel dimensions as the dense vectors.
|
||||
|
||||
Our dense affine im2col transformed activation tensor:
|
||||
|
||||
```cpp
|
||||
// im2col transformed activation layout: ((nzpq), (ctrs)) => idx
|
||||
auto xformed_act_layout = make_layout(
|
||||
make_shape (make_shape ( N, Z, P, Q), make_shape ( C, T, R, S)),
|
||||
make_stride(make_stride(D*H*W*C, H*W*C, W*C, C), make_stride(_1{}, H*W*C, W*C, C)));
|
||||
```
|
||||
|
||||
now becomes a composed layout that uses `IndexedGather`:
|
||||
|
||||
```cpp
|
||||
// Inner layout of the composition:
|
||||
// ((nzpq), (csrt)) => (idx_buffer_idx, dense_offset)
|
||||
auto EG = E<0>{}; // Gather basis (1,0) (idx_buffer_idx)
|
||||
auto EC = E<1>{}; // Contiguous basis (0,1) (dense_offset)
|
||||
auto xformed_act_logical_inner = make_layout(
|
||||
make_shape (make_shape ( N, Z, P, Q), make_shape ( C, T, R, S)),
|
||||
make_stride(make_stride(D*H*W*EG, H*W*EG, W*EG, EG), make_stride(EC, H*W*EG, W*EG, EG)));
|
||||
|
||||
// Outer layout of the composition:
|
||||
// (idx_buffer_idx, dense_offset) => idx
|
||||
// IndexedGather obtains idx by applying (gmem_base_ptr + gather_idx_buf[idx_buffer_idx] + dense_offset)
|
||||
auto xformed_act_gather_outer = make_layout(
|
||||
make_shape(_1{},_1{}),
|
||||
make_stride(CustomStride{IndexedGather{gather_idx_buf}, C}, _1{}));
|
||||
|
||||
// Compose the inner and outer layouts
|
||||
// ((nzpq), (ctrs)) => idx
|
||||
auto xformed_act_composed_layout = composition(
|
||||
xformed_act_gather_outer,
|
||||
make_arithmetic_tuple(_0{}, _0{}),
|
||||
xformed_act_logical_inner);
|
||||
```
|
||||
|
||||
Here, we create a composed layout whose inner layout has the same logical MK shape as earlier,
|
||||
but with an outer layout that uses the custom strides with an index buffer to access memory with
|
||||
indirections. A custom stride requires two inputs to compute the index that a certain coordinate maps to:
|
||||
the index buffer offset and the dense offset into the vector. This entails that our inner layout
|
||||
(the one with the logical MK shape) has a rank-2 codomain `(idx_buffer_idx, dense_offset)`.
|
||||
We can set up such a layout with scaled basis strides, which allow us to map a domain onto a
|
||||
codomain with multiple orthogonal bases. The two codomain basis are the
|
||||
index buffer offsets (rank 0 basis), and the dense vector offsets (rank 1 basis).
|
||||
A similar composed layout is set up for the output scatter tensor.
|
||||
|
||||
This tensor still has a logical MK shape and is backed by a CuTe layout, which means we can still
|
||||
tile, partition, and otherwise manipulate it with CuTe's layout algebra in the same way we would any
|
||||
other tensor. Substituting the activation tensor's affine layout for this gather layout requires
|
||||
no changes to the implementation of the kernel whatsoever. Everything composes. This example
|
||||
stamps out a dense 3D convolution as well as gather/scatter 3D convolution using the same kernel template,
|
||||
with the only difference between them being the layouts of the input and output tensors.
|
||||
|
||||
Convolutions are just a special case of tensor contractions, and as [example 51](../51_hopper_gett)
|
||||
demonstrates, the exact same collective used in this example can also be used to implement arbitrary GETTs.
|
||||
Of course, this also means that the same kernel can implement gather/scatter GETTs as well!
|
||||
|
||||
This demonstrates the composition power of not just CuTe, but also CUTLASS 3's two level
|
||||
micro kernel abstraction. A single highly tuned temporal micro-kernel (collective) can be implemented once
|
||||
and applied to compute dense GETTs, gather/scatter GETTs, dense convolutions, and gather/scatter convolutions.
|
||||
|
||||
## Peak performance on Ampere and Ada GPUs by leveraging domain specific knowledge
|
||||
|
||||
Often, when implementing custom kernels, a user has more knowledge of the problem domain that can be
|
||||
exploited to deliver higher performance than otherwise could be through general kernels. In this example
|
||||
we presume that the shape of each of the images (DHWC dimensions) as well as the filter (TRS) are available
|
||||
a-priori and that the tile shape evenly divides the problem. Number of images (N) is still left as a runtime
|
||||
parameter.
|
||||
|
||||
Knowing the extents of our tensors at compile time allows us to encode them as static cute shapes rather than
|
||||
a dynamic problem shape, resulting in the elimination of most of the index computation instructions such as
|
||||
expensive div/mods. Knowing that the problem shape is divisible by the tile shape allows us to use the
|
||||
Ampere collective that does not perform predication on global memory loads, further reducing overheads
|
||||
and allowing us to achieve near peak performance on RTX Ampere and Ada GPUs.
|
||||
|
||||
Running this example on an RTX 3080Ti prints the following performance numbers (some output culled for brevity):
|
||||
|
||||
```
|
||||
$> ./examples/59_ampere_gather_scatter_conv/59_ampere_gather_scatter_conv --n=131072 --i=128 --no-check
|
||||
Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors.
|
||||
|
||||
Allocating tensors ... done.
|
||||
Initializing data ... done.
|
||||
Initializing gather/scatter index buffers ... done.
|
||||
|
||||
Running dense fprop kernel
|
||||
Conv TFLOP count = 0.927713
|
||||
Conv dense perf: 31.027376ms | TFLOP/s = 29.899819
|
||||
|
||||
Running gather/scatter fprop kernel
|
||||
Conv TFLOP count = 0.927713
|
||||
Conv gather/scatter perf: 28.973721ms | TFLOP/s = 32.019117
|
||||
```
|
||||
|
||||
With this in mind, this example kernel has the following limitations:
|
||||
- This example kernel only supports dynamic image count, all other conv problem shape must be defined as `cute::Constant<>`s
|
||||
- Problem shapes (including dynamic image count `N`) must be evenly divisible by the tile shape
|
||||
- It does not perform fp32->tf32 numeric conversion, gmem inputs must be rounded to tf32 already
|
||||
320
examples/59_ampere_gather_scatter_conv/ampere_conv_kernel.h
Normal file
320
examples/59_ampere_gather_scatter_conv/ampere_conv_kernel.h
Normal file
@ -0,0 +1,320 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/atom/copy_atom.hpp"
|
||||
#include <random>
|
||||
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_mma.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
struct AmpereUnpredicatedFprop {
|
||||
//
|
||||
// Static config for conv problem shape
|
||||
//
|
||||
using D = _6;
|
||||
using H = _4;
|
||||
using W = _4;
|
||||
|
||||
using T = _3;
|
||||
using R = _3;
|
||||
using S = _3;
|
||||
|
||||
using Z = _4;
|
||||
using P = _2;
|
||||
using Q = _2;
|
||||
|
||||
using C = _64;
|
||||
using K = _128;
|
||||
|
||||
// Tiler config
|
||||
using Tiler_K = decltype(cute::min(K{}, _128{}));
|
||||
using Tiler_C = decltype(cute::min(C{}, _32{}));
|
||||
using Tiler_N = _4;
|
||||
using TileM = Tiler_K;
|
||||
using TileN = Shape<Tiler_N, Z, P, Q>;
|
||||
using TileK = Shape<Tiler_C,_1,_1,_1>;
|
||||
using PIPE = _3;
|
||||
using TilerFlt = Shape<TileM, TileK>;
|
||||
using TilerAct = Shape<TileN, TileK>;
|
||||
using TilerOut = Shape<TileM, TileN>;
|
||||
|
||||
using TileSizeM = Int<size(TileM{})>;
|
||||
using TileSizeN = Int<size(TileN{})>;
|
||||
using TileSizeK = Int<size(TileK{})>;
|
||||
static constexpr int Stages = PIPE::value;
|
||||
|
||||
using ElementFlt = tfloat32_t;
|
||||
using ElementAct = tfloat32_t;
|
||||
using ElementOut = float;
|
||||
|
||||
using TiledMma = TiledMMA<
|
||||
MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>,
|
||||
Layout<Shape<_2,_2,_1>>,
|
||||
Tile<_32,_32,Underscore>>;
|
||||
|
||||
static constexpr int MaxThreadsPerBlock = size(TiledMma{});
|
||||
static constexpr int MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
union SharedStorage {
|
||||
struct {
|
||||
ElementFlt sAMatrix[size(TileM{}) * size(TileK{}) * size(PIPE{})];
|
||||
ElementAct sBMatrix[size(TileN{}) * size(TileK{}) * size(PIPE{})];
|
||||
} mainloop;
|
||||
|
||||
struct {
|
||||
ElementOut sCMatrix[size(TileM{}) * size(TileN{})];
|
||||
} epilogue;
|
||||
};
|
||||
|
||||
//
|
||||
// Stencil tensor
|
||||
//
|
||||
|
||||
using GmemLayoutFlt = decltype(make_ordered_layout(
|
||||
Shape< K, Shape< C, T, R, S>>{},
|
||||
tuple<_4, tuple<_0,_3,_2,_1>>{}));
|
||||
|
||||
// We have 64 elements * 32b each in the major mode that we can vectorize
|
||||
// Max vector size is 128b, so lay 16 threads along the major mode with a vector size of 4
|
||||
// Rest along the minor mode
|
||||
using GmemTiledCopyFlt = decltype(make_tiled_copy(
|
||||
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, ElementFlt>{},
|
||||
Layout<Shape <_16, _8>,
|
||||
Stride< _8, _1>>{},
|
||||
Layout<Shape < _1, _4>>{}));
|
||||
|
||||
// Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses
|
||||
// using SmemLayoutFlt = decltype(
|
||||
// composition(Swizzle<3,2,3>{},
|
||||
// make_ordered_layout(
|
||||
// Shape<TileSizeM,TileSizeK,PIPE>{},
|
||||
// tuple< _1, _0, _2>{})));
|
||||
|
||||
using SmemLayoutAtomFlt = decltype(
|
||||
composition(Swizzle<1,2,3>{},
|
||||
Layout<Shape <_8,Shape <_4, _2>>,
|
||||
Stride<_4,Stride<_1,_32>>>{}));
|
||||
|
||||
using SmemCopyAtomFlt = Copy_Atom<SM75_U32x4_LDSM_N, ElementFlt>;
|
||||
|
||||
//
|
||||
// Activation tensor
|
||||
//
|
||||
|
||||
// Activation tensor is major in the contraction mode, so vectorize that mode first
|
||||
// Then lay out the rest of the threads along the other mode
|
||||
using GmemTiledCopyAct = decltype(make_tiled_copy(
|
||||
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, ElementAct>{},
|
||||
Layout<Shape <_16, _8>,
|
||||
Stride< _8, _1>>{},
|
||||
Layout<Shape < _1, _4>>{}));
|
||||
|
||||
// Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses
|
||||
// using SmemLayoutAct = decltype(
|
||||
// composition(Swizzle<3,2,3>{},
|
||||
// make_ordered_layout(
|
||||
// Shape<TileSizeN,TileSizeK,PIPE>{},
|
||||
// tuple< _1, _0, _2>{})));
|
||||
|
||||
using SmemLayoutAtomAct = decltype(
|
||||
composition(Swizzle<1,2,3>{},
|
||||
Layout<Shape <_8,Shape <_4, _2>>,
|
||||
Stride<_4,Stride<_1,_32>>>{}));
|
||||
|
||||
using SmemCopyAtomAct = Copy_Atom<SM75_U32x4_LDSM_N, ElementAct>;
|
||||
|
||||
//
|
||||
// Output tensor
|
||||
//
|
||||
|
||||
using GmemTiledCopyOut = decltype(make_tiled_copy(
|
||||
Copy_Atom<UniversalCopy<uint128_t>, ElementAct>{},
|
||||
Layout<Shape <_8, _16>,
|
||||
Stride<_1, _8>>{},
|
||||
Layout<Shape <_4, _1>>{}));
|
||||
|
||||
using SmemCopyAtomOut = Copy_Atom<UniversalCopy<uint32_t>, ElementOut>;
|
||||
|
||||
// This can be optimized to make accesses BCF, but we use a col-major layout here to show off composability
|
||||
using SmemLayoutOut = Layout<Shape<TileSizeM, TileSizeN>>;
|
||||
|
||||
//
|
||||
// Conv functor
|
||||
//
|
||||
template <class EngineFlt, class TensorActivation, class TensorOutput>
|
||||
void __device__
|
||||
operator()(cute::Tensor<EngineFlt, GmemLayoutFlt> mFlt, // ( K, (C,T,R,S))
|
||||
TensorActivation mAct, // ((N,Z,P,Q), (C,T,R,S))
|
||||
TensorOutput mOut, // ( K, (N,Z,P,Q))
|
||||
char* smem_buf) const {
|
||||
using namespace cute;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveMma<
|
||||
cutlass::gemm::MainloopSm80CpAsyncUnpredicated<PIPE::value>,
|
||||
Shape<TileM,TileN,TileK>,
|
||||
ElementFlt,
|
||||
Underscore, // Ignore the stride, we are passing full cute::Tensor to operator()
|
||||
ElementAct,
|
||||
Underscore, // Ignore the stride, we are passing full cute::Tensor to operator()
|
||||
TiledMma,
|
||||
GmemTiledCopyFlt,
|
||||
SmemLayoutAtomFlt,
|
||||
SmemCopyAtomFlt,
|
||||
cute::identity,
|
||||
GmemTiledCopyAct,
|
||||
SmemLayoutAtomAct,
|
||||
SmemCopyAtomAct,
|
||||
cute::identity>;
|
||||
|
||||
TiledMma tiled_mma;
|
||||
Tensor accum = partition_fragment_C(tiled_mma, TilerOut{});
|
||||
clear(accum);
|
||||
|
||||
// Set up tensors
|
||||
// NOTE: blockIdx.x projects onto act-NDHW mode, y along the flt-K mode for the sake of higher dynamic range in NDHW
|
||||
Tensor gA_mk = local_tile(mFlt, TilerFlt{}, make_coord(_,_)); // (BLK_M,BLK_K,m',k')
|
||||
Tensor gB_nk = local_tile(mAct, TilerAct{}, make_coord(_,_)); // (BLK_N,BLK_K,n',_1)
|
||||
Tensor gC_mn = local_tile(mOut, TilerOut{}, make_coord(_,_)); // (BLK_M,BLK_N,m',n')
|
||||
|
||||
// Compute m_coord and n_coord with their post-tiled shapes
|
||||
auto m_coord = idx2crd(int(blockIdx.y), shape<2>(gA_mk));
|
||||
auto n_coord = idx2crd(int(blockIdx.x), shape<2>(gB_nk));
|
||||
Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k')
|
||||
Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,_1)
|
||||
Tensor gC = gC_mn(_,_,m_coord,n_coord); // (BLK_M,BLK_N)
|
||||
|
||||
auto k_tile_iter = cute::make_coord_iterator(size<2>(gA));
|
||||
int k_tile_count = size<2>(gA);
|
||||
|
||||
CollectiveMainloop collective_mma;
|
||||
collective_mma(
|
||||
accum,
|
||||
gA,
|
||||
gB,
|
||||
accum,
|
||||
k_tile_iter, k_tile_count,
|
||||
Underscore{}, // no residue since we do not support predication
|
||||
threadIdx.x,
|
||||
smem_buf);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
||||
Tensor sC = make_tensor(make_smem_ptr(&storage.epilogue.sCMatrix[0]), SmemLayoutOut{});
|
||||
|
||||
auto smem_tiled_copy_C = make_tiled_copy_C(SmemCopyAtomOut{}, tiled_mma);
|
||||
auto smem_thr_copy_C = smem_tiled_copy_C.get_slice(threadIdx.x);
|
||||
auto tCrC = smem_thr_copy_C.retile_S(accum);
|
||||
auto tCsC = smem_thr_copy_C.partition_D(sC);
|
||||
copy(smem_tiled_copy_C, tCrC, tCsC);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
GmemTiledCopyOut gmem_tiled_copy_C;
|
||||
auto gmem_thr_copy_C = gmem_tiled_copy_C.get_slice(threadIdx.x);
|
||||
auto tDsC = gmem_thr_copy_C.partition_S(sC);
|
||||
auto tDgC = gmem_thr_copy_C.partition_D(gC);
|
||||
copy(gmem_tiled_copy_C, tDsC, tDgC);
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("mAct = "); print(mAct); print('\n');
|
||||
print("mFlt = "); print(mFlt); print('\n');
|
||||
print("mOut = "); print(mOut); print('\n');
|
||||
print("gA = "); print(gA); print('\n');
|
||||
print("gB = "); print(gB); print('\n');
|
||||
print("gC = "); print(gC); print('\n');
|
||||
print("sA = "); print(sA.layout()); print('\n');
|
||||
print("sB = "); print(sB.layout()); print('\n');
|
||||
print("sC = "); print(sC.layout()); print('\n');
|
||||
print("tAgA = "); print(tAgA.layout()); print('\n');
|
||||
print("tBgB = "); print(tBgB.layout()); print('\n');
|
||||
print("tAsA = "); print(tAsA.layout()); print('\n');
|
||||
print("tBsB = "); print(tBsB.layout()); print('\n');
|
||||
print("tCsA = "); print(tCsA.layout()); print('\n');
|
||||
print("tCsB = "); print(tCsB.layout()); print('\n');
|
||||
print("tCrC = "); print(tCrC.layout()); print('\n');
|
||||
print("tCsC = "); print(tCsC.layout()); print('\n');
|
||||
print("tDsC = "); print(tDsC.layout()); print('\n');
|
||||
print("tDgC = "); print(tDgC.layout()); print('\n');
|
||||
print("gmem tiled copy A = "); print(gmem_tiled_copy_A); print('\n');
|
||||
print("gmem tiled copy B = "); print(gmem_tiled_copy_B); print('\n');
|
||||
print("gmem tiled copy C = "); print(gmem_tiled_copy_C); print('\n');
|
||||
print("k_tile_count = "); print(size<2>(gA)); print('\n');
|
||||
print("k_tile_iter = "); print(*k_tile_iter); print('\n');
|
||||
print("K_BLOCK_MAX = "); print(K_BLOCK_MAX); print('\n');
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <class TensorFlt, class TensorAct, class TensorOut>
|
||||
inline int
|
||||
fprop_reference(
|
||||
TensorFlt mStencil, // Logical MK: ( K, (C,T,R,S))
|
||||
TensorAct mActivation, // Logical NK: ((N,Z,P,Q), (C,T,R,S))
|
||||
TensorOut mOutput, // Logical MN: ( K, (N,Z,P,Q))
|
||||
TensorOut mOutputRef) {
|
||||
int32_t N = size<1,0>(mOutputRef);
|
||||
int32_t Z = size<1,1>(mOutputRef);
|
||||
int32_t P = size<1,2>(mOutputRef);
|
||||
int32_t Q = size<1,3>(mOutputRef);
|
||||
int32_t T = size<1,3>(mStencil);
|
||||
int32_t R = size<1,2>(mStencil);
|
||||
int32_t S = size<1,1>(mStencil);
|
||||
int32_t C = size<1,0>(mStencil);
|
||||
|
||||
size_t K = static_cast<size_t>(size<0>(mOutputRef));
|
||||
size_t NZPQ = static_cast<size_t>(size<1>(mOutputRef));
|
||||
size_t CTRS = static_cast<size_t>(size<1>(mStencil));
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (size_t logical_m = 0; logical_m < K; ++logical_m) {
|
||||
for (size_t logical_n = 0; logical_n < NZPQ; ++logical_n) {
|
||||
auto accumulator = float(0);
|
||||
for (size_t logical_k = 0; logical_k < CTRS; ++logical_k) {
|
||||
accumulator += mStencil(logical_m, logical_k) * mActivation(logical_n, logical_k);
|
||||
}
|
||||
mOutputRef(logical_m, logical_n) = accumulator;
|
||||
}
|
||||
}
|
||||
|
||||
return print_relative_error(mOutput, mOutputRef, /*print_verbose*/ false, /*print_error*/ true, /*error_margin*/ 0.01);
|
||||
}
|
||||
@ -0,0 +1,392 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Example demonstrating CuTe and CUTLASS 3.x based Ampere convolution forward propogation kernel
|
||||
capable of operating on both affine and gather/scatter tensors.
|
||||
|
||||
This example demonstartes a few super cool features of CUTLASS and CuTe. It shows off
|
||||
1. A dense conv 3D fprop kernel written as a single file ...
|
||||
2. ... that leverages off the shelf CUTLASS collectives to show how custom kernels can use collectives ...
|
||||
3. ... and uses the exact same templated kernel to also stamp out a gather/scatter 3D fprop conv ...
|
||||
4. ... while getting near peak performance of the Ampere class tensor core on Ampere and Ada GPUs ...
|
||||
5. ... by using static cute shapes and strides in case problem shapes are known at compile time.
|
||||
|
||||
Full documentation for this example can be found within the README.md file in this directory.
|
||||
|
||||
Example executions:
|
||||
./59_ampere_gather_scatter_conv
|
||||
./59_ampere_gather_scatter_conv --n=108
|
||||
./59_ampere_gather_scatter_conv --n=4096 --i=1
|
||||
./59_ampere_gather_scatter_conv --n=1080 --i=1000
|
||||
./59_ampere_gather_scatter_conv --n=131072 --i=1000 --no-check
|
||||
*/
|
||||
|
||||
#include <thrust/sequence.h>
|
||||
#include <thrust/universal_vector.h>
|
||||
|
||||
#include "ampere_conv_kernel.h"
|
||||
#include "gather_tensor.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
|
||||
bool check_cuda_result(cudaError_t code, const char* file, int line) {
|
||||
if (code == cudaSuccess) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::cerr << "CUDA error at (" << file << "," << line << ")\n\t" << unsigned(code) << " -- " << cudaGetErrorString(code) << "\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
#define CHECK_CUDA(code) (check_cuda_result(code, __FILE__, __LINE__))
|
||||
|
||||
using namespace cute;
|
||||
using example::IndexedGather;
|
||||
using example::CustomStride;
|
||||
|
||||
template<class Operator, class FilterTensor, class ActivationTensor, class OutputTensor>
|
||||
__global__
|
||||
__launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor)
|
||||
void kernel_entrypoint(FilterTensor mFlt, ActivationTensor mAct, OutputTensor mOut) {
|
||||
extern __shared__ char smem_buf[];
|
||||
Operator op;
|
||||
op(mFlt, mAct, mOut, smem_buf);
|
||||
}
|
||||
|
||||
int ampere_dense_conv_fprop(
|
||||
int num_images,
|
||||
float* activations,
|
||||
float* filter,
|
||||
float* output,
|
||||
float* output_ref,
|
||||
int num_iterations = 1,
|
||||
bool do_ref_check = true) {
|
||||
auto D = typename AmpereUnpredicatedFprop::D{};
|
||||
auto H = typename AmpereUnpredicatedFprop::H{};
|
||||
auto W = typename AmpereUnpredicatedFprop::W{};
|
||||
auto Z = typename AmpereUnpredicatedFprop::Z{};
|
||||
auto P = typename AmpereUnpredicatedFprop::P{};
|
||||
auto Q = typename AmpereUnpredicatedFprop::Q{};
|
||||
auto C = typename AmpereUnpredicatedFprop::C{};
|
||||
auto K = typename AmpereUnpredicatedFprop::K{};
|
||||
auto S = typename AmpereUnpredicatedFprop::S{};
|
||||
auto R = typename AmpereUnpredicatedFprop::R{};
|
||||
auto T = typename AmpereUnpredicatedFprop::T{};
|
||||
|
||||
int N = num_images; // dynamic
|
||||
if (num_images % int(typename AmpereUnpredicatedFprop::Tiler_N{}) != 0) {
|
||||
printf("ERROR: Input image count must be evenly divisible by CTA tiler N.\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Tensor Activation: (n,d,h,w,c)::(?,6,4,4,64):(6144,1536,384,64,1)
|
||||
auto activation_layout = make_layout(
|
||||
make_shape (make_shape ( N, D, H, W), make_shape ( C, _1{},_1{},_1{})),
|
||||
make_stride(make_stride(D*H*W*C, H*W*C, W*C, C), make_stride(_1{}, _0{},_0{},_0{})));
|
||||
|
||||
auto xformed_act_layout = make_layout(
|
||||
make_shape (make_shape(N, Z, P, Q), make_shape ( C, T, R, S)),
|
||||
make_stride(stride<0>(activation_layout), make_stride(_1{}, H*W*C, W*C, C)));
|
||||
|
||||
// Tensor Filter : (k,c,s,r,t)::(128,3,3,3,64):(1728,576,192,64,1)
|
||||
auto filter_layout = AmpereUnpredicatedFprop::GmemLayoutFlt{};
|
||||
|
||||
// Tensor Output : (n,z,p,q,k)::(?,4,2,2,128):(2048,1024,512,128,1)
|
||||
auto output_layout = make_ordered_layout(
|
||||
make_shape( K, make_shape( N, Z, P, Q)),
|
||||
make_tuple(_0{}, make_tuple(_4{},_3{},_2{},_1{})));
|
||||
|
||||
Tensor mActivation = make_tensor(make_gmem_ptr(activations), activation_layout);
|
||||
Tensor mXformedAct = make_tensor(make_gmem_ptr(activations), xformed_act_layout);
|
||||
Tensor mFilter = make_tensor(make_gmem_ptr(filter), filter_layout);
|
||||
Tensor mOutput = make_tensor(make_gmem_ptr(output), output_layout); // (K, (N,Z,P,Q))
|
||||
Tensor mOutputRef = make_tensor(make_gmem_ptr(output_ref), output_layout);
|
||||
|
||||
print("xformed act layout ((N,Z,P,Q), (C,T,R,S)) = "); print(xformed_act_layout); print("\n");
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
CHECK_CUDA(cudaEventCreate(&start));
|
||||
CHECK_CUDA(cudaEventCreate(&stop));
|
||||
|
||||
constexpr size_t smem_size = sizeof(typename AmpereUnpredicatedFprop::SharedStorage);
|
||||
Tensor gOutput_mn = zipped_divide(mOutput, typename AmpereUnpredicatedFprop::TilerOut{}); // ((BLK_M, BLK_N), (m', n'))
|
||||
dim3 lauch_grid {static_cast<uint32_t>(size<1,1>(gOutput_mn)), static_cast<uint32_t>(size<1,0>(gOutput_mn)), 1};
|
||||
|
||||
CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel_entrypoint<AmpereUnpredicatedFprop, decltype(mFilter), decltype(mXformedAct), decltype(mOutput)>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size));
|
||||
|
||||
CHECK_CUDA(cudaEventRecord(start));
|
||||
for (int i = 0; i < num_iterations; ++i) {
|
||||
kernel_entrypoint<AmpereUnpredicatedFprop, decltype(mFilter), decltype(mXformedAct), decltype(mOutput)>
|
||||
<<<lauch_grid, AmpereUnpredicatedFprop::MaxThreadsPerBlock, smem_size>>>(
|
||||
mFilter, mXformedAct, mOutput);
|
||||
}
|
||||
CHECK_CUDA(cudaEventRecord(stop));
|
||||
CHECK_CUDA(cudaEventSynchronize(stop));
|
||||
|
||||
float milliseconds = 0;
|
||||
cudaEventElapsedTime(&milliseconds, start, stop);
|
||||
milliseconds /= float(num_iterations);
|
||||
|
||||
double tflop_count = (2 * double(size<0>(xformed_act_layout)) * double(size(filter_layout))) / double(1e12);
|
||||
double tflops = tflop_count / (double(milliseconds) / double(1e3));
|
||||
|
||||
printf("Conv TFLOP count = %f\n", tflop_count);
|
||||
printf("Conv dense perf: %fms | TFLOP/s = %f\n", milliseconds, tflops);
|
||||
|
||||
if (do_ref_check) {
|
||||
printf("Running host reference check ...\n");
|
||||
return fprop_reference(mFilter, mXformedAct, mOutput, mOutputRef);
|
||||
}
|
||||
else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
int ampere_gather_scatter_conv_fprop(
|
||||
int num_images,
|
||||
float* activations,
|
||||
uint32_t *gather_idx_buf,
|
||||
float* filter,
|
||||
float* output,
|
||||
uint32_t *scatter_idx_buf,
|
||||
int num_iterations = 1) {
|
||||
auto D = typename AmpereUnpredicatedFprop::D{};
|
||||
auto H = typename AmpereUnpredicatedFprop::H{};
|
||||
auto W = typename AmpereUnpredicatedFprop::W{};
|
||||
auto Z = typename AmpereUnpredicatedFprop::Z{};
|
||||
auto P = typename AmpereUnpredicatedFprop::P{};
|
||||
auto Q = typename AmpereUnpredicatedFprop::Q{};
|
||||
auto C = typename AmpereUnpredicatedFprop::C{};
|
||||
auto K = typename AmpereUnpredicatedFprop::K{};
|
||||
auto S = typename AmpereUnpredicatedFprop::S{};
|
||||
auto R = typename AmpereUnpredicatedFprop::R{};
|
||||
auto T = typename AmpereUnpredicatedFprop::T{};
|
||||
|
||||
int N = num_images; // dynamic
|
||||
if (N % int(typename AmpereUnpredicatedFprop::Tiler_N{}) != 0) {
|
||||
printf("ERROR: Input image count must be evenly divisible by CTA tiler N. Got num_images = %d\n", N);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Tensor Filter : (k,c,s,r,t)::(128,3,3,3,64):(1728,576,192,64,1)
|
||||
auto filter_layout = AmpereUnpredicatedFprop::GmemLayoutFlt{};
|
||||
|
||||
// Tensor Output : (n,z,p,q,k)::(?,4,2,2,128):(2048,1024,512,128,1)
|
||||
auto output_layout = make_ordered_layout(
|
||||
make_shape( K, make_shape( N, Z, P, Q)),
|
||||
make_tuple(_0{}, make_tuple(_4{},_3{},_2{},_1{})));
|
||||
|
||||
// Input gather layout
|
||||
// inner_layout(make_coord((nzpq), (csrt))) => (idx_buffer_idx, dense_c_idx)
|
||||
auto EG = E<0>{}; // Gather basis (1,0) (idx_buffer_idx)
|
||||
auto EC = E<1>{}; // Contiguous basis (0,1) (dense_offset)
|
||||
auto xformed_act_logical_inner = make_layout(
|
||||
make_shape (make_shape ( N, Z, P, Q), make_shape ( C, T, R, S)),
|
||||
make_stride(make_stride(D*H*W*EG, H*W*EG, W*EG, EG), make_stride(EC, H*W*EG, W*EG, EG)));
|
||||
|
||||
// outer_layout(make_coord(idx_buffer_idx, dense_c_idx)) => idx
|
||||
// IndexedGather obtains idx by applying (gmem_base_ptr + gather_idx_buf[idx_buffer_idx] + dense_offset)
|
||||
auto xformed_act_gather_outer = make_layout(
|
||||
make_shape(_1{},_1{}),
|
||||
make_stride(CustomStride{IndexedGather{gather_idx_buf}, C}, _1{}));
|
||||
|
||||
// Compose the inner and outer layouts
|
||||
// gather_composed(make_coord((nzpq), (csrt))) => idx
|
||||
auto xformed_act_composed_layout = composition(
|
||||
xformed_act_gather_outer,
|
||||
make_arithmetic_tuple(_0{}, _0{}),
|
||||
xformed_act_logical_inner);
|
||||
|
||||
// Output scatter layout
|
||||
auto out_basis_stride = make_stride(
|
||||
E<1>{},
|
||||
make_stride(Z*P*Q*E<0>{}, P*Q*E<0>{}, Q*E<0>{}, _1{}*E<0>{})); // -> (crd0, crd1)
|
||||
auto out_basis_layout = make_layout(shape(output_layout), out_basis_stride);
|
||||
auto out_scatter_layout = make_layout(
|
||||
make_shape(_1{},_1{}),
|
||||
make_stride(CustomStride{IndexedGather{scatter_idx_buf}, K}, _1{}));
|
||||
auto out_composed_layout = composition(
|
||||
out_scatter_layout,
|
||||
make_arithmetic_tuple(_0{},_0{}),
|
||||
out_basis_layout);
|
||||
|
||||
Tensor mXformedActGather = make_tensor(make_gmem_ptr(activations), xformed_act_composed_layout);
|
||||
Tensor mFilter = make_tensor(make_gmem_ptr(filter), filter_layout);
|
||||
Tensor mOutputScatter = make_tensor(make_gmem_ptr(output), out_composed_layout); // (K, (N,Z,P,Q))
|
||||
|
||||
Tensor gOutput_mn = zipped_divide(mOutputScatter, typename AmpereUnpredicatedFprop::TilerOut{}); // ((BLK_M, BLK_N), (m', n'))
|
||||
dim3 lauch_grid {static_cast<uint32_t>(size<1,1>(gOutput_mn)), static_cast<uint32_t>(size<1,0>(gOutput_mn)), 1};
|
||||
constexpr size_t smem_size = sizeof(typename AmpereUnpredicatedFprop::SharedStorage);
|
||||
|
||||
print("xforemed gather layout ((N,Z,P,Q), (C,T,R,S)) = "); print(xformed_act_composed_layout); print("\n");
|
||||
print("Output scatter layout ( K, (N,Z,P,Q)) = "); print(out_composed_layout); print("\n");
|
||||
print("Filter layout ( K, (C,T,R,S)) = "); print(filter_layout); print("\n");
|
||||
|
||||
CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel_entrypoint<AmpereUnpredicatedFprop, decltype(mFilter), decltype(mXformedActGather), decltype(mOutputScatter)>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size));
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
CHECK_CUDA(cudaEventCreate(&start));
|
||||
CHECK_CUDA(cudaEventCreate(&stop));
|
||||
CHECK_CUDA(cudaEventRecord(start));
|
||||
for (int i = 0; i < num_iterations; ++i) {
|
||||
kernel_entrypoint<AmpereUnpredicatedFprop, decltype(mFilter), decltype(mXformedActGather), decltype(mOutputScatter)>
|
||||
<<<lauch_grid, AmpereUnpredicatedFprop::MaxThreadsPerBlock, smem_size>>>(
|
||||
mFilter, mXformedActGather, mOutputScatter);
|
||||
}
|
||||
CHECK_CUDA(cudaEventRecord(stop));
|
||||
CHECK_CUDA(cudaEventSynchronize(stop));
|
||||
float milliseconds = 0;
|
||||
cudaEventElapsedTime(&milliseconds, start, stop);
|
||||
milliseconds /= float(num_iterations);
|
||||
|
||||
double tflop_count = (2 * double(size<0>(xformed_act_logical_inner)) * double(size(filter_layout))) / double(1e12);
|
||||
double tflops = tflop_count / (double(milliseconds) / double(1e3));
|
||||
printf("Conv TFLOP count = %f\n", tflop_count);
|
||||
printf("Conv gather/scatter perf: %fms | TFLOP/s = %f\n", milliseconds, tflops);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int
|
||||
main(int argc, char const** argv) {
|
||||
cutlass::CommandLine cmd(argc, argv);
|
||||
std::cout << "Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors.\n\n";
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
std::cout
|
||||
<< "Options:\n"
|
||||
"\t--n=<int> Sets the number of images for the input activation tensor (dataset size). Default = 131072.\n"
|
||||
"\t--i=<int> Sets the benchmarking repetitions. Default = 128.\n"
|
||||
"\t--nocheck If specified, skips the reference check for dense kernel.\n"
|
||||
"\t--help Displays this help message and exits.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
cudaDeviceProp props;
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
if (props.major < 8) {
|
||||
std::cerr << "This example requires an Ampere GPU or newer.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
int num_images = 4320;
|
||||
cmd.get_cmd_line_argument("n", num_images, 4320);
|
||||
int num_iterations = 128;
|
||||
cmd.get_cmd_line_argument("i", num_iterations, 128);
|
||||
bool do_host_ref_check = not cmd.check_cmd_line_flag("no-check");
|
||||
|
||||
auto D = typename AmpereUnpredicatedFprop::D{};
|
||||
auto H = typename AmpereUnpredicatedFprop::H{};
|
||||
auto W = typename AmpereUnpredicatedFprop::W{};
|
||||
auto Z = typename AmpereUnpredicatedFprop::Z{};
|
||||
auto P = typename AmpereUnpredicatedFprop::P{};
|
||||
auto Q = typename AmpereUnpredicatedFprop::Q{};
|
||||
auto C = typename AmpereUnpredicatedFprop::C{};
|
||||
auto K = typename AmpereUnpredicatedFprop::K{};
|
||||
|
||||
auto activation_layout = make_layout(
|
||||
make_shape (make_shape (num_images, D, H, W), make_shape ( C, _1{},_1{},_1{})),
|
||||
make_stride(make_stride( D*H*W*C, H*W*C, W*C, C), make_stride(_1{}, _0{},_0{},_0{})));
|
||||
|
||||
auto filter_layout = typename AmpereUnpredicatedFprop::GmemLayoutFlt{};
|
||||
|
||||
auto output_layout = make_ordered_layout(
|
||||
make_shape( K, make_shape(num_images, Z, P, Q)),
|
||||
make_step (_0{}, make_step ( _4{},_3{},_2{},_1{})));
|
||||
|
||||
print("Filter layout ( K, (C,T,R,S)) = "); print(filter_layout); print("\n");
|
||||
print("Activation layout ((N,D,H,W), (C,1,1,1)) = "); print(activation_layout); print("\n");
|
||||
print("Output layout ( K, (N,Z,P,Q)) = "); print(output_layout); print("\n");
|
||||
|
||||
// allocate tensors
|
||||
std::cout << "Allocating tensors ... ";
|
||||
thrust::universal_vector<float> activation_data(size_t(cute::size(activation_layout)), float(0));
|
||||
thrust::universal_vector<float> filter_data(size_t(cute::size(filter_layout)), float(0));
|
||||
thrust::universal_vector<float> output_data(size_t(cute::size(output_layout)), float(0));
|
||||
thrust::universal_vector<float> output_data_ref(size_t(cute::size(output_layout)), float(0));
|
||||
std::cout << "done.\n";
|
||||
|
||||
// init tensors
|
||||
std::cout << "Initializing data ... " << std::flush;
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<float> uniform_dist(-1.0, 1.0);
|
||||
for (std::size_t i = 0; i < size_t(cute::size(activation_layout)); ++i) {
|
||||
activation_data[i] = uniform_dist(gen);
|
||||
}
|
||||
|
||||
for (std::size_t i = 0; i < size_t(cute::size(filter_layout)); ++i) {
|
||||
filter_data[i] = uniform_dist(gen);
|
||||
}
|
||||
std::cout << "done.\n";
|
||||
|
||||
// set up index buffers for gather/scatter, fill with indireciton indices in reversed order
|
||||
std::cout << "Initializing gather/scatter index buffers ... ";
|
||||
thrust::universal_vector<uint32_t> gather_idx_buf(size_t(size<0>(activation_layout)));
|
||||
thrust::universal_vector<uint32_t> scatter_idx_buf(size_t(size<1>(output_layout)));
|
||||
thrust::sequence(gather_idx_buf.rbegin(), gather_idx_buf.rend());
|
||||
thrust::sequence(scatter_idx_buf.rbegin(), scatter_idx_buf.rend());
|
||||
std::cout << "done.\n";
|
||||
|
||||
// launch dense
|
||||
std::cout << "\nRunning dense fprop kernel\n";
|
||||
int passed = ampere_dense_conv_fprop(
|
||||
num_images,
|
||||
activation_data.data().get(),
|
||||
filter_data.data().get(),
|
||||
output_data.data().get(),
|
||||
output_data_ref.data().get(),
|
||||
num_iterations,
|
||||
do_host_ref_check);
|
||||
|
||||
// launch gather/scatter
|
||||
std::cout << "\nRunning gather/scatter fprop kernel\n";
|
||||
ampere_gather_scatter_conv_fprop(
|
||||
num_images,
|
||||
activation_data.data().get(),
|
||||
gather_idx_buf.data().get(),
|
||||
filter_data.data().get(),
|
||||
output_data.data().get(),
|
||||
scatter_idx_buf.data().get(),
|
||||
num_iterations);
|
||||
|
||||
return passed;
|
||||
}
|
||||
@ -0,0 +1,534 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Hopper GEMM + Top-K + Softmax fusion
|
||||
|
||||
This example illustrates how to use the LinCombTopKSoftmaxCol EVT node to fuse
|
||||
Top-K and Softmax into the GEMM epilogue, with certain assumptions made.
|
||||
|
||||
Those assumptions are as:
|
||||
1. Fusion is over the N dimension.
|
||||
2. Top-K is either 2 or 4 elements, and the value is static (meaning two kernels have to be
|
||||
compiled to support both.)
|
||||
3. The GEMM tile shape along N is greater than or equal to problem size
|
||||
along N.
|
||||
|
||||
|
||||
The example runs the fused GEMM kernel, along with a standard unfused host reference, and
|
||||
manually performs Top-K and softmax, and compares the error between tensors.
|
||||
|
||||
Note that some numerical error (smaller than 1e-5) is to be expected, but this is true
|
||||
in most efficient reduction kernels, because floating point addition is not necessarily
|
||||
associative.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/error_metrics.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
static constexpr int TopK = 2;
|
||||
static constexpr bool EnableTopKSoftmax = TopK > 1;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::half_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::half_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentC = 1;
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutD = cutlass::layout::RowMajor; // Layout type for output
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of output in units of elements (up to 16 bytes)
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_64,_64,_128>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
|
||||
|
||||
// Top-K + Softmax fusion operation
|
||||
using FusionOperation = std::conditional_t<EnableTopKSoftmax,
|
||||
typename cutlass::epilogue::fusion::LinCombTopKSoftmaxCol<TopK, ElementD, ElementCompute>,
|
||||
typename cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementCompute>
|
||||
>;
|
||||
|
||||
// The fusion op only allows for epilogue tiles matching the mainloop tile.
|
||||
using EpilogueTileType = decltype(cute::take<0,2>(TileShape{}));
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Extract information from Gemm kernel.
|
||||
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
|
||||
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<ElementA , LayoutA > tensor_A;
|
||||
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
|
||||
|
||||
using LayoutScalar = cutlass::layout::PackedVectorLayout;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
|
||||
int iterations = 1000;
|
||||
int m = 16, n = 8, k = 64, l = 1;
|
||||
double eps = 1e-5;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("eps", eps);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "61_hopper_gemm_with_topk_and_softmax\n\n"
|
||||
<< " Hopper FP8 GEMM with Top-K and softmax fusion.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
|
||||
<< " --eps=<float> Threshold of numerical verification. Default: 1e-5.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "61_hopper_gemm_with_topk_and_softmax" << " --m=16 --n=8 --k=1024 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
|
||||
float alpha() const {
|
||||
return 1.f / static_cast<float>(k);
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result {
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, /* max = */ 1, /* min = */ -1, /* bits = */ 2);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
tensor_D.resize(c_coord);
|
||||
tensor_ref_D.resize(c_coord);
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), seed + 2022);
|
||||
initialize_tensor(tensor_B.host_view(), seed + 2023);
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_D.sync_device();
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options) {
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B},
|
||||
{
|
||||
{options.alpha(), 0.f}, // alpha, beta
|
||||
nullptr, stride_D,
|
||||
tensor_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
|
||||
auto B = cute::make_tensor(tensor_B.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
|
||||
auto D = cute::make_tensor(tensor_ref_D.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettMainloopParams<ElementAccumulator, decltype(A), decltype(B)> mainloop_params{A, B};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
ElementScalar,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
unused_t,
|
||||
decltype(D),
|
||||
unused_t, // bias
|
||||
unused_t, // aux
|
||||
unused_t, // valpha
|
||||
unused_t // vbeta
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.alpha = options.alpha();
|
||||
epilogue_params.beta = 0.f;
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
if constexpr (EnableTopKSoftmax) {
|
||||
// top-K + softmax
|
||||
for (int i = 0; i < options.m; ++i) {
|
||||
|
||||
// Find Top-K
|
||||
cutlass::Array<ElementAccumulator, TopK> top_k;
|
||||
top_k.fill(-cutlass::platform::numeric_limits<ElementCompute>::infinity());
|
||||
for (int j = 0; j < options.n; ++j) {
|
||||
auto val = static_cast<ElementAccumulator>(tensor_ref_D.host_view().ref().at({i, j}));
|
||||
for (int top_k_idx = 0; top_k_idx < TopK; ++top_k_idx) {
|
||||
if (val > top_k[top_k_idx]) {
|
||||
// Shift down
|
||||
for (int l = TopK - 1; l > top_k_idx; --l) {
|
||||
top_k[l] = top_k[l - 1];
|
||||
}
|
||||
top_k[top_k_idx] = val;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This formulation of top-K + softmax only works when it is
|
||||
// guaranteed that none of the top-K elements are repeated!
|
||||
// If this is the case, the device kernel can also make mistakes, because
|
||||
// A. Once the top-K values are reduced, and the operation is being applied,
|
||||
// there is no way to tell repeated elements apart, so none are masked.
|
||||
// B. The softmax sum of exps will be incorrect (because the repeated elements
|
||||
// are not repeated in it.)
|
||||
|
||||
ElementAccumulator max = top_k[0];
|
||||
ElementAccumulator sum = ElementAccumulator(0.f);
|
||||
for (int top_k_idx = 0; top_k_idx < TopK; ++top_k_idx) {
|
||||
sum = sum + cutlass::fast_exp(top_k[top_k_idx] - max);
|
||||
}
|
||||
|
||||
for (int j=0; j < options.n; ++j) {
|
||||
auto val = tensor_ref_D.host_view().ref().at({i, j});
|
||||
if (val < top_k[TopK - 1]) {
|
||||
tensor_ref_D.host_view().ref().at({i, j}) = static_cast<ElementD>(0.f);
|
||||
} else {
|
||||
// Softmax
|
||||
auto softmax_val = cutlass::fast_exp(val - max) / sum;
|
||||
tensor_ref_D.host_view().ref().at({i, j}) = static_cast<ElementD>(softmax_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// compare_reference
|
||||
tensor_D.sync_host();
|
||||
|
||||
double err = cutlass::reference::host::TensorRelativeErrorMetric(
|
||||
tensor_D.host_view(),
|
||||
tensor_ref_D.host_view());
|
||||
bool passed = err < options.eps;
|
||||
|
||||
if (options.m <= 32 && options.n <= 32) {
|
||||
std::cout << "GEMM output:\n" << tensor_D.host_view() << "\n\n";
|
||||
std::cout << "Reference output:\n" << tensor_ref_D.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << " \t Relative error: " << err << std::endl;
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options) {
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0) {
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
32
examples/61_hopper_gemm_with_topk_and_softmax/CMakeLists.txt
Normal file
32
examples/61_hopper_gemm_with_topk_and_softmax/CMakeLists.txt
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
61_hopper_gemm_with_topk_and_softmax
|
||||
61_hopper_gemm_with_topk_and_softmax.cu
|
||||
)
|
||||
596
examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu
Normal file
596
examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu
Normal file
@ -0,0 +1,596 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Hopper Sparse GEMM example.
|
||||
|
||||
This example demonstrates how to construct and run a structured sparse GEMM kernel
|
||||
on NVIDIA Hopper architecture.
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::half_t; // Element type for A matrix operand
|
||||
using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::half_t; // Element type for B matrix operand
|
||||
using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using LayoutTagC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size for sparse kernel
|
||||
using TileShapeRef = Shape<_128,_128, _64>; // Threadblock-level tile size for reference (dense) kernel
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized; // Kernel schedule policy
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; // Epilogue schedule policy
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
|
||||
// Sparse kernel setup
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutTagC, AlignmentC,
|
||||
ElementC, LayoutTagC, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
|
||||
ElementA, LayoutTagA, AlignmentA,
|
||||
ElementB, LayoutTagB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference (dense) kernel setup
|
||||
|
||||
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShapeRef, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutTagC, AlignmentC,
|
||||
ElementC, LayoutTagC, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, LayoutTagA, AlignmentA,
|
||||
ElementB, LayoutTagB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShapeRef, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopRef,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
|
||||
|
||||
// Layouts
|
||||
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
|
||||
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
// Layouts for reference (non-sparse) tensors
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
|
||||
using StrideE = StrideA;
|
||||
|
||||
using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE;
|
||||
using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
|
||||
// Offline compressor kernel
|
||||
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
||||
ProblemShape,
|
||||
ElementA,
|
||||
LayoutTagA,
|
||||
SparseConfig>;
|
||||
|
||||
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
|
||||
ProblemShape,
|
||||
ElementA,
|
||||
LayoutTagA,
|
||||
SparseConfig,
|
||||
cutlass::arch::Sm90>;
|
||||
|
||||
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ProblemShape problem_shape;
|
||||
|
||||
StrideA stride_A;
|
||||
StrideA stride_A_compressed;
|
||||
StrideE stride_E;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
|
||||
LayoutA layout_A;
|
||||
LayoutE layout_E;
|
||||
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A_compressed;
|
||||
cutlass::DeviceAllocation<typename Gemm::CollectiveMainloop::ElementE> block_E;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D_ref;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k, l;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(5120), n(4096), k(16384), l(1),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "62_hopper_sparse_gemm\n\n"
|
||||
<< " Hopper Sparse GEMM example.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the L extent of the GEMM (batch size)\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "62_hopper_sparse_gemm" << " --m=4096 --n=5120 --k=8192 --l=1 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Make A structured sparse by replacing elements with 0 and compress it
|
||||
bool sparsify_and_compress()
|
||||
{
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
CompressorUtility compressor_utility(problem_shape, stride_A);
|
||||
|
||||
int ME = compressor_utility.get_metadata_m_physical();
|
||||
int KE = compressor_utility.get_metadata_k_physical();
|
||||
int KC = compressor_utility.get_tensorA_k_physical();
|
||||
|
||||
block_A_compressed.reset(M * KC * L);
|
||||
block_E.reset(ME * KE * L);
|
||||
|
||||
stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KC, L));
|
||||
stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L));
|
||||
|
||||
// Random sparsification is performed on host
|
||||
std::vector<ElementA> block_A_host(block_A.size());
|
||||
cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size());
|
||||
compressor_utility.structure_sparse_zero_mask_fill(block_A_host.data(), static_cast<int>(seed + 2024));
|
||||
cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size());
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
typename Compressor::Arguments arguments {
|
||||
problem_shape,
|
||||
{ block_A.get(),
|
||||
stride_A,
|
||||
block_A_compressed.get(),
|
||||
block_E.get() },
|
||||
{hw_info} };
|
||||
|
||||
Compressor compressor_op;
|
||||
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
CUTLASS_CHECK(compressor_op.can_implement(arguments));
|
||||
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(compressor_op.run());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
bool initialize(Options const& options) {
|
||||
|
||||
problem_shape = make_tuple(options.m, options.n, options.k, options.l);
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
|
||||
// Allocate memory for tensors
|
||||
block_A.reset(M * K * L);
|
||||
block_B.reset(N * K * L);
|
||||
block_C.reset(M * N * L);
|
||||
block_D.reset(M * N * L);
|
||||
block_D_ref.reset(M * N * L);
|
||||
|
||||
// Fill input tensors with data
|
||||
initialize_block(block_A, seed + 2021);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2023);
|
||||
|
||||
// Replace 0 in A with 1 to avoid metadata changes
|
||||
std::vector<ElementA> block_A_host(block_A.size());
|
||||
cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size());
|
||||
for (size_t i = 0; i < block_A.size(); ++i) if (block_A_host[i] == ElementA(0)) block_A_host[i] = ElementA(1.0);
|
||||
cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size());
|
||||
|
||||
if (!sparsify_and_compress()) {
|
||||
return false;
|
||||
};
|
||||
|
||||
// Build the compressed/metadata layouts
|
||||
layout_A = SparseConfig::fill_layoutA(problem_shape);
|
||||
layout_E = SparseConfig::fill_layoutE(problem_shape);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments make_args(Options const& options)
|
||||
{
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_shape,
|
||||
{ block_A_compressed.get(), layout_A, block_B.get(), stride_B, block_E.get(), layout_E },
|
||||
{ { ElementAccumulator(options.alpha), ElementAccumulator(options.beta) },
|
||||
block_C.get(), stride_C, block_D.get(), stride_D }
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
typename GemmRef::Arguments make_args_ref(Options const& options)
|
||||
{
|
||||
typename GemmRef::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_shape,
|
||||
{ block_A.get(), stride_A, block_B.get(), stride_B },
|
||||
{ { ElementAccumulator(options.alpha), ElementAccumulator(options.beta) },
|
||||
block_C.get(), stride_C, block_D_ref.get(), stride_D }
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template<class Engine, class Layout>
|
||||
void print_device_tensor(cute::Tensor<Engine, Layout> const& t)
|
||||
{
|
||||
// Assumes size = cosize, i.e. compact tensor
|
||||
std::vector<typename Engine::value_type> data_host(t.size());
|
||||
cutlass::device_memory::copy_to_host(data_host.data(), t.data(), t.size());
|
||||
auto t_host = cute::make_tensor(data_host.data(), t.layout());
|
||||
cute::print_tensor(t_host);
|
||||
}
|
||||
|
||||
bool verify(Options const& options) {
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_D_ref.get(), block_D.get(), block_D.size());
|
||||
|
||||
#if 0
|
||||
if (!passed) {
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
CompressorUtility compressor_utility(problem_shape, stride_A);
|
||||
int ME = compressor_utility.get_metadata_m_physical();
|
||||
int KE = compressor_utility.get_metadata_k_physical();
|
||||
int KC = compressor_utility.get_tensorA_k_physical();
|
||||
|
||||
cute::print("A (original): "); print_device_tensor(make_tensor(block_A.get(), make_shape(M, K, L), stride_A));
|
||||
cute::print("A (compressed): "); print_device_tensor(make_tensor(block_A_compressed.get(), make_shape(M, KC, L), stride_A_compressed));
|
||||
cute::print("E (physical): "); print_device_tensor(make_tensor(block_E.get(), make_shape(ME, KE, L), stride_E));
|
||||
cute::print("E (logical): "); print_device_tensor(make_tensor(block_E.get(), upcast<CollectiveMainloop::ElementEMmaSparsity>(layout_E)));
|
||||
cute::print("B: "); print_device_tensor(make_tensor(block_B.get(), make_shape(N, K, L), stride_B));
|
||||
cute::print("C: "); print_device_tensor(make_tensor(block_C.get(), make_shape(M, N, L), stride_C));
|
||||
cute::print("D reference: "); print_device_tensor(make_tensor(block_D_ref.get(), make_shape(M, N, L), stride_D));
|
||||
cute::print("D computed: "); print_device_tensor(make_tensor(block_D.get(), make_shape(M, N, L), stride_D));
|
||||
}
|
||||
#endif
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
template<typename Gemm>
|
||||
struct Runner
|
||||
{
|
||||
using Arguments = typename Gemm::Arguments;
|
||||
|
||||
Runner(Arguments args): arguments(args) {
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
workspace.reset(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
}
|
||||
|
||||
void run() {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
|
||||
void benchmark(Options const& options) {
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
run();
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
double avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
double gflops = options.gflops(avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Avg runtime: " << avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << gflops << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
Gemm gemm;
|
||||
Arguments arguments;
|
||||
cutlass::device_memory::allocation<uint8_t> workspace;
|
||||
};
|
||||
|
||||
/// Execute the example (verification and timing)
|
||||
void run(Options &options) {
|
||||
bool init = initialize(options);
|
||||
if (!init) {
|
||||
std::cout << "Initialization failure" << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
Runner<Gemm> gemm(make_args(options));
|
||||
Runner<GemmRef> gemm_ref(make_args_ref(options));
|
||||
|
||||
gemm.run();
|
||||
gemm_ref.run();
|
||||
|
||||
bool passed = verify(options);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!passed) {
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
std::cout << "Sparse GEMM:" << std::endl;
|
||||
gemm.benchmark(options);
|
||||
|
||||
std::cout << "Dense GEMM:" << std::endl;
|
||||
gemm_ref.benchmark(options);
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.2 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 2)) {
|
||||
std::cerr << "This example requires CUDA 12.2 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED)
|
||||
run(options);
|
||||
#endif
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
36
examples/62_hopper_sparse_gemm/CMakeLists.txt
Normal file
36
examples/62_hopper_sparse_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
|
||||
# Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Sparse kernel in this example triggers an ICE in gcc 7.5
|
||||
if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0))
|
||||
cutlass_example_add_executable(
|
||||
62_hopper_sparse_gemm
|
||||
62_hopper_sparse_gemm.cu
|
||||
)
|
||||
endif()
|
||||
@ -0,0 +1,500 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Hopper FP8 GEMM + L2 Weight Prefetch
|
||||
|
||||
This example implements a non-persistent warp-specialized GEMM kernel for the Hopper
|
||||
architecture with programmatic dependent launch (PDL) enabling prefetching weights into
|
||||
L2 cache.
|
||||
|
||||
For more information about dependent launch refer to the CUDA programming guide:
|
||||
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization
|
||||
|
||||
In some cases, PDL can result in a window where a previous kernel is not actively utilizing
|
||||
DRAM, and the next kernel sits idle until the previous finishes. During this window, the next
|
||||
kernel can begin loading a non-dependent operand (i.e. weights in a linear projection are
|
||||
typically static) and cache it in L2.
|
||||
|
||||
The kernel and collective mainloop assume operand `A` corresponds to weights and operand `B`
|
||||
corresponds to activations (so we can have very small batch/token count).
|
||||
After initialization, the prefetch warp starts loading K tiles of `A` into an unused portion
|
||||
of shared memory, and loads up to half of all K tiles that the same CTA would eventually load.
|
||||
The exact number of K tiles loaded is determined by `args.mainloop.prefetch_ratio` \in
|
||||
[0.0, 1.0]. Smaller values result in less prefetching, and larger values result in more.
|
||||
Negative values result in a "best-effort" prefetch, meaning prefetcher will stop issuing weight
|
||||
loads as soon as the activation DMA warp starts loading (as soon as it is signaled that the
|
||||
previous kernel has flushed its memory.)
|
||||
|
||||
The DMA warp responsible for loading `A` will also begin loading K tiles until it fills up
|
||||
the available shared memory.
|
||||
The DMA warp responsible for loading `B` will wait until activations are flushed to global
|
||||
memory by the preceding kernel.
|
||||
|
||||
Another mainloop parameter, `args.mainloop.overlap_ratio` \in [0.0, 1.0] determines how early
|
||||
the next kernel (the one doing the prefetch) is launched. Smaller values result in greater
|
||||
overlap, and larger values result in smaller overlap. Negative values disable PDL completely,
|
||||
meaning there will be no overlap. This will make prefetch ineffective.
|
||||
|
||||
These two runtime parameters should be tuned per problem size and GEMM config combination, and
|
||||
if feasible, per-operation in an entire layer or model.
|
||||
|
||||
NOTE: you must build this target with the following flag to enable Grid Dependency Control
|
||||
instructions (GDC) in CUTLASS:
|
||||
- CUTLASS_ENABLE_GDC_FOR_SM90
|
||||
|
||||
To lock persistence mode, power (350W), clocks (1005MHz) for evaluation (assumes device 0 and H100)
|
||||
|
||||
$ sudo nvidia-smi -pm 1 -i 0
|
||||
|
||||
$ sudo nvidia-smi -i 0 -pl 350
|
||||
|
||||
$ sudo nvidia-smi -i 0 -lgc 1005
|
||||
|
||||
Example:
|
||||
|
||||
$ mkdir build && cd build
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="90a" -DCUTLASS_ENABLE_GDC_FOR_SM90=1
|
||||
|
||||
$ cd examples/63_hopper_gemm_with_weight_prefetch
|
||||
|
||||
$ make
|
||||
|
||||
$ ./63_hopper_gemm_with_weight_prefetch --p=0.5 --o=0.5
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
|
||||
#include "collective/dispatch_policy_extra.hpp"
|
||||
#include "collective/builder.hpp"
|
||||
#include "kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "gemm_with_weight_prefetch_commandline.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_64,_64,_128>; // Threadblock-level tile size
|
||||
// Cluster_N > 1 is not supported yet.
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Extract information from Gemm kernel.
|
||||
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
|
||||
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<ElementA , LayoutA > tensor_A;
|
||||
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
|
||||
cutlass::HostTensor<ElementC , LayoutC > tensor_C;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
|
||||
|
||||
using LayoutScalar = cutlass::layout::PackedVectorLayout;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_alpha;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
double eff_bw;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
double eff_bw = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), eff_bw(eff_bw), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
}
|
||||
else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
tensor_C.resize(c_coord);
|
||||
tensor_D.resize(c_coord);
|
||||
tensor_ref_D.resize(c_coord);
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), seed + 2022);
|
||||
initialize_tensor(tensor_B.host_view(), seed + 2023);
|
||||
initialize_tensor(tensor_C.host_view(), seed + 2024);
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_D.sync_device();
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
tensor_C.device_data(), stride_C,
|
||||
tensor_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
auto &fusion_args = arguments.epilogue.thread;
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = scalar_alpha.device_data();
|
||||
fusion_args.beta_ptr = scalar_beta.device_data();
|
||||
|
||||
arguments.mainloop.overlap_ratio = options.overlap_ratio;
|
||||
arguments.mainloop.prefetch_ratio = options.prefetch_ratio;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
|
||||
auto B = cute::make_tensor(tensor_B.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
|
||||
auto C = cute::make_tensor(tensor_C.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
|
||||
auto D = cute::make_tensor(tensor_ref_D.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettMainloopParams<ElementAccumulator, decltype(A), decltype(B)> mainloop_params{A, B};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
ElementScalar,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D),
|
||||
unused_t, // bias
|
||||
unused_t, // aux
|
||||
unused_t, // valpha
|
||||
unused_t // vbeta
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.alpha = options.alpha;
|
||||
epilogue_params.beta = options.beta;
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// compare_reference
|
||||
tensor_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run(nullptr, nullptr, /* launch_with_pdl = */ options.overlap_ratio >= 0));
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run(nullptr, nullptr, /* launch_with_pdl = */ options.overlap_ratio >= 0));
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
double avg_runtime_s = (double)(result.avg_runtime_ms / 1000.0);
|
||||
result.gflops = options.gflops(avg_runtime_s);
|
||||
result.eff_bw = options.effective_bandwidth(avg_runtime_s, sizeof(ElementA), sizeof(ElementB), sizeof(ElementC), sizeof(ElementD));
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
std::cout << " Effective bandwidth: " << result.eff_bw << " GB/s" << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
36
examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt
Normal file
36
examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
include_directories(
|
||||
.
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
63_hopper_gemm_with_weight_prefetch
|
||||
63_hopper_gemm_with_weight_prefetch.cu
|
||||
)
|
||||
82
examples/63_hopper_gemm_with_weight_prefetch/README.md
Normal file
82
examples/63_hopper_gemm_with_weight_prefetch/README.md
Normal file
@ -0,0 +1,82 @@
|
||||
# GEMM with L2 weight prefetch
|
||||
|
||||
A non-persistent warp specialized GEMM directed at low latency inference.
|
||||
|
||||
The kernel can optionally prefetch a portion of weights (operand `A`) into L2 cache while the
|
||||
rest of the warps are waiting on the previous kernel to finish writing and flush its memory.
|
||||
An example of this is normalization or reduction kernels that are immediately followed by a GEMM.
|
||||
|
||||
It exposes two runtime parameters:
|
||||
1. `overlap_ratio`: how early `griddepcontrol.launch_dependent_grids` is issued.
|
||||
Default is `0.5`, meaning after approximately half of K tiles are loaded by DMA warps.
|
||||
2. `prefetch_ratio`: what percentage of K tiles to prefetch.
|
||||
Default is `-1.0`, meaning prefetching will stop as soon as other DMA warps are past
|
||||
`griddepcontrol`.
|
||||
|
||||
It is highly recommended to auto-tune these parameters per GEMM and according to some end to end
|
||||
runtime (either an entire transformer layer or multiple, but probably not the entire model.)
|
||||
|
||||
TMA loads use non-default cache hints: `A` (weights) are loaded with `EvictFirst`, and `B` (activation)
|
||||
is loaded with `EvictLast`.
|
||||
|
||||
## Getting started
|
||||
To use this kernel in your own target, add this directory to your includes, and include the
|
||||
following headers from this example:
|
||||
|
||||
```cxx
|
||||
#include "collective/dispatch_policy_extra.hpp"
|
||||
#include "collective/builder.hpp"
|
||||
#include "kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp"
|
||||
```
|
||||
|
||||
And then use either one of the new kernel schedules:
|
||||
|
||||
```cxx
|
||||
// Without separate warps for A and B
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetch;
|
||||
|
||||
// With separate warps for A and B
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA;
|
||||
```
|
||||
|
||||
The kernel with separate warps for A and B (
|
||||
`KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA`)
|
||||
is expected to be more performant than the other, especially since it allows the kernel to load
|
||||
weights into shmem ahead of the `griddepcontrol`.
|
||||
|
||||
As for other GEMM parameters, Thread Block Cluster larger than 1 CTA are not yet supported, and
|
||||
obviously the kernel layer implementation is warp specialized and uses the TMA, and other kernel
|
||||
layers or collectives require reimplementation.
|
||||
|
||||
## Example
|
||||
|
||||
Using the example is mostly straightforward.
|
||||
Just build, and run with your choice of `MNK`:
|
||||
|
||||
```bash
|
||||
./63_hopper_gemm_with_weight_prefetch --m=8192 --n=1 --k=8192
|
||||
```
|
||||
|
||||
You can also disable the overlap or try different overlap and prefetch ratios and see the
|
||||
difference:
|
||||
|
||||
```bash
|
||||
echo "Without overlap and prefetch"
|
||||
./63_hopper_gemm_with_weight_prefetch --o=-1.0 --p=-1.0
|
||||
|
||||
echo "Overlap ratio of 0.5, best effort prefetch"
|
||||
./63_hopper_gemm_with_weight_prefetch --o=0.5 --p=-1.0
|
||||
|
||||
echo "Overlap ratio of 0.8, prefetch ratio of 0.7"
|
||||
./63_hopper_gemm_with_weight_prefetch --o=0.8 --p=0.7
|
||||
```
|
||||
|
||||
However, note that the example still runs a single GEMM, and most of the performance improvement
|
||||
is expected in end to end applications.
|
||||
|
||||
|
||||
## Limitations
|
||||
* The parameter defaults are typically not good choices, especially `prefetch_ratio`.
|
||||
When `prefetch_ratio` is unspecified (set to `-1.0`), the prefetch warp will `try_wait` on a
|
||||
memory barrier before issuing every single TMA load, and in many cases this will slow down
|
||||
prefetching to the point of being almost ineffective.
|
||||
@ -0,0 +1,242 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "dispatch_policy_extra.hpp"
|
||||
#include "sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp"
|
||||
#include "../pipeline/prefetch_pipeline_sm90.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||
template<int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, int stages>
|
||||
constexpr int
|
||||
compute_stage_count_or_override_prefetch(StageCount<stages> stage_count) {
|
||||
return stages;
|
||||
}
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||
template<int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, int carveout_bytes>
|
||||
constexpr int
|
||||
compute_stage_count_or_override_prefetch(StageCountAutoCarveout<carveout_bytes> stage_count) {
|
||||
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
|
||||
constexpr auto prefetch_pipeline_bytes = sizeof(typename cutlass::detail::PrefetcherPipelineSharedStorage<PrefetchStages>);
|
||||
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>;
|
||||
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>;
|
||||
constexpr int MK_bytes = cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})); //also the prefetch smem size
|
||||
constexpr int NK_bytes = cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}));
|
||||
constexpr int stage_bytes = MK_bytes + NK_bytes + static_cast<int>(mainloop_pipeline_bytes);
|
||||
|
||||
return (CapacityBytes - carveout_bytes - MK_bytes * PrefetchStagesActual - prefetch_pipeline_bytes) / stage_bytes;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch
|
||||
template <
|
||||
class ElementA,
|
||||
class GmemLayoutATag,
|
||||
int AlignmentA,
|
||||
class ElementB,
|
||||
class GmemLayoutBTag,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
class KernelScheduleType
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm90,
|
||||
arch::OpClassTensorOp,
|
||||
ElementA,
|
||||
GmemLayoutATag,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
GmemLayoutBTag,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccumWithPrefetch>>
|
||||
> {
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Not meet TMA alignment requirement yet\n");
|
||||
static_assert(detail::is_input_fp8<ElementA, ElementB>(),
|
||||
"Only FP8 datatypes are compatible with these kernel schedules\n");
|
||||
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
|
||||
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>(),
|
||||
"Not supported for fp8 non-TN warp specialized kernels yet\n");
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutATag>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutBTag>();
|
||||
|
||||
using AtomLayoutMNK = Layout<Shape<_1,_1,_1>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
||||
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override_prefetch<detail::sm90_smem_capacity_bytes,
|
||||
ElementA, ElementB, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMma<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementA,
|
||||
TagToStrideA_t<GmemLayoutATag>,
|
||||
ElementB,
|
||||
TagToStrideB_t<GmemLayoutBTag>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
SmemCopyAtomA,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
SmemCopyAtomB,
|
||||
cute::identity
|
||||
>;
|
||||
};
|
||||
|
||||
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch and split DMA warps
|
||||
template <
|
||||
class ElementA,
|
||||
class GmemLayoutATag,
|
||||
int AlignmentA,
|
||||
class ElementB,
|
||||
class GmemLayoutBTag,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
class KernelScheduleType
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm90,
|
||||
arch::OpClassTensorOp,
|
||||
ElementA,
|
||||
GmemLayoutATag,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
GmemLayoutBTag,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA>>
|
||||
> {
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Not meet TMA alignment requirement yet\n");
|
||||
static_assert(detail::is_input_fp8<ElementA, ElementB>(),
|
||||
"Only FP8 datatypes are compatible with these kernel schedules\n");
|
||||
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
|
||||
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>(),
|
||||
"Not supported for fp8 non-TN warp specialized kernels yet\n");
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutATag>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutBTag>();
|
||||
|
||||
using AtomLayoutMNK = Layout<Shape<_1,_1,_1>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
||||
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override_prefetch<detail::sm90_smem_capacity_bytes,
|
||||
ElementA, ElementB, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMma<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementA,
|
||||
TagToStrideA_t<GmemLayoutATag>,
|
||||
ElementB,
|
||||
TagToStrideB_t<GmemLayoutBTag>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
SmemCopyAtomA,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
SmemCopyAtomB,
|
||||
cute::identity
|
||||
>;
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,61 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace cutlass::gemm {
|
||||
|
||||
// Standard non-persistent kernel with a single producer warp, and one prefetch warp.
|
||||
// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A`
|
||||
// while the producer warp is waiting on griddepcontrol.
|
||||
// GDC `launch_dependent_grids` is issued from the producer warp instead of math warps, and
|
||||
// according to prefetch ratio.
|
||||
struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetch { };
|
||||
|
||||
// Non-persistent kernel with two producer warps (one for each of A and B), and one prefetch warp.
|
||||
// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A`
|
||||
// while the producer warp for `B` is waiting on griddepcontrol. Producer warp for `A` does not
|
||||
// wait on griddepcontrol and loads immediately.
|
||||
struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA { };
|
||||
|
||||
template<
|
||||
int Stages_,
|
||||
class ClusterShape_ = Shape<_1,_1,_1>,
|
||||
class KernelSchedule = KernelTmaWarpSpecializedFP8FastAccumWithPrefetch
|
||||
>
|
||||
struct MainloopSm90TmaGmmaWarpSpecializedWithPrefetch {
|
||||
constexpr static int Stages = Stages_;
|
||||
using ClusterShape = ClusterShape_;
|
||||
using ArchTag = arch::Sm90;
|
||||
using Schedule = KernelSchedule;
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm
|
||||
@ -0,0 +1,872 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cutlass/arch/grid_dependency_control.h"
|
||||
|
||||
#include "dispatch_policy_extra.hpp"
|
||||
|
||||
#include "../pipeline/prefetch_pipeline_sm90.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
constexpr int PrefetchStages = 4;
|
||||
constexpr int PrefetchInitialStages = 1;
|
||||
// This determines how much shmem we set aside for prefetch.
|
||||
// We don't reuse anything loaded by prefetcher, so we can keep
|
||||
// loading into the same place -- there will be a conflict when
|
||||
// writing, but it doesn't affect performance as much as the doors
|
||||
// that this opens.
|
||||
constexpr int PrefetchStagesActual = 1;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <
|
||||
int Stages,
|
||||
class ClusterShape,
|
||||
class KernelSchedule,
|
||||
class TileShape_,
|
||||
class ElementA_,
|
||||
class StrideA_,
|
||||
class ElementB_,
|
||||
class StrideB_,
|
||||
class TiledMma_,
|
||||
class GmemTiledCopyA_,
|
||||
class SmemLayoutAtomA_,
|
||||
class SmemCopyAtomA_,
|
||||
class TransformA_,
|
||||
class GmemTiledCopyB_,
|
||||
class SmemLayoutAtomB_,
|
||||
class SmemCopyAtomB_,
|
||||
class TransformB_>
|
||||
struct CollectiveMma<
|
||||
MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<Stages, ClusterShape, KernelSchedule>,
|
||||
TileShape_,
|
||||
ElementA_,
|
||||
StrideA_,
|
||||
ElementB_,
|
||||
StrideB_,
|
||||
TiledMma_,
|
||||
GmemTiledCopyA_,
|
||||
SmemLayoutAtomA_,
|
||||
SmemCopyAtomA_,
|
||||
TransformA_,
|
||||
GmemTiledCopyB_,
|
||||
SmemLayoutAtomB_,
|
||||
SmemCopyAtomB_,
|
||||
TransformB_>
|
||||
{
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<Stages, ClusterShape, KernelSchedule>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
static_assert(size<1>(ClusterShape{}) == 1, "Cluster shape N must be 1");
|
||||
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
|
||||
|
||||
using PrefetcherPipeline = cutlass::PrefetchPipeline<detail::PrefetchStages>;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(
|
||||
SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
|
||||
static_assert(rank(SmemLayoutA{}) == 3 && size<2>(SmemLayoutA{}) == DispatchPolicy::Stages);
|
||||
static_assert(rank(SmemLayoutB{}) == 3 && size<2>(SmemLayoutB{}) == DispatchPolicy::Stages);
|
||||
|
||||
using PrefetchSmemLayoutA = decltype(make_layout(make_shape(
|
||||
cute::Int<size<0>(SmemLayoutA{})>{},
|
||||
cute::Int<size<1>(SmemLayoutA{})>{},
|
||||
cute::Int<detail::PrefetchStagesActual>{})));
|
||||
|
||||
static constexpr auto prefetch_smem_size = cute::cosize_v<PrefetchSmemLayoutA>;
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
||||
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
|
||||
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
|
||||
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
|
||||
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
|
||||
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
|
||||
using InternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
|
||||
using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
|
||||
|
||||
// Defined outside the class where it's used, to work around MSVC issues
|
||||
using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage<detail::PrefetchStages>;
|
||||
|
||||
struct SharedStorage {
|
||||
struct TensorStorage : cute::aligned_struct<128, _0> {
|
||||
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
|
||||
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
|
||||
cute::array_aligned<typename TiledMma::ValTypeA, prefetch_smem_size> smem_prefetch;
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
PipelineStorage pipeline;
|
||||
PrefetcherPipelineStorage prefetcher_pipeline;
|
||||
};
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
ElementA const* ptr_A;
|
||||
StrideA dA;
|
||||
ElementB const* ptr_B;
|
||||
StrideB dB;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
float overlap_ratio = 0.5;
|
||||
float prefetch_ratio = -1.0;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy_A_sm90(
|
||||
GmemTiledCopyA{},
|
||||
make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{}));
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy_B_sm90(
|
||||
GmemTiledCopyB{},
|
||||
make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{}));
|
||||
|
||||
TMA_A tma_load_a;
|
||||
TMA_B tma_load_b;
|
||||
uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
|
||||
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
|
||||
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
|
||||
float overlap_ratio = 0.5;
|
||||
float prefetch_ratio = -1.0;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
(void) workspace;
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
auto ptr_A = reinterpret_cast<InternalElementA const*>(args.ptr_A);
|
||||
auto ptr_B = reinterpret_cast<InternalElementB const*>(args.ptr_B);
|
||||
|
||||
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
|
||||
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
|
||||
|
||||
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90(
|
||||
GmemTiledCopyA{},
|
||||
tensor_a,
|
||||
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{});
|
||||
typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90(
|
||||
GmemTiledCopyB{},
|
||||
tensor_b,
|
||||
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{});
|
||||
uint32_t transaction_bytes_mk = TmaTransactionBytesMK;
|
||||
uint32_t transaction_bytes_nk = TmaTransactionBytesNK;
|
||||
uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk;
|
||||
|
||||
return {
|
||||
tma_load_a,
|
||||
tma_load_b,
|
||||
transaction_bytes,
|
||||
transaction_bytes_mk,
|
||||
transaction_bytes_nk,
|
||||
args.overlap_ratio,
|
||||
args.prefetch_ratio
|
||||
};
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
static bool
|
||||
can_implement(
|
||||
ProblemShape const& problem_shape,
|
||||
[[maybe_unused]] Arguments const& args) {
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
bool implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (args.overlap_ratio > 1.0) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `overlap_ratio` must be either negative (disabled) or in [0, 1].\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (args.prefetch_ratio > 1.0) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `prefetch_ratio` must be either negative (disabled) or in [0, 1].\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
static constexpr int K_PIPE_MMAS = 1;
|
||||
static constexpr uint32_t TmaTransactionBytesMK =
|
||||
cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value));
|
||||
static constexpr uint32_t TmaTransactionBytesNK =
|
||||
cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
||||
/// Returned tuple must contain at least two elements, with the first two elements being:
|
||||
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||
/// The rest of the tensors can be specified as needed by this collective.
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto
|
||||
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||
// Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
return cute::make_tuple(gA_mkl, gB_nkl);
|
||||
}
|
||||
|
||||
template <
|
||||
class TensorA, class TensorB,
|
||||
class KTileIterator, class BlockCoord
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
load(
|
||||
Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PrefetcherPipeline prefetcher_pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
TensorA const& gA_mkl,
|
||||
TensorB const& gB_nkl,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
bool disable_gdc = mainloop_params.overlap_ratio < 0.0;
|
||||
float overlap_ratio = mainloop_params.overlap_ratio;
|
||||
int launch_dep_grids_threshold = static_cast<int>(static_cast<float>(k_tile_count - 1) * overlap_ratio);
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||
|
||||
// Applies the mapping from cta_tma_a
|
||||
Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
// Applies the mapping from cta_tma_b
|
||||
Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// We have to wait on dependent grids because of B.
|
||||
cutlass::arch::wait_on_dependent_grids();
|
||||
|
||||
// Signal prefetcher to stop
|
||||
prefetcher_pipeline.producer_arrive();
|
||||
|
||||
bool launch_dep_grids = false;
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
int write_stage = smem_pipe_write.index();
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
||||
launch_dep_grids = true;
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
if (!disable_gdc && !launch_dep_grids) {
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
class TensorA,
|
||||
class KTileIterator, class BlockCoord
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
load_MK(
|
||||
Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PrefetcherPipeline prefetcher_pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
TensorA const& gA_mkl,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
bool disable_gdc = mainloop_params.overlap_ratio < 0.0;
|
||||
float overlap_ratio = mainloop_params.overlap_ratio;
|
||||
int launch_dep_grids_threshold = static_cast<int>(static_cast<float>(k_tile_count - 1) * overlap_ratio);
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
|
||||
// Applies the mapping from cta_tma_a
|
||||
Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// Don't wait on dependent grids when loading `A`, because
|
||||
// we assume `A` (weights) are static.
|
||||
|
||||
bool launch_dep_grids = false;
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
int write_stage = smem_pipe_write.index();
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
||||
launch_dep_grids = true;
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
if (!disable_gdc && !launch_dep_grids) {
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
class TensorB,
|
||||
class KTileIterator, class BlockCoord
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
load_NK(
|
||||
Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PrefetcherPipeline prefetcher_pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
TensorB const& gB_nkl,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||
|
||||
// Applies the mapping from cta_tma_b
|
||||
Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that the prefetched kernel does not touch
|
||||
// unflushed global memory prior to this instruction
|
||||
cutlass::arch::wait_on_dependent_grids();
|
||||
|
||||
// Signal prefetcher to stop
|
||||
prefetcher_pipeline.producer_arrive();
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
int write_stage = smem_pipe_write.index();
|
||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void
|
||||
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <
|
||||
class TensorA,
|
||||
class KTileIterator, class BlockCoord
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
prefetch_MK(
|
||||
Params const& mainloop_params,
|
||||
PrefetcherPipeline prefetcher_pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
TensorA const& gA_mkl,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
bool do_best_effort_prefetch = mainloop_params.prefetch_ratio < 0;
|
||||
float prefetch_ratio = do_best_effort_prefetch ? 1.0 : mainloop_params.prefetch_ratio;
|
||||
int prefetch_iters = static_cast<int>(static_cast<float>(k_tile_count) * 0.5 * prefetch_ratio);
|
||||
prefetch_iters = min(k_tile_count, ((prefetch_iters + detail::PrefetchStages - 1) / detail::PrefetchStages) * detail::PrefetchStages);
|
||||
|
||||
Tensor sA = make_tensor(
|
||||
make_smem_ptr(shared_tensors.smem_prefetch.data()), PrefetchSmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
|
||||
// Applies the mapping from cta_tma_a
|
||||
Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t prefetcher_stage = 0;
|
||||
uint32_t prefetcher_phase = 0;
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int cnt = 0 ; cnt < prefetch_iters; ++cnt) {
|
||||
|
||||
if (do_best_effort_prefetch && prefetcher_pipeline.have_producers_arrived()) {
|
||||
break;
|
||||
}
|
||||
|
||||
prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= detail::PrefetchStages);
|
||||
using BarrierType = typename PrefetcherPipeline::PrefetcherBarrierType;
|
||||
BarrierType* tma_barrier = prefetcher_pipeline.prefetcher_get_barrier(prefetcher_stage);
|
||||
|
||||
int write_stage = 0;
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
++k_tile_iter;
|
||||
++k_tile_iter;
|
||||
|
||||
prefetcher_pipeline.advance_prefetcher_state(prefetcher_stage, prefetcher_phase);
|
||||
}
|
||||
prefetcher_pipeline.prefetcher_tail(prefetcher_stage, prefetcher_phase);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <
|
||||
class FrgTensorC
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
mma(MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_read,
|
||||
FrgTensorC& accum,
|
||||
int k_tile_count,
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params) {
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors"
|
||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
||||
"ERROR : Incorrect number of MMAs in flight");
|
||||
|
||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
warpgroup_fence_operand(accum);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accum);
|
||||
// Mainloop GMMAs
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_fence_operand(accum);
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accum);
|
||||
|
||||
// UNLOCK smem_pipe_release, done _computing_ on it
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
|
||||
// Advance smem_pipe_read and smem_pipe_release
|
||||
++smem_pipe_read;
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accum);
|
||||
}
|
||||
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void
|
||||
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count) {
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
++smem_pipe_release;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,117 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
float overlap_ratio = 0.5f, prefetch_ratio = 0.5f;
|
||||
int iterations = 1000;
|
||||
int n = 64, m = 1280, k = 8192, l = 1;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("p", prefetch_ratio, 0.5f);
|
||||
cmd.get_cmd_line_argument("o", overlap_ratio, 0.5f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "63_hopper_gemm_with_weight_prefetch\n\n"
|
||||
<< " Hopper FP8 GEMM using a non-persistent kernel with L2 weight prefetch. \n"
|
||||
<< " For more details please refer to the source file.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --p=<f32> Prefetch ratio\n"
|
||||
<< " --o=<f32> Overlap ratio\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "63_hopper_gemm_with_weight_prefetch" <<
|
||||
" --m=1024 --n=512 --k=1024 --o=0.5 --p=0.5 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k * l;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
|
||||
/// Compute effective bandwidth in GB/sec
|
||||
double effective_bandwidth(
|
||||
double runtime_s,
|
||||
size_t bytes_a,
|
||||
size_t bytes_b,
|
||||
size_t bytes_c,
|
||||
size_t bytes_d
|
||||
) const
|
||||
{
|
||||
static double const kBytesPerGiB = double(1ull << 30);
|
||||
|
||||
double bytes_in =
|
||||
(double)(l) * (double)(m) * (double)(k) * (double)(bytes_a) + // A
|
||||
(double)(l) * (double)(n) * (double)(k) * (double)(bytes_b) + // B
|
||||
(beta != 0.f ? (double)(l) * (double)(m) * (double)(n) * (double)(bytes_c) : 0.f); // C
|
||||
double bytes_out = (double)(l) * (double)(m) * (double)(n) * (double)(bytes_d); // D
|
||||
|
||||
double gb_total = (bytes_in + bytes_out) / kBytesPerGiB;
|
||||
return gb_total / runtime_s;
|
||||
}
|
||||
};
|
||||
@ -0,0 +1,561 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "../collective/dispatch_policy_extra.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GEMM + Prefetch for the A tensor + (optional) split DMA warps
|
||||
template <
|
||||
class ProblemShape_,
|
||||
class CollectiveMainloop_,
|
||||
class CollectiveEpilogue_,
|
||||
class TileScheduler_
|
||||
>
|
||||
class GemmUniversal<
|
||||
ProblemShape_,
|
||||
CollectiveMainloop_,
|
||||
CollectiveEpilogue_,
|
||||
TileScheduler_,
|
||||
cute::enable_if_t<
|
||||
cute::is_same_v<typename CollectiveMainloop_::DispatchPolicy::Schedule, KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA> ||
|
||||
cute::is_same_v<typename CollectiveMainloop_::DispatchPolicy::Schedule, KernelTmaWarpSpecializedFP8FastAccumWithPrefetch>
|
||||
>
|
||||
>
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled;
|
||||
|
||||
static constexpr bool SplitWarps = cute::is_same_v<typename CollectiveMainloop_::DispatchPolicy::Schedule, KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA>;
|
||||
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using TiledMma = typename CollectiveMainloop::TiledMma;
|
||||
using ArchTag = typename CollectiveMainloop::ArchTag;
|
||||
using ElementA = typename CollectiveMainloop::ElementA;
|
||||
using StrideA = typename CollectiveMainloop::StrideA;
|
||||
using ElementB = typename CollectiveMainloop::ElementB;
|
||||
using StrideB = typename CollectiveMainloop::StrideB;
|
||||
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
||||
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
||||
using ClusterShape = typename DispatchPolicy::ClusterShape;
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
using MainloopParams = typename CollectiveMainloop::Params;
|
||||
static_assert(ArchTag::kMinComputeCapability >= 90);
|
||||
|
||||
// Epilogue derived types
|
||||
using CollectiveEpilogue = CollectiveEpilogue_;
|
||||
using ElementC = typename CollectiveEpilogue::ElementC;
|
||||
using StrideC = typename CollectiveEpilogue::StrideC;
|
||||
using ElementD = typename CollectiveEpilogue::ElementD;
|
||||
using StrideD = typename CollectiveEpilogue::StrideD;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
using EpilogueParams = typename CollectiveEpilogue::Params;
|
||||
|
||||
static_assert(cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>,
|
||||
"TMA warp-specialized kernel does not support specializing the tile scheduler.");
|
||||
using TileSchedulerTag = TileScheduler_;
|
||||
using TileScheduler = typename detail::TileSchedulerSelector<
|
||||
TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
|
||||
using TileSchedulerArguments = typename TileScheduler::Arguments;
|
||||
|
||||
// Kernel level shared memory storage
|
||||
struct SharedStorage {
|
||||
// Mainloop and epilogue don't use smem concurrently since kernel is non-persistent, so we can use a union
|
||||
union TensorStorage {
|
||||
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
|
||||
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
|
||||
|
||||
MainloopTensorStorage mainloop;
|
||||
EpilogueTensorStorage epilogue;
|
||||
} tensors;
|
||||
|
||||
struct PipelineStorage : cute::aligned_struct<16, _1> {
|
||||
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
|
||||
using PrefetcherPipelineStorage = typename CollectiveMainloop::PrefetcherPipelineStorage;
|
||||
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
||||
|
||||
alignas(16) MainloopPipelineStorage mainloop;
|
||||
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||
alignas(16) PrefetcherPipelineStorage prefetcher;
|
||||
} pipelines;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
static constexpr uint32_t NumLoadWarpGroups = 1;
|
||||
static constexpr uint32_t NumMmaWarpGroups = 1;
|
||||
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
// Device side arguments
|
||||
struct Arguments {
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopArguments mainloop{};
|
||||
EpilogueArguments epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerArguments scheduler{};
|
||||
};
|
||||
|
||||
// Kernel entry point API
|
||||
struct Params {
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopParams mainloop{};
|
||||
EpilogueParams epilogue{};
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
||||
static
|
||||
Params
|
||||
to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
(void) workspace;
|
||||
auto problem_shape = args.problem_shape;
|
||||
if constexpr (detail::Has_SwapAB_v<CollectiveMainloop>) {
|
||||
// swap M/N
|
||||
get<0>(problem_shape) = get<1>(args.problem_shape);
|
||||
get<1>(problem_shape) = get<0>(args.problem_shape);
|
||||
}
|
||||
return {
|
||||
args.mode,
|
||||
problem_shape,
|
||||
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace)
|
||||
};
|
||||
}
|
||||
|
||||
static bool
|
||||
can_implement(Arguments const& args) {
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
||||
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
||||
return implementable;
|
||||
}
|
||||
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
|
||||
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
|
||||
implementable &= TileScheduler::can_implement(args.scheduler);
|
||||
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static
|
||||
size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static
|
||||
cutlass::Status
|
||||
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
// Computes the kernel launch grid shape based on runtime parameters
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
auto cluster_shape = ClusterShape{};
|
||||
auto tile_shape = TileShape{};
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
||||
return TileScheduler::get_tiled_cta_shape_mnl(
|
||||
problem_shape_MNKL, tile_shape, cluster_shape);
|
||||
}
|
||||
|
||||
static dim3
|
||||
get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
operator()(Params const& params, char* smem_buf) {
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
#if defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
# define ENABLE_SM90_KERNEL_LEVEL 1
|
||||
#endif
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(ENABLE_SM90_KERNEL_LEVEL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
enum class WarpGroupRole {
|
||||
Producer = 0,
|
||||
Consumer = 1,
|
||||
};
|
||||
// Split mode: use Warp0 to load NK and epilogue, Warp2 to load MK.
|
||||
// Non-split mode: use Warp0 to load MK, NK and epilogue, Warp2 is unused.
|
||||
// Both modes use Warp1 to prefetch.
|
||||
enum class ProducerWarpRole {
|
||||
Warp0 = 0,
|
||||
PrefetchMK = 1,
|
||||
Warp2 = 2,
|
||||
UnusedWarp = 3
|
||||
};
|
||||
|
||||
// Kernel level shared memory storage
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
||||
|
||||
int thread_idx = int(threadIdx.x);
|
||||
int lane_idx = canonical_lane_idx();
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
|
||||
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
|
||||
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
|
||||
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
|
||||
|
||||
// Issue Tma Descriptor Prefetch from a single thread
|
||||
if ((warp_idx == 0) && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
||||
}
|
||||
|
||||
// Mainloop Load pipeline
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
typename MainloopPipeline::Params mainloop_pipeline_params;
|
||||
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
if (warp_group_role == WarpGroupRole::Producer && (
|
||||
producer_warp_role == ProducerWarpRole::Warp0 ||
|
||||
producer_warp_role == ProducerWarpRole::Warp2)) {
|
||||
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
|
||||
mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer) {
|
||||
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
|
||||
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
|
||||
bool should_prefetch = params.mainloop.prefetch_ratio > 0;
|
||||
using PrefetcherPipeline = typename CollectiveMainloop::PrefetcherPipeline;
|
||||
typename PrefetcherPipeline::Params prefetcher_pipeline_params;
|
||||
prefetcher_pipeline_params.num_prefetchers = 1;
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) {
|
||||
prefetcher_pipeline_params.should_prefetch = should_prefetch;
|
||||
prefetcher_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes_mk;
|
||||
}
|
||||
PrefetcherPipeline prefetcher_pipeline(shared_storage.pipelines.prefetcher, prefetcher_pipeline_params);
|
||||
|
||||
// Epilogue Load pipeline
|
||||
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Warp0) {
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer) {
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
|
||||
if constexpr (CollectiveEpilogue::RequiresTransactionBytes) {
|
||||
epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes;
|
||||
}
|
||||
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
|
||||
|
||||
// Epilogue Store pipeline
|
||||
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
|
||||
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||
epi_store_pipeline_params.always_wait = true;
|
||||
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||
|
||||
// Initialize starting pipeline states for the collectives
|
||||
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// i.e., we skip all waits since we know that the buffer is indeed empty
|
||||
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
|
||||
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
auto cluster_wait_fn = [&] () {
|
||||
// We need this to guarantee that the Pipeline init is visible
|
||||
// To all producers and consumer thread blocks in the Cluster
|
||||
if constexpr (size(ClusterShape{}) > 1) {
|
||||
// Non-prefetcher warps arrive and wait,
|
||||
// Prefetcher warp can go ahead without waiting.
|
||||
cute::cluster_arrive_relaxed();
|
||||
if (warp_group_role != WarpGroupRole::Producer ||
|
||||
producer_warp_role != ProducerWarpRole::PrefetchMK) {
|
||||
cute::cluster_wait();
|
||||
}
|
||||
return [] () {};
|
||||
}
|
||||
else {
|
||||
// __syncthreads() but only for non prefetcher warps
|
||||
if (should_prefetch) {
|
||||
|
||||
// Use a named barrier to let the prefetcher warp start loading into the L2
|
||||
// without waiting to sync with all other warps.
|
||||
// All other warps need to sync because the mainloop pipeline init
|
||||
// should be visible to all of them.
|
||||
// Prefetcher has its own barriers, and the only warps it would need to sync
|
||||
// with would be the DMA warps.
|
||||
using ClusterSyncWithPrefetchBarrier = typename cutlass::arch::NamedBarrier;
|
||||
auto prefetcher_arrive_barrier = ClusterSyncWithPrefetchBarrier(
|
||||
blockDim.x * blockDim.y * blockDim.z,
|
||||
/*reserved_named_barriers_*/ 14);
|
||||
// Prefetcher warp doesn't arrive on this barrier.
|
||||
auto cluster_arrive_barrier = ClusterSyncWithPrefetchBarrier(
|
||||
blockDim.x * blockDim.y * blockDim.z - NumThreadsPerWarp,
|
||||
/*reserved_named_barriers_*/ 15);
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) {
|
||||
__syncwarp();
|
||||
prefetcher_arrive_barrier.arrive();
|
||||
}
|
||||
else if (warp_group_role == WarpGroupRole::Producer) {
|
||||
prefetcher_arrive_barrier.arrive_and_wait();
|
||||
cluster_arrive_barrier.arrive_and_wait();
|
||||
}
|
||||
else {
|
||||
prefetcher_arrive_barrier.arrive();
|
||||
cluster_arrive_barrier.arrive_and_wait();
|
||||
}
|
||||
} else {
|
||||
__syncthreads();
|
||||
}
|
||||
return [] () {};
|
||||
}
|
||||
} ();
|
||||
|
||||
// Preconditions
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
||||
|
||||
// Get the appropriate blocks for this thread block -- potential for thread block locality
|
||||
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
||||
TiledMma tiled_mma;
|
||||
|
||||
// In a warp specialized kernel, collectives expose data movement and compute operations separately
|
||||
CollectiveMainloop collective_mainloop;
|
||||
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
|
||||
|
||||
// Prepare and partition the input tensors. Expects a tuple of tensors where:
|
||||
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
|
||||
// get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
|
||||
auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
|
||||
static_assert(cute::tuple_size_v<decltype(load_inputs)> >= 2, "Output of load_init must have at least two elements (A, B)");
|
||||
|
||||
// Extract out partitioned A and B.
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
|
||||
// Compute m_coord, n_coord, and l_coord with their post-tiled shapes
|
||||
auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
// Get pipeline iterators and increments from tensor shapes
|
||||
auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl));
|
||||
auto k_tile_count = size<3>(gA_mkl);
|
||||
|
||||
// Wait for all thread blocks in the Cluster
|
||||
cluster_wait_fn();
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer) {
|
||||
if (producer_warp_role == ProducerWarpRole::Warp0) {
|
||||
if constexpr(SplitWarps) {
|
||||
collective_mainloop.load_NK(
|
||||
params.mainloop,
|
||||
mainloop_pipeline,
|
||||
prefetcher_pipeline,
|
||||
mainloop_pipe_producer_state,
|
||||
gB_nkl,
|
||||
blk_coord,
|
||||
k_tile_iter, k_tile_count,
|
||||
lane_idx,
|
||||
block_rank_in_cluster,
|
||||
shared_storage.tensors.mainloop
|
||||
);
|
||||
}
|
||||
else {
|
||||
collective_mainloop.load(
|
||||
params.mainloop,
|
||||
mainloop_pipeline,
|
||||
prefetcher_pipeline,
|
||||
mainloop_pipe_producer_state,
|
||||
gA_mkl, gB_nkl,
|
||||
blk_coord,
|
||||
k_tile_iter, k_tile_count,
|
||||
lane_idx,
|
||||
block_rank_in_cluster,
|
||||
shared_storage.tensors.mainloop
|
||||
);
|
||||
}
|
||||
// Update starting mainloop pipeline state for the pipeline drain
|
||||
mainloop_pipe_producer_state.advance(k_tile_count);
|
||||
// Make sure mainloop consumer has been waited upon before issuing epilogue load
|
||||
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
|
||||
|
||||
if (collective_epilogue.is_producer_load_needed()) {
|
||||
// Ensure warp is converged before issuing epilogue loads
|
||||
__syncwarp();
|
||||
epi_load_pipe_producer_state = collective_epilogue.load(
|
||||
epi_load_pipeline,
|
||||
epi_load_pipe_producer_state,
|
||||
problem_shape_MNKL,
|
||||
blk_shape,
|
||||
blk_coord,
|
||||
tiled_mma,
|
||||
lane_idx,
|
||||
shared_storage.tensors.epilogue
|
||||
);
|
||||
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
|
||||
}
|
||||
}
|
||||
else if (SplitWarps && producer_warp_role == ProducerWarpRole::Warp2) {
|
||||
collective_mainloop.load_MK(
|
||||
params.mainloop,
|
||||
mainloop_pipeline,
|
||||
prefetcher_pipeline,
|
||||
mainloop_pipe_producer_state,
|
||||
gA_mkl,
|
||||
blk_coord,
|
||||
k_tile_iter, k_tile_count,
|
||||
lane_idx,
|
||||
block_rank_in_cluster,
|
||||
shared_storage.tensors.mainloop
|
||||
);
|
||||
// Update starting mainloop pipeline state for the pipeline drain
|
||||
mainloop_pipe_producer_state.advance(k_tile_count);
|
||||
// Make sure mainloop consumer has been waited upon before issuing epilogue load
|
||||
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
|
||||
} else if (producer_warp_role == ProducerWarpRole::PrefetchMK && should_prefetch) {
|
||||
collective_mainloop.prefetch_MK(
|
||||
params.mainloop,
|
||||
prefetcher_pipeline,
|
||||
mainloop_pipe_producer_state,
|
||||
gA_mkl,
|
||||
blk_coord,
|
||||
k_tile_iter, k_tile_count,
|
||||
lane_idx,
|
||||
block_rank_in_cluster,
|
||||
shared_storage.tensors.mainloop
|
||||
);
|
||||
}
|
||||
}
|
||||
else if (warp_group_role == WarpGroupRole::Consumer) {
|
||||
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
collective_mainloop.mma(
|
||||
mainloop_pipeline,
|
||||
mainloop_pipe_consumer_state,
|
||||
accumulators,
|
||||
k_tile_count,
|
||||
warp_group_thread_idx,
|
||||
shared_storage.tensors.mainloop,
|
||||
params.mainloop
|
||||
);
|
||||
|
||||
// Make sure the math instructions are done and free buffers before entering the epilogue
|
||||
collective_mainloop.mma_tail(
|
||||
mainloop_pipeline,
|
||||
mainloop_pipe_consumer_state,
|
||||
k_tile_count
|
||||
);
|
||||
|
||||
// Epilogue and write to gD
|
||||
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
|
||||
collective_epilogue.store(
|
||||
epi_load_pipeline,
|
||||
epi_load_pipe_consumer_state,
|
||||
epi_store_pipeline,
|
||||
epi_store_pipe_producer_state,
|
||||
problem_shape_MNKL,
|
||||
blk_shape,
|
||||
blk_coord,
|
||||
accumulators,
|
||||
tiled_mma,
|
||||
warp_group_thread_idx,
|
||||
shared_storage.tensors.epilogue
|
||||
);
|
||||
|
||||
collective_epilogue.store_tail(
|
||||
epi_load_pipeline,
|
||||
epi_load_pipe_consumer_state_next,
|
||||
epi_store_pipeline,
|
||||
epi_store_pipe_producer_state_next
|
||||
);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
@ -0,0 +1,161 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cutlass/arch/barrier.h"
|
||||
#include "cute/container/array.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// MSVC work-around
|
||||
template <int Stages>
|
||||
struct PrefetcherPipelineSharedStorage {
|
||||
using TransactionBarrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
using Barrier = cutlass::arch::ClusterBarrier;
|
||||
|
||||
TransactionBarrier tma_barrier[Stages];
|
||||
Barrier producer_ready_barrier;
|
||||
};
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// Prefetcher pipeline is modeled after PipelineTmaAsync, with a cluster transaction
|
||||
// barrier providing control over the number of concurrent outstanding TMA loads.
|
||||
// There is also an additional cluster barrier which is only used when `prefetch_ratio` is unset.
|
||||
// `prefetch_ratio` determines how many K tiles get loaded, and when unset, the prefetcher checks
|
||||
// whether DMA warps are done waiting on griddepcontrol, and if so, stops issuing more TMA loads.
|
||||
template <int Stages_>
|
||||
class PrefetchPipeline {
|
||||
public :
|
||||
static constexpr uint32_t Stages = Stages_;
|
||||
using SharedStorage = detail::PrefetcherPipelineSharedStorage<Stages>;
|
||||
|
||||
using TransactionBarrier = typename SharedStorage::TransactionBarrier;
|
||||
using Barrier = typename SharedStorage::Barrier;
|
||||
using PrefetcherBarrierType = typename TransactionBarrier::ValueType;
|
||||
|
||||
struct Params {
|
||||
uint32_t transaction_bytes = 0;
|
||||
uint32_t num_prefetchers = 1;
|
||||
bool should_prefetch = false;
|
||||
};
|
||||
|
||||
// Constructor
|
||||
CUTLASS_DEVICE
|
||||
PrefetchPipeline(SharedStorage& storage, Params params)
|
||||
: params_(params)
|
||||
, tma_barrier_ptr_(&storage.tma_barrier[0])
|
||||
, producer_ready_barrier_ptr_(&storage.producer_ready_barrier) {
|
||||
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
if (params.should_prefetch && lane_predicate) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Stages; ++i) {
|
||||
tma_barrier_ptr_[i].init(params.num_prefetchers);
|
||||
}
|
||||
producer_ready_barrier_ptr_[0].init(1);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void producer_arrive() {
|
||||
if (params_.should_prefetch) {
|
||||
producer_ready_barrier_ptr_[0].arrive();
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool have_producers_arrived() {
|
||||
if (params_.should_prefetch) {
|
||||
uint32_t barrier_status_ = producer_ready_barrier_ptr_[0].try_wait(0);
|
||||
auto barrier_status = static_cast<BarrierStatus>(barrier_status_);
|
||||
if (barrier_status == BarrierStatus::WaitDone) {
|
||||
return true; // exit prefetcher loop
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void prefetcher_acquire(uint32_t stage, uint32_t phase, bool should_wait) {
|
||||
if (params_.should_prefetch) {
|
||||
if (should_wait) {
|
||||
tma_barrier_ptr_[stage].wait(phase ^ 1);
|
||||
}
|
||||
tma_barrier_ptr_[stage].arrive_and_expect_tx(params_.transaction_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void advance_prefetcher_state(uint32_t& stage, uint32_t& phase) {
|
||||
if (params_.should_prefetch) {
|
||||
stage++;
|
||||
if (stage == Stages) {
|
||||
stage = 0;
|
||||
phase ^= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void prefetcher_tail(uint32_t stage, uint32_t phase) {
|
||||
if (params_.should_prefetch) {
|
||||
// Wait on any already-issued loads
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < stage; ++i) {
|
||||
tma_barrier_ptr_[i].wait(phase);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
PrefetcherBarrierType* prefetcher_get_barrier(uint32_t stage) {
|
||||
return reinterpret_cast<PrefetcherBarrierType*>(&tma_barrier_ptr_[stage]);
|
||||
}
|
||||
|
||||
private :
|
||||
TransactionBarrier* tma_barrier_ptr_ = nullptr;
|
||||
Barrier* producer_ready_barrier_ptr_ = nullptr;
|
||||
Params params_;
|
||||
|
||||
};
|
||||
|
||||
} // end namespace cutlass
|
||||
35
examples/64_ada_fp8_gemm_grouped/CMakeLists.txt
Normal file
35
examples/64_ada_fp8_gemm_grouped/CMakeLists.txt
Normal file
@ -0,0 +1,35 @@
|
||||
|
||||
# Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
64_ada_fp8_gemm_grouped
|
||||
ada_fp8_gemm_grouped.cu
|
||||
)
|
||||
1208
examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu
Normal file
1208
examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user